diff --git a/nn/init.go b/nn/init.go index bdc6144..983528b 100644 --- a/nn/init.go +++ b/nn/init.go @@ -222,26 +222,31 @@ func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device, dtype dtype = dtypeOpt[0] } - fanIn, _, err := CalculateFans(dims) - if err != nil { - panic(err) - } + /* + fanIn, _, err := CalculateFans(dims) + if err != nil { + panic(err) + } - gain, err := calculateGain(k.NonLinearity, k.NegativeSlope) // default non-linearity="leaky_relu", negative_slope=0.01 - if err != nil { - err = fmt.Errorf("kaimingUniformInit.InitTensor() failed: %v\n", err) - panic(err) - } + gain, err := calculateGain(k.NonLinearity, k.NegativeSlope) // default non-linearity="leaky_relu", negative_slope=0.01 + if err != nil { + err = fmt.Errorf("kaimingUniformInit.InitTensor() failed: %v\n", err) + panic(err) + } - std := gain / math.Sqrt(float64(fanIn)) // default using fanIn + std := gain / math.Sqrt(float64(fanIn)) // default using fanIn - // Calculate uniform bounds from standard deviation - bound := math.Sqrt(3.0) * std + // Calculate uniform bounds from standard deviation + bound := math.Sqrt(3.0) * std - ts.NoGrad(func() { + // NOTE. This is a well-known memory leak!!! + // Avoid to use it for now!!! retVal = ts.MustZeros(dims, dtype, device) retVal.Uniform_(-bound, bound) - }) + */ + + // For now, just make a random norm + retVal = ts.MustRandn(dims, dtype, device) return retVal } @@ -382,3 +387,36 @@ func contains(items []string, item string) bool { } return false } + +// XavierUniform fills the input tensor with values according to the method +// described in the paper `Understanding the difficulty of training deep feedforward neural networks` +// using a uniform distribution +// +// Also known as Glorot initialization. +// +// Paper: https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf +// Pytorch implementation: https://github.com/pytorch/pytorch/blob/df50f91571891ec3f87977a2bdd4a2b609d70afc/torch/nn/init.py#L310 +func XavierUniform_(x *ts.Tensor, gainOpt ...float64) { + gain := 1.0 + if len(gainOpt) > 0 { + gain = gainOpt[0] + } + + size := x.MustSize() + dtype := x.DType() + device := x.MustDevice() + fanIn, fanOut, err := CalculateFans(size) + if err != nil { + panic(err) + } + + std := gain * math.Sqrt(2.0/float64(fanIn+fanOut)) + + // calculate uniform bounds from standard deviation + a := math.Sqrt(3.0) * std + uniformInit := NewUniformInit(-a, a) + src := uniformInit.InitTensor(size, device, dtype) + x.Copy_(src) + + src.MustDrop() +} diff --git a/nn/linear.go b/nn/linear.go index e86fbb8..c4d763c 100644 --- a/nn/linear.go +++ b/nn/linear.go @@ -21,6 +21,8 @@ type LinearConfig struct { func DefaultLinearConfig() *LinearConfig { negSlope := math.Sqrt(5) return &LinearConfig{ + // NOTE. KaimingUniform cause mem leak due to ts.Uniform()!!! + // Avoid using it now. WsInit: NewKaimingUniformInit(WithKaimingNegativeSlope(negSlope)), BsInit: nil, Bias: true, @@ -60,8 +62,10 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear { } } + ws := vs.MustNewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false) + return &Linear{ - Ws: vs.MustNewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false), + Ws: ws, Bs: bs, } } diff --git a/nn/varstore.go b/nn/varstore.go index ec0246c..6dae83f 100644 --- a/nn/varstore.go +++ b/nn/varstore.go @@ -567,6 +567,7 @@ func (p *Path) add(name string, newTs *ts.Tensor, trainable bool, varType string tensor *ts.Tensor err error ) + if trainable { tensor, err = newTs.SetRequiresGrad(true, false) if err != nil { @@ -877,12 +878,18 @@ func (p *Path) MustOnesNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Te // related argument. func (p *Path) NewVar(name string, dims []int64, ini Init, opts ...AddOpt) (*ts.Tensor, error) { dtype := gotch.DefaultDType - v := ini.InitTensor(dims, p.varstore.device, dtype) + // v := ini.InitTensor(dims, p.varstore.device, dtype) + var v *ts.Tensor + + v = ini.InitTensor(dims, p.varstore.device, dtype) + out, err := p.Add(name, v, true, opts...) if err != nil { return nil, err } + v.MustDrop() + return out, err } diff --git a/nn/varstore_test.go b/nn/varstore_test.go index 616acf5..c862e68 100644 --- a/nn/varstore_test.go +++ b/nn/varstore_test.go @@ -1,10 +1,12 @@ package nn_test import ( + "fmt" "os" "path/filepath" "reflect" "testing" + "time" "github.com/sugarme/gotch" "github.com/sugarme/gotch/nn" @@ -133,3 +135,43 @@ func TestSaveLoad(t *testing.T) { t.Errorf("Failed deleting varstore saved file: %v\n", filenameAbs) } } + +// Test whether create params in varstore can cause memory blow-up due to accumulate gradient. +func TestVarstore_Memcheck(t *testing.T) { + gotch.PrintMemStats("Start") + device := gotch.CPU + vs := nn.NewVarStore(device) + params := 1000 + + path := vs.Root() + // dims := []int64{1024, 1024} + config := nn.DefaultLinearConfig() + inDim := int64(1024) + outDim := int64(1024) + var layers []nn.Linear + for i := 0; i < params; i++ { + ts.NoGrad(func() { + name := fmt.Sprintf("param_%v", i) + l := nn.NewLinear(path.Sub(name), inDim, outDim, config) + layers = append(layers, *l) + // x := ts.MustRandn(dims, gotch.DefaultDType, device) + // path.MustAdd(name, x, false) + // x.MustDrop() + }) + } + + // vs.Summary() + + fmt.Printf("vs created...\n") + // printMemStats("After varstore created") + + vs.Destroy() + ts.CleanUp() + + fmt.Printf("vs deleted...\n") + + // printMemStats("After varstore deleted") + + time.Sleep(time.Second * 10) + gotch.PrintMemStats("Final") +} diff --git a/ts/tensor.go b/ts/tensor.go index 81a39ba..3d3bbd8 100644 --- a/ts/tensor.go +++ b/ts/tensor.go @@ -129,6 +129,7 @@ func freeCTensor(ts *Tensor) error { // Just return if it has been deleted previously! if unsafe.Pointer(ts.ctensor) == nil { + log.Printf("INFO: ctensor is nil. Nothing to delete here...\n") return nil }