temp use of ts.Randn() instead of ts.Uniform() as it causes mem leak
This commit is contained in:
parent
163e625426
commit
b3d821d34e
66
nn/init.go
66
nn/init.go
|
@ -222,26 +222,31 @@ func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device, dtype
|
|||
dtype = dtypeOpt[0]
|
||||
}
|
||||
|
||||
fanIn, _, err := CalculateFans(dims)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
/*
|
||||
fanIn, _, err := CalculateFans(dims)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
gain, err := calculateGain(k.NonLinearity, k.NegativeSlope) // default non-linearity="leaky_relu", negative_slope=0.01
|
||||
if err != nil {
|
||||
err = fmt.Errorf("kaimingUniformInit.InitTensor() failed: %v\n", err)
|
||||
panic(err)
|
||||
}
|
||||
gain, err := calculateGain(k.NonLinearity, k.NegativeSlope) // default non-linearity="leaky_relu", negative_slope=0.01
|
||||
if err != nil {
|
||||
err = fmt.Errorf("kaimingUniformInit.InitTensor() failed: %v\n", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
std := gain / math.Sqrt(float64(fanIn)) // default using fanIn
|
||||
std := gain / math.Sqrt(float64(fanIn)) // default using fanIn
|
||||
|
||||
// Calculate uniform bounds from standard deviation
|
||||
bound := math.Sqrt(3.0) * std
|
||||
// Calculate uniform bounds from standard deviation
|
||||
bound := math.Sqrt(3.0) * std
|
||||
|
||||
ts.NoGrad(func() {
|
||||
// NOTE. This is a well-known memory leak!!!
|
||||
// Avoid to use it for now!!!
|
||||
retVal = ts.MustZeros(dims, dtype, device)
|
||||
retVal.Uniform_(-bound, bound)
|
||||
})
|
||||
*/
|
||||
|
||||
// For now, just make a random norm
|
||||
retVal = ts.MustRandn(dims, dtype, device)
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
@ -382,3 +387,36 @@ func contains(items []string, item string) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// XavierUniform fills the input tensor with values according to the method
|
||||
// described in the paper `Understanding the difficulty of training deep feedforward neural networks`
|
||||
// using a uniform distribution
|
||||
//
|
||||
// Also known as Glorot initialization.
|
||||
//
|
||||
// Paper: https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf
|
||||
// Pytorch implementation: https://github.com/pytorch/pytorch/blob/df50f91571891ec3f87977a2bdd4a2b609d70afc/torch/nn/init.py#L310
|
||||
func XavierUniform_(x *ts.Tensor, gainOpt ...float64) {
|
||||
gain := 1.0
|
||||
if len(gainOpt) > 0 {
|
||||
gain = gainOpt[0]
|
||||
}
|
||||
|
||||
size := x.MustSize()
|
||||
dtype := x.DType()
|
||||
device := x.MustDevice()
|
||||
fanIn, fanOut, err := CalculateFans(size)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
std := gain * math.Sqrt(2.0/float64(fanIn+fanOut))
|
||||
|
||||
// calculate uniform bounds from standard deviation
|
||||
a := math.Sqrt(3.0) * std
|
||||
uniformInit := NewUniformInit(-a, a)
|
||||
src := uniformInit.InitTensor(size, device, dtype)
|
||||
x.Copy_(src)
|
||||
|
||||
src.MustDrop()
|
||||
}
|
||||
|
|
|
@ -21,6 +21,8 @@ type LinearConfig struct {
|
|||
func DefaultLinearConfig() *LinearConfig {
|
||||
negSlope := math.Sqrt(5)
|
||||
return &LinearConfig{
|
||||
// NOTE. KaimingUniform cause mem leak due to ts.Uniform()!!!
|
||||
// Avoid using it now.
|
||||
WsInit: NewKaimingUniformInit(WithKaimingNegativeSlope(negSlope)),
|
||||
BsInit: nil,
|
||||
Bias: true,
|
||||
|
@ -60,8 +62,10 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
|||
}
|
||||
}
|
||||
|
||||
ws := vs.MustNewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false)
|
||||
|
||||
return &Linear{
|
||||
Ws: vs.MustNewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false),
|
||||
Ws: ws,
|
||||
Bs: bs,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -567,6 +567,7 @@ func (p *Path) add(name string, newTs *ts.Tensor, trainable bool, varType string
|
|||
tensor *ts.Tensor
|
||||
err error
|
||||
)
|
||||
|
||||
if trainable {
|
||||
tensor, err = newTs.SetRequiresGrad(true, false)
|
||||
if err != nil {
|
||||
|
@ -877,12 +878,18 @@ func (p *Path) MustOnesNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Te
|
|||
// related argument.
|
||||
func (p *Path) NewVar(name string, dims []int64, ini Init, opts ...AddOpt) (*ts.Tensor, error) {
|
||||
dtype := gotch.DefaultDType
|
||||
v := ini.InitTensor(dims, p.varstore.device, dtype)
|
||||
// v := ini.InitTensor(dims, p.varstore.device, dtype)
|
||||
var v *ts.Tensor
|
||||
|
||||
v = ini.InitTensor(dims, p.varstore.device, dtype)
|
||||
|
||||
out, err := p.Add(name, v, true, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v.MustDrop()
|
||||
|
||||
return out, err
|
||||
}
|
||||
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
package nn_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
|
@ -133,3 +135,43 @@ func TestSaveLoad(t *testing.T) {
|
|||
t.Errorf("Failed deleting varstore saved file: %v\n", filenameAbs)
|
||||
}
|
||||
}
|
||||
|
||||
// Test whether create params in varstore can cause memory blow-up due to accumulate gradient.
|
||||
func TestVarstore_Memcheck(t *testing.T) {
|
||||
gotch.PrintMemStats("Start")
|
||||
device := gotch.CPU
|
||||
vs := nn.NewVarStore(device)
|
||||
params := 1000
|
||||
|
||||
path := vs.Root()
|
||||
// dims := []int64{1024, 1024}
|
||||
config := nn.DefaultLinearConfig()
|
||||
inDim := int64(1024)
|
||||
outDim := int64(1024)
|
||||
var layers []nn.Linear
|
||||
for i := 0; i < params; i++ {
|
||||
ts.NoGrad(func() {
|
||||
name := fmt.Sprintf("param_%v", i)
|
||||
l := nn.NewLinear(path.Sub(name), inDim, outDim, config)
|
||||
layers = append(layers, *l)
|
||||
// x := ts.MustRandn(dims, gotch.DefaultDType, device)
|
||||
// path.MustAdd(name, x, false)
|
||||
// x.MustDrop()
|
||||
})
|
||||
}
|
||||
|
||||
// vs.Summary()
|
||||
|
||||
fmt.Printf("vs created...\n")
|
||||
// printMemStats("After varstore created")
|
||||
|
||||
vs.Destroy()
|
||||
ts.CleanUp()
|
||||
|
||||
fmt.Printf("vs deleted...\n")
|
||||
|
||||
// printMemStats("After varstore deleted")
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
gotch.PrintMemStats("Final")
|
||||
}
|
||||
|
|
|
@ -129,6 +129,7 @@ func freeCTensor(ts *Tensor) error {
|
|||
|
||||
// Just return if it has been deleted previously!
|
||||
if unsafe.Pointer(ts.ctensor) == nil {
|
||||
log.Printf("INFO: ctensor is nil. Nothing to delete here...\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user