diff --git a/example/mnist/nn.go b/example/mnist/nn.go index 79cf297..da38775 100644 --- a/example/mnist/nn.go +++ b/example/mnist/nn.go @@ -5,6 +5,7 @@ import ( "log" "runtime" + "github.com/sugarme/gotch" "github.com/sugarme/gotch/nn" "github.com/sugarme/gotch/ts" "github.com/sugarme/gotch/vision" @@ -14,14 +15,13 @@ const ( ImageDimNN int64 = 784 HiddenNodesNN int64 = 128 LabelNN int64 = 10 - // MnistDirNN string = "../../data/mnist" - MnistDirNN string = "/mnt/projects/numbat/data/mnist" epochsNN = 200 LrNN = 1e-3 ) +var MnistDirNN string = fmt.Sprintf("%s/%s", gotch.CachedDir, "mnist") var l nn.Linear func netInit(vs *nn.Path) ts.Module { diff --git a/mem-util.go b/mem-util.go new file mode 100644 index 0000000..78682bd --- /dev/null +++ b/mem-util.go @@ -0,0 +1,146 @@ +package gotch + +// helper to debug memory blow-up + +import ( + "fmt" + "os" + "runtime" + "strings" + "text/tabwriter" +) + +func PrintMemStats(messageOpt ...string) { + message := "Memory Stats" + if len(messageOpt) > 0 { + message = fmt.Sprintf("%s: %s", message, messageOpt[0]) + } + + var rtm runtime.MemStats + runtime.ReadMemStats(&rtm) + + tp := newTablePrinter() + tp.title = message + + tp.AddRecord("|", "Allocated heap objects", padRight(fmt.Sprintf("%v", rtm.Mallocs), 10), "|") + tp.AddRecord("|", "Released heap objects", padRight(fmt.Sprintf("%v", rtm.Frees), 10), "|") + tp.AddRecord("|", "Living heap objects", padRight(fmt.Sprintf("%v", rtm.HeapObjects), 10), "|") + tp.AddRecord("|", "Memory in use by heap objects (bytes)", padRight(fmt.Sprintf("%v", rtm.HeapAlloc), 10), "|") + tp.AddRecord("|", "Reserved memory (by Go runtime for heap, stack,...) (bytes)", padRight(fmt.Sprintf("%v", rtm.Sys), 10), "|") + tp.AddRecord("|", "Total pause time by GC (nanoseconds)", padRight(fmt.Sprintf("%v", rtm.PauseTotalNs), 10), "|") + tp.AddRecord("|", "Number of GC called", padRight(fmt.Sprintf("%v", rtm.NumGC), 10), "|") + // tp.AddRecord("Last GC called", fmt.Sprintf("%v", time.UnixMilli(int64(rtm.LastGC/1_000_000)))) + + tp.Print() + +} + +type tablePrinter struct { + w *tabwriter.Writer + maxLength int + title string +} + +type printItem struct { + val string + alignRight bool +} + +func item(val string, alignRightOpt ...bool) printItem { + alignRight := false + if len(alignRightOpt) > 0 { + alignRight = alignRightOpt[0] + } + return printItem{ + val: val, + alignRight: alignRight, + } +} + +func newTablePrinter() *tablePrinter { + w := tabwriter.NewWriter( + os.Stdout, //output + 0, // min width + 1, // tabwidth + 2, // padding + ' ', // padding character + 0, // align left + ) + + return &tablePrinter{ + w: w, + maxLength: 0, + } +} + +func (tp *tablePrinter) AddRecord(items ...string) { + tp.printRecord(items...) +} + +func (tp *tablePrinter) AlignRight() { + tp.w.Init( + os.Stdout, //output + 0, // min width + 1, // tabwidth + 2, // padding + ' ', // padding character + tabwriter.AlignRight, + ) // flags +} + +func (tp *tablePrinter) AlignLeft() { + tp.w.Init( + os.Stdout, //output + 0, // min width + 1, // tabwidth + 2, // padding + ' ', // padding character + 0, // align left + ) // flags +} + +func (tp *tablePrinter) printRecord(rec ...string) { + var val string + for i, item := range rec { + switch i { + case 0: + val = item + case len(rec) - 1: + val += fmt.Sprintf("\t%s\n", item) + default: + val += fmt.Sprintf("\t%s", item) + } + } + + nbytes, err := tp.w.Write([]byte(val)) + if err != nil { + panic(err) + } + + if nbytes > tp.maxLength { + tp.maxLength = nbytes + } +} + +func (tp *tablePrinter) Print() { + printBorder(tp.maxLength) + printLine(tp.maxLength, tp.title) + printBorder(tp.maxLength) + tp.w.Flush() + printBorder(tp.maxLength) +} + +func padRight(val interface{}, rightEnd int) string { + value := fmt.Sprintf("%v", val) + pad := fmt.Sprintf("%s", strings.Repeat(" ", rightEnd-len(value))) + return fmt.Sprintf("%s%s", pad, value) +} + +func printLine(lineLength int, value string) { + fmt.Printf("| %s %s\n", value, padRight("|", lineLength-len(value)-1)) +} + +func printBorder(length int) { + line := fmt.Sprintf("%s", strings.Repeat("-", length)) + fmt.Printf("+%s+\n", line) +} diff --git a/nn/init.go b/nn/init.go index 87add04..bdc6144 100644 --- a/nn/init.go +++ b/nn/init.go @@ -86,13 +86,17 @@ func (r randnInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...got dtype = dtypeOpt[0] } - // if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 { - if r.mean == 0 { - return ts.MustRandn(dims, dtype, device) - } + ts.NoGrad(func() { + // if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 { + if r.mean == 0 { + retVal = ts.MustRandn(dims, dtype, device) + } - initTs := ts.MustRandn(dims, dtype, device) - return initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true) + initTs := ts.MustRandn(dims, dtype, device) + retVal = initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true) + }) + + return retVal } func (r randnInit) Set(tensor *ts.Tensor) { @@ -101,9 +105,11 @@ func (r randnInit) Set(tensor *ts.Tensor) { log.Fatalf("randInit - Set method call error: %v\n", err) } - initTs := r.InitTensor(dims, tensor.MustDevice()) - tensor.Copy_(initTs) - initTs.MustDrop() + ts.NoGrad(func() { + initTs := r.InitTensor(dims, tensor.MustDevice()) + tensor.Copy_(initTs) + initTs.MustDrop() + }) } // uniformInit : @@ -127,11 +133,13 @@ func (u uniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...g } var err error - retVal = ts.MustZeros(dims, dtype, device) - retVal.Uniform_(u.lo, u.up) - if err != nil { - log.Fatalf("uniformInit - InitTensor method call error: %v\n", err) - } + ts.NoGrad(func() { + retVal = ts.MustZeros(dims, dtype, device) + retVal.Uniform_(u.lo, u.up) + if err != nil { + log.Fatalf("uniformInit - InitTensor method call error: %v\n", err) + } + }) return retVal } @@ -230,8 +238,10 @@ func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device, dtype // Calculate uniform bounds from standard deviation bound := math.Sqrt(3.0) * std - retVal = ts.MustZeros(dims, dtype, device) - retVal.Uniform_(-bound, bound) + ts.NoGrad(func() { + retVal = ts.MustZeros(dims, dtype, device) + retVal.Uniform_(-bound, bound) + }) return retVal } diff --git a/nn/init_test.go b/nn/init_test.go new file mode 100644 index 0000000..88cf3e5 --- /dev/null +++ b/nn/init_test.go @@ -0,0 +1,44 @@ +package nn + +import ( + "fmt" + "testing" + "time" + + "github.com/sugarme/gotch" + "github.com/sugarme/gotch/ts" +) + +// Test whether InitTensor() can cause memory blow-up due to accumulate gradient. +func TestInitTensor_Memcheck(t *testing.T) { + gotch.PrintMemStats("Start") + device := gotch.CPU + vs := NewVarStore(device) + params := 500 + + path := vs.Root() + dims := []int64{1024, 1024} + for i := 0; i < params; i++ { + ts.NoGrad(func() { + name := fmt.Sprintf("param_%v", i) + x := ts.MustRandn(dims, gotch.DefaultDType, device) + path.MustAdd(name, x, false) + x.MustDrop() + }) + } + + // vs.Summary() + + fmt.Printf("vs created...\n") + // printMemStats("After varstore created") + + vs.Destroy() + ts.CleanUp() + + fmt.Printf("vs deleted...\n") + + // printMemStats("After varstore deleted") + + time.Sleep(time.Second * 10) + gotch.PrintMemStats("Final") +} diff --git a/nn/varstore.go b/nn/varstore.go index 15edbed..ec0246c 100644 --- a/nn/varstore.go +++ b/nn/varstore.go @@ -458,6 +458,20 @@ func (vs *VarStore) Summary() { fmt.Printf("DType: %v\n", dtype) } +// Destroy deletes all tensors in varstore and set it to nil. +func (vs *VarStore) Destroy() { + vs.Lock() + for n, v := range vs.vars { + v.Tensor.MustDrop() + + delete(vs.vars, n) + } + + vs.Unlock() + + vs = nil +} + // ToDType casts all variables in VarStore to specified DType. // // NOTE. only float-like types (Half, BFloat16, Float, Double) can ensure convertible.