From 17b9e6c46f2642358ffe94e1877d1f2b48af53dd Mon Sep 17 00:00:00 2001 From: sugarme Date: Wed, 5 Aug 2020 17:10:17 +1000 Subject: [PATCH] fix(nn/init): fixed KaiminigUnitform Init missing case dims = 1 --- .gitignore | 2 ++ nn/init.go | 16 ++++++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index af11752..42ef99b 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ target/ _build/ data/ example/testdata/ +example/test/ tmp/ bak/ gen/.merlin @@ -32,3 +33,4 @@ libtch/dummy_cuda_dependency.cpp libtch/fake_cuda_dependency.cpp.cpu libtch/lib.go.cpu libtch/lib.go + diff --git a/nn/init.go b/nn/init.go index b718b62..9ba49cd 100644 --- a/nn/init.go +++ b/nn/init.go @@ -154,13 +154,16 @@ func NewKaimingUniformInit() kaimingUniformInit { func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) { var fanIn int64 - if len(dims) == 1 { - log.Fatalf("KaimingUniformInit method call: dims (%v) should have length > 1", dims) + if len(dims) == 0 { + log.Fatalf("KaimingUniformInit method call: dims (%v) should have length >= 1", dims) + } else if len(dims) == 1 { + fanIn = factorial(dims[0]) } else { fanIn = product(dims[1:]) } bound := math.Sqrt(1.0 / float64(fanIn)) + log.Println(fanIn) kind := gotch.Float retVal = ts.MustZeros(dims, kind, device) retVal.Uniform_(-bound, bound) @@ -170,7 +173,6 @@ func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVa // product calculates product by multiplying elements func product(dims []int64) (retVal int64) { - for i, v := range dims { if i == 0 { retVal = v @@ -182,7 +184,7 @@ func product(dims []int64) (retVal int64) { return retVal } -func factorial(n uint64) (result uint64) { +func factorial(n int64) (result int64) { if n > 0 { result = n * factorial(n-1) return result @@ -197,8 +199,10 @@ func (k kaimingUniformInit) Set(tensor ts.Tensor) { } var fanIn int64 - if len(dims) == 1 { - log.Fatalf("KaimingUniformInit Set method call: Tensor (%v) should have length > 1", tensor.MustSize()) + if len(dims) == 0 { + log.Fatalf("KaimingUniformInit Set method call: Tensor (%v) should have length >= 1", tensor.MustSize()) + } else if len(dims) == 1 { + fanIn = factorial(dims[0]) } else { fanIn = product(dims[1:]) }