fix(nn/init): fixed KaiminigUnitform Init missing case dims = 1
This commit is contained in:
parent
282c54f185
commit
17b9e6c46f
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -19,6 +19,7 @@ target/
|
|||
_build/
|
||||
data/
|
||||
example/testdata/
|
||||
example/test/
|
||||
tmp/
|
||||
bak/
|
||||
gen/.merlin
|
||||
|
@ -32,3 +33,4 @@ libtch/dummy_cuda_dependency.cpp
|
|||
libtch/fake_cuda_dependency.cpp.cpu
|
||||
libtch/lib.go.cpu
|
||||
libtch/lib.go
|
||||
|
||||
|
|
16
nn/init.go
16
nn/init.go
|
@ -154,13 +154,16 @@ func NewKaimingUniformInit() kaimingUniformInit {
|
|||
|
||||
func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
||||
var fanIn int64
|
||||
if len(dims) == 1 {
|
||||
log.Fatalf("KaimingUniformInit method call: dims (%v) should have length > 1", dims)
|
||||
if len(dims) == 0 {
|
||||
log.Fatalf("KaimingUniformInit method call: dims (%v) should have length >= 1", dims)
|
||||
} else if len(dims) == 1 {
|
||||
fanIn = factorial(dims[0])
|
||||
} else {
|
||||
fanIn = product(dims[1:])
|
||||
}
|
||||
|
||||
bound := math.Sqrt(1.0 / float64(fanIn))
|
||||
log.Println(fanIn)
|
||||
kind := gotch.Float
|
||||
retVal = ts.MustZeros(dims, kind, device)
|
||||
retVal.Uniform_(-bound, bound)
|
||||
|
@ -170,7 +173,6 @@ func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVa
|
|||
|
||||
// product calculates product by multiplying elements
|
||||
func product(dims []int64) (retVal int64) {
|
||||
|
||||
for i, v := range dims {
|
||||
if i == 0 {
|
||||
retVal = v
|
||||
|
@ -182,7 +184,7 @@ func product(dims []int64) (retVal int64) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func factorial(n uint64) (result uint64) {
|
||||
func factorial(n int64) (result int64) {
|
||||
if n > 0 {
|
||||
result = n * factorial(n-1)
|
||||
return result
|
||||
|
@ -197,8 +199,10 @@ func (k kaimingUniformInit) Set(tensor ts.Tensor) {
|
|||
}
|
||||
|
||||
var fanIn int64
|
||||
if len(dims) == 1 {
|
||||
log.Fatalf("KaimingUniformInit Set method call: Tensor (%v) should have length > 1", tensor.MustSize())
|
||||
if len(dims) == 0 {
|
||||
log.Fatalf("KaimingUniformInit Set method call: Tensor (%v) should have length >= 1", tensor.MustSize())
|
||||
} else if len(dims) == 1 {
|
||||
fanIn = factorial(dims[0])
|
||||
} else {
|
||||
fanIn = product(dims[1:])
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user