commit
37f3a59854
41
nn/init.go
41
nn/init.go
|
@ -3,7 +3,6 @@ package nn
|
|||
import (
|
||||
"log"
|
||||
"math"
|
||||
"math/rand"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
|
@ -72,48 +71,24 @@ func NewRandnInit(mean, stdev float64) randnInit {
|
|||
}
|
||||
|
||||
func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
|
||||
var err error
|
||||
rand.Seed(86)
|
||||
|
||||
data := make([]float32, ts.FlattenDim(dims))
|
||||
for i := range data {
|
||||
// NOTE. tensor will have DType = Float (float32)
|
||||
data[i] = float32(rand.NormFloat64()*r.mean + r.stdev)
|
||||
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
|
||||
if r.mean == 0 {
|
||||
return ts.MustRandn(dims, gotch.Float, device)
|
||||
}
|
||||
|
||||
newTs, err := ts.NewTensorFromData(data, dims)
|
||||
if err != nil {
|
||||
log.Fatalf("randInit - InitTensor method call error: %v\n", err)
|
||||
}
|
||||
|
||||
retVal = newTs.MustTo(device, true)
|
||||
|
||||
return retVal
|
||||
|
||||
initTs := ts.MustRandn(dims, gotch.Float, device)
|
||||
return initTs.MustMul1(ts.FloatScalar(r.stdev), true).MustAdd1(ts.FloatScalar(r.mean), true)
|
||||
}
|
||||
|
||||
func (r randnInit) Set(tensor *ts.Tensor) {
|
||||
var (
|
||||
randnTs *ts.Tensor
|
||||
err error
|
||||
)
|
||||
|
||||
dims, err := tensor.Size()
|
||||
if err != nil {
|
||||
log.Fatalf("randInit - Set method call error: %v\n", err)
|
||||
}
|
||||
|
||||
rand.Seed(86)
|
||||
data := make([]float64, ts.FlattenDim(dims))
|
||||
for i := range data {
|
||||
data[i] = rand.NormFloat64()*r.mean + r.stdev
|
||||
}
|
||||
randnTs, err = ts.NewTensorFromData(data, dims)
|
||||
if err != nil {
|
||||
log.Fatalf("randInit - InitTensor method call error: %v\n", err)
|
||||
}
|
||||
|
||||
tensor.Copy_(randnTs)
|
||||
initTs := r.InitTensor(dims, tensor.MustDevice())
|
||||
tensor.Copy_(initTs)
|
||||
initTs.MustDrop()
|
||||
}
|
||||
|
||||
// uniformInit :
|
||||
|
|
Loading…
Reference in New Issue
Block a user