fix(nn/init): fixed KaiminigUnitform Init missing case dims = 1

This commit is contained in:
sugarme 2020-08-05 17:10:17 +10:00
parent 282c54f185
commit 17b9e6c46f
2 changed files with 12 additions and 6 deletions

2
.gitignore vendored
View File

@ -19,6 +19,7 @@ target/
_build/
data/
example/testdata/
example/test/
tmp/
bak/
gen/.merlin
@ -32,3 +33,4 @@ libtch/dummy_cuda_dependency.cpp
libtch/fake_cuda_dependency.cpp.cpu
libtch/lib.go.cpu
libtch/lib.go

View File

@ -154,13 +154,16 @@ func NewKaimingUniformInit() kaimingUniformInit {
func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
var fanIn int64
if len(dims) == 1 {
log.Fatalf("KaimingUniformInit method call: dims (%v) should have length > 1", dims)
if len(dims) == 0 {
log.Fatalf("KaimingUniformInit method call: dims (%v) should have length >= 1", dims)
} else if len(dims) == 1 {
fanIn = factorial(dims[0])
} else {
fanIn = product(dims[1:])
}
bound := math.Sqrt(1.0 / float64(fanIn))
log.Println(fanIn)
kind := gotch.Float
retVal = ts.MustZeros(dims, kind, device)
retVal.Uniform_(-bound, bound)
@ -170,7 +173,6 @@ func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVa
// product calculates product by multiplying elements
func product(dims []int64) (retVal int64) {
for i, v := range dims {
if i == 0 {
retVal = v
@ -182,7 +184,7 @@ func product(dims []int64) (retVal int64) {
return retVal
}
func factorial(n uint64) (result uint64) {
func factorial(n int64) (result int64) {
if n > 0 {
result = n * factorial(n-1)
return result
@ -197,8 +199,10 @@ func (k kaimingUniformInit) Set(tensor ts.Tensor) {
}
var fanIn int64
if len(dims) == 1 {
log.Fatalf("KaimingUniformInit Set method call: Tensor (%v) should have length > 1", tensor.MustSize())
if len(dims) == 0 {
log.Fatalf("KaimingUniformInit Set method call: Tensor (%v) should have length >= 1", tensor.MustSize())
} else if len(dims) == 1 {
fanIn = factorial(dims[0])
} else {
fanIn = product(dims[1:])
}