fixed #45 #48 RandInit

This commit is contained in:
sugarme 2021-07-21 00:04:53 +10:00
parent 4fd5c5059d
commit 731513a986

View File

@ -3,7 +3,6 @@ package nn
import (
"log"
"math"
"math/rand"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
@ -72,48 +71,24 @@ func NewRandnInit(mean, stdev float64) randnInit {
}
func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
var err error
rand.Seed(86)
data := make([]float32, ts.FlattenDim(dims))
for i := range data {
// NOTE. tensor will have DType = Float (float32)
data[i] = float32(rand.NormFloat64()*r.mean + r.stdev)
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
if r.mean == 0 {
return ts.MustRandn(dims, gotch.Float, device)
}
newTs, err := ts.NewTensorFromData(data, dims)
if err != nil {
log.Fatalf("randInit - InitTensor method call error: %v\n", err)
}
retVal = newTs.MustTo(device, true)
return retVal
initTs := ts.MustRandn(dims, gotch.Float, device)
return initTs.MustMul1(ts.FloatScalar(r.stdev), true).MustAdd1(ts.FloatScalar(r.mean), true)
}
func (r randnInit) Set(tensor *ts.Tensor) {
var (
randnTs *ts.Tensor
err error
)
dims, err := tensor.Size()
if err != nil {
log.Fatalf("randInit - Set method call error: %v\n", err)
}
rand.Seed(86)
data := make([]float64, ts.FlattenDim(dims))
for i := range data {
data[i] = rand.NormFloat64()*r.mean + r.stdev
}
randnTs, err = ts.NewTensorFromData(data, dims)
if err != nil {
log.Fatalf("randInit - InitTensor method call error: %v\n", err)
}
tensor.Copy_(randnTs)
initTs := r.InitTensor(dims, tensor.MustDevice())
tensor.Copy_(initTs)
initTs.MustDrop()
}
// uniformInit :