fix(nn/init): fixed KaiminigUnitform Init missing case dims = 1

This commit is contained in:
sugarme 2020-08-05 17:10:17 +10:00
parent 282c54f185
commit 17b9e6c46f
2 changed files with 12 additions and 6 deletions

2
.gitignore vendored
View File

@ -19,6 +19,7 @@ target/
_build/ _build/
data/ data/
example/testdata/ example/testdata/
example/test/
tmp/ tmp/
bak/ bak/
gen/.merlin gen/.merlin
@ -32,3 +33,4 @@ libtch/dummy_cuda_dependency.cpp
libtch/fake_cuda_dependency.cpp.cpu libtch/fake_cuda_dependency.cpp.cpu
libtch/lib.go.cpu libtch/lib.go.cpu
libtch/lib.go libtch/lib.go

View File

@ -154,13 +154,16 @@ func NewKaimingUniformInit() kaimingUniformInit {
func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) { func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
var fanIn int64 var fanIn int64
if len(dims) == 1 { if len(dims) == 0 {
log.Fatalf("KaimingUniformInit method call: dims (%v) should have length > 1", dims) log.Fatalf("KaimingUniformInit method call: dims (%v) should have length >= 1", dims)
} else if len(dims) == 1 {
fanIn = factorial(dims[0])
} else { } else {
fanIn = product(dims[1:]) fanIn = product(dims[1:])
} }
bound := math.Sqrt(1.0 / float64(fanIn)) bound := math.Sqrt(1.0 / float64(fanIn))
log.Println(fanIn)
kind := gotch.Float kind := gotch.Float
retVal = ts.MustZeros(dims, kind, device) retVal = ts.MustZeros(dims, kind, device)
retVal.Uniform_(-bound, bound) retVal.Uniform_(-bound, bound)
@ -170,7 +173,6 @@ func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVa
// product calculates product by multiplying elements // product calculates product by multiplying elements
func product(dims []int64) (retVal int64) { func product(dims []int64) (retVal int64) {
for i, v := range dims { for i, v := range dims {
if i == 0 { if i == 0 {
retVal = v retVal = v
@ -182,7 +184,7 @@ func product(dims []int64) (retVal int64) {
return retVal return retVal
} }
func factorial(n uint64) (result uint64) { func factorial(n int64) (result int64) {
if n > 0 { if n > 0 {
result = n * factorial(n-1) result = n * factorial(n-1)
return result return result
@ -197,8 +199,10 @@ func (k kaimingUniformInit) Set(tensor ts.Tensor) {
} }
var fanIn int64 var fanIn int64
if len(dims) == 1 { if len(dims) == 0 {
log.Fatalf("KaimingUniformInit Set method call: Tensor (%v) should have length > 1", tensor.MustSize()) log.Fatalf("KaimingUniformInit Set method call: Tensor (%v) should have length >= 1", tensor.MustSize())
} else if len(dims) == 1 {
fanIn = factorial(dims[0])
} else { } else {
fanIn = product(dims[1:]) fanIn = product(dims[1:])
} }