temp use of ts.Randn() instead of ts.Uniform() as it causes mem leak
This commit is contained in:
parent
163e625426
commit
b3d821d34e
66
nn/init.go
66
nn/init.go
|
@ -222,26 +222,31 @@ func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device, dtype
|
||||||
dtype = dtypeOpt[0]
|
dtype = dtypeOpt[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
fanIn, _, err := CalculateFans(dims)
|
/*
|
||||||
if err != nil {
|
fanIn, _, err := CalculateFans(dims)
|
||||||
panic(err)
|
if err != nil {
|
||||||
}
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
gain, err := calculateGain(k.NonLinearity, k.NegativeSlope) // default non-linearity="leaky_relu", negative_slope=0.01
|
gain, err := calculateGain(k.NonLinearity, k.NegativeSlope) // default non-linearity="leaky_relu", negative_slope=0.01
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("kaimingUniformInit.InitTensor() failed: %v\n", err)
|
err = fmt.Errorf("kaimingUniformInit.InitTensor() failed: %v\n", err)
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
std := gain / math.Sqrt(float64(fanIn)) // default using fanIn
|
std := gain / math.Sqrt(float64(fanIn)) // default using fanIn
|
||||||
|
|
||||||
// Calculate uniform bounds from standard deviation
|
// Calculate uniform bounds from standard deviation
|
||||||
bound := math.Sqrt(3.0) * std
|
bound := math.Sqrt(3.0) * std
|
||||||
|
|
||||||
ts.NoGrad(func() {
|
// NOTE. This is a well-known memory leak!!!
|
||||||
|
// Avoid to use it for now!!!
|
||||||
retVal = ts.MustZeros(dims, dtype, device)
|
retVal = ts.MustZeros(dims, dtype, device)
|
||||||
retVal.Uniform_(-bound, bound)
|
retVal.Uniform_(-bound, bound)
|
||||||
})
|
*/
|
||||||
|
|
||||||
|
// For now, just make a random norm
|
||||||
|
retVal = ts.MustRandn(dims, dtype, device)
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
@ -382,3 +387,36 @@ func contains(items []string, item string) bool {
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// XavierUniform fills the input tensor with values according to the method
|
||||||
|
// described in the paper `Understanding the difficulty of training deep feedforward neural networks`
|
||||||
|
// using a uniform distribution
|
||||||
|
//
|
||||||
|
// Also known as Glorot initialization.
|
||||||
|
//
|
||||||
|
// Paper: https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf
|
||||||
|
// Pytorch implementation: https://github.com/pytorch/pytorch/blob/df50f91571891ec3f87977a2bdd4a2b609d70afc/torch/nn/init.py#L310
|
||||||
|
func XavierUniform_(x *ts.Tensor, gainOpt ...float64) {
|
||||||
|
gain := 1.0
|
||||||
|
if len(gainOpt) > 0 {
|
||||||
|
gain = gainOpt[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
size := x.MustSize()
|
||||||
|
dtype := x.DType()
|
||||||
|
device := x.MustDevice()
|
||||||
|
fanIn, fanOut, err := CalculateFans(size)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
std := gain * math.Sqrt(2.0/float64(fanIn+fanOut))
|
||||||
|
|
||||||
|
// calculate uniform bounds from standard deviation
|
||||||
|
a := math.Sqrt(3.0) * std
|
||||||
|
uniformInit := NewUniformInit(-a, a)
|
||||||
|
src := uniformInit.InitTensor(size, device, dtype)
|
||||||
|
x.Copy_(src)
|
||||||
|
|
||||||
|
src.MustDrop()
|
||||||
|
}
|
||||||
|
|
|
@ -21,6 +21,8 @@ type LinearConfig struct {
|
||||||
func DefaultLinearConfig() *LinearConfig {
|
func DefaultLinearConfig() *LinearConfig {
|
||||||
negSlope := math.Sqrt(5)
|
negSlope := math.Sqrt(5)
|
||||||
return &LinearConfig{
|
return &LinearConfig{
|
||||||
|
// NOTE. KaimingUniform cause mem leak due to ts.Uniform()!!!
|
||||||
|
// Avoid using it now.
|
||||||
WsInit: NewKaimingUniformInit(WithKaimingNegativeSlope(negSlope)),
|
WsInit: NewKaimingUniformInit(WithKaimingNegativeSlope(negSlope)),
|
||||||
BsInit: nil,
|
BsInit: nil,
|
||||||
Bias: true,
|
Bias: true,
|
||||||
|
@ -60,8 +62,10 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ws := vs.MustNewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false)
|
||||||
|
|
||||||
return &Linear{
|
return &Linear{
|
||||||
Ws: vs.MustNewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false),
|
Ws: ws,
|
||||||
Bs: bs,
|
Bs: bs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -567,6 +567,7 @@ func (p *Path) add(name string, newTs *ts.Tensor, trainable bool, varType string
|
||||||
tensor *ts.Tensor
|
tensor *ts.Tensor
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
if trainable {
|
if trainable {
|
||||||
tensor, err = newTs.SetRequiresGrad(true, false)
|
tensor, err = newTs.SetRequiresGrad(true, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -877,12 +878,18 @@ func (p *Path) MustOnesNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Te
|
||||||
// related argument.
|
// related argument.
|
||||||
func (p *Path) NewVar(name string, dims []int64, ini Init, opts ...AddOpt) (*ts.Tensor, error) {
|
func (p *Path) NewVar(name string, dims []int64, ini Init, opts ...AddOpt) (*ts.Tensor, error) {
|
||||||
dtype := gotch.DefaultDType
|
dtype := gotch.DefaultDType
|
||||||
v := ini.InitTensor(dims, p.varstore.device, dtype)
|
// v := ini.InitTensor(dims, p.varstore.device, dtype)
|
||||||
|
var v *ts.Tensor
|
||||||
|
|
||||||
|
v = ini.InitTensor(dims, p.varstore.device, dtype)
|
||||||
|
|
||||||
out, err := p.Add(name, v, true, opts...)
|
out, err := p.Add(name, v, true, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
v.MustDrop()
|
v.MustDrop()
|
||||||
|
|
||||||
return out, err
|
return out, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
package nn_test
|
package nn_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
"github.com/sugarme/gotch"
|
||||||
"github.com/sugarme/gotch/nn"
|
"github.com/sugarme/gotch/nn"
|
||||||
|
@ -133,3 +135,43 @@ func TestSaveLoad(t *testing.T) {
|
||||||
t.Errorf("Failed deleting varstore saved file: %v\n", filenameAbs)
|
t.Errorf("Failed deleting varstore saved file: %v\n", filenameAbs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test whether create params in varstore can cause memory blow-up due to accumulate gradient.
|
||||||
|
func TestVarstore_Memcheck(t *testing.T) {
|
||||||
|
gotch.PrintMemStats("Start")
|
||||||
|
device := gotch.CPU
|
||||||
|
vs := nn.NewVarStore(device)
|
||||||
|
params := 1000
|
||||||
|
|
||||||
|
path := vs.Root()
|
||||||
|
// dims := []int64{1024, 1024}
|
||||||
|
config := nn.DefaultLinearConfig()
|
||||||
|
inDim := int64(1024)
|
||||||
|
outDim := int64(1024)
|
||||||
|
var layers []nn.Linear
|
||||||
|
for i := 0; i < params; i++ {
|
||||||
|
ts.NoGrad(func() {
|
||||||
|
name := fmt.Sprintf("param_%v", i)
|
||||||
|
l := nn.NewLinear(path.Sub(name), inDim, outDim, config)
|
||||||
|
layers = append(layers, *l)
|
||||||
|
// x := ts.MustRandn(dims, gotch.DefaultDType, device)
|
||||||
|
// path.MustAdd(name, x, false)
|
||||||
|
// x.MustDrop()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// vs.Summary()
|
||||||
|
|
||||||
|
fmt.Printf("vs created...\n")
|
||||||
|
// printMemStats("After varstore created")
|
||||||
|
|
||||||
|
vs.Destroy()
|
||||||
|
ts.CleanUp()
|
||||||
|
|
||||||
|
fmt.Printf("vs deleted...\n")
|
||||||
|
|
||||||
|
// printMemStats("After varstore deleted")
|
||||||
|
|
||||||
|
time.Sleep(time.Second * 10)
|
||||||
|
gotch.PrintMemStats("Final")
|
||||||
|
}
|
||||||
|
|
|
@ -129,6 +129,7 @@ func freeCTensor(ts *Tensor) error {
|
||||||
|
|
||||||
// Just return if it has been deleted previously!
|
// Just return if it has been deleted previously!
|
||||||
if unsafe.Pointer(ts.ctensor) == nil {
|
if unsafe.Pointer(ts.ctensor) == nil {
|
||||||
|
log.Printf("INFO: ctensor is nil. Nothing to delete here...\n")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user