fixed mem blow-up due to gradient accumulation when loading model pretrained

This commit is contained in:
sugarme 2023-08-12 15:46:51 +10:00
parent ef00723027
commit 163e625426
5 changed files with 232 additions and 18 deletions

View File

@ -5,6 +5,7 @@ import (
"log"
"runtime"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
"github.com/sugarme/gotch/ts"
"github.com/sugarme/gotch/vision"
@ -14,14 +15,13 @@ const (
ImageDimNN int64 = 784
HiddenNodesNN int64 = 128
LabelNN int64 = 10
// MnistDirNN string = "../../data/mnist"
MnistDirNN string = "/mnt/projects/numbat/data/mnist"
epochsNN = 200
LrNN = 1e-3
)
var MnistDirNN string = fmt.Sprintf("%s/%s", gotch.CachedDir, "mnist")
var l nn.Linear
func netInit(vs *nn.Path) ts.Module {

146
mem-util.go Normal file
View File

@ -0,0 +1,146 @@
package gotch
// helper to debug memory blow-up
import (
"fmt"
"os"
"runtime"
"strings"
"text/tabwriter"
)
func PrintMemStats(messageOpt ...string) {
message := "Memory Stats"
if len(messageOpt) > 0 {
message = fmt.Sprintf("%s: %s", message, messageOpt[0])
}
var rtm runtime.MemStats
runtime.ReadMemStats(&rtm)
tp := newTablePrinter()
tp.title = message
tp.AddRecord("|", "Allocated heap objects", padRight(fmt.Sprintf("%v", rtm.Mallocs), 10), "|")
tp.AddRecord("|", "Released heap objects", padRight(fmt.Sprintf("%v", rtm.Frees), 10), "|")
tp.AddRecord("|", "Living heap objects", padRight(fmt.Sprintf("%v", rtm.HeapObjects), 10), "|")
tp.AddRecord("|", "Memory in use by heap objects (bytes)", padRight(fmt.Sprintf("%v", rtm.HeapAlloc), 10), "|")
tp.AddRecord("|", "Reserved memory (by Go runtime for heap, stack,...) (bytes)", padRight(fmt.Sprintf("%v", rtm.Sys), 10), "|")
tp.AddRecord("|", "Total pause time by GC (nanoseconds)", padRight(fmt.Sprintf("%v", rtm.PauseTotalNs), 10), "|")
tp.AddRecord("|", "Number of GC called", padRight(fmt.Sprintf("%v", rtm.NumGC), 10), "|")
// tp.AddRecord("Last GC called", fmt.Sprintf("%v", time.UnixMilli(int64(rtm.LastGC/1_000_000))))
tp.Print()
}
type tablePrinter struct {
w *tabwriter.Writer
maxLength int
title string
}
type printItem struct {
val string
alignRight bool
}
func item(val string, alignRightOpt ...bool) printItem {
alignRight := false
if len(alignRightOpt) > 0 {
alignRight = alignRightOpt[0]
}
return printItem{
val: val,
alignRight: alignRight,
}
}
func newTablePrinter() *tablePrinter {
w := tabwriter.NewWriter(
os.Stdout, //output
0, // min width
1, // tabwidth
2, // padding
' ', // padding character
0, // align left
)
return &tablePrinter{
w: w,
maxLength: 0,
}
}
func (tp *tablePrinter) AddRecord(items ...string) {
tp.printRecord(items...)
}
func (tp *tablePrinter) AlignRight() {
tp.w.Init(
os.Stdout, //output
0, // min width
1, // tabwidth
2, // padding
' ', // padding character
tabwriter.AlignRight,
) // flags
}
func (tp *tablePrinter) AlignLeft() {
tp.w.Init(
os.Stdout, //output
0, // min width
1, // tabwidth
2, // padding
' ', // padding character
0, // align left
) // flags
}
func (tp *tablePrinter) printRecord(rec ...string) {
var val string
for i, item := range rec {
switch i {
case 0:
val = item
case len(rec) - 1:
val += fmt.Sprintf("\t%s\n", item)
default:
val += fmt.Sprintf("\t%s", item)
}
}
nbytes, err := tp.w.Write([]byte(val))
if err != nil {
panic(err)
}
if nbytes > tp.maxLength {
tp.maxLength = nbytes
}
}
func (tp *tablePrinter) Print() {
printBorder(tp.maxLength)
printLine(tp.maxLength, tp.title)
printBorder(tp.maxLength)
tp.w.Flush()
printBorder(tp.maxLength)
}
func padRight(val interface{}, rightEnd int) string {
value := fmt.Sprintf("%v", val)
pad := fmt.Sprintf("%s", strings.Repeat(" ", rightEnd-len(value)))
return fmt.Sprintf("%s%s", pad, value)
}
func printLine(lineLength int, value string) {
fmt.Printf("| %s %s\n", value, padRight("|", lineLength-len(value)-1))
}
func printBorder(length int) {
line := fmt.Sprintf("%s", strings.Repeat("-", length))
fmt.Printf("+%s+\n", line)
}

View File

@ -86,13 +86,17 @@ func (r randnInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...got
dtype = dtypeOpt[0]
}
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
if r.mean == 0 {
return ts.MustRandn(dims, dtype, device)
}
ts.NoGrad(func() {
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
if r.mean == 0 {
retVal = ts.MustRandn(dims, dtype, device)
}
initTs := ts.MustRandn(dims, dtype, device)
return initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true)
initTs := ts.MustRandn(dims, dtype, device)
retVal = initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true)
})
return retVal
}
func (r randnInit) Set(tensor *ts.Tensor) {
@ -101,9 +105,11 @@ func (r randnInit) Set(tensor *ts.Tensor) {
log.Fatalf("randInit - Set method call error: %v\n", err)
}
initTs := r.InitTensor(dims, tensor.MustDevice())
tensor.Copy_(initTs)
initTs.MustDrop()
ts.NoGrad(func() {
initTs := r.InitTensor(dims, tensor.MustDevice())
tensor.Copy_(initTs)
initTs.MustDrop()
})
}
// uniformInit :
@ -127,11 +133,13 @@ func (u uniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...g
}
var err error
retVal = ts.MustZeros(dims, dtype, device)
retVal.Uniform_(u.lo, u.up)
if err != nil {
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
}
ts.NoGrad(func() {
retVal = ts.MustZeros(dims, dtype, device)
retVal.Uniform_(u.lo, u.up)
if err != nil {
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
}
})
return retVal
}
@ -230,8 +238,10 @@ func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device, dtype
// Calculate uniform bounds from standard deviation
bound := math.Sqrt(3.0) * std
retVal = ts.MustZeros(dims, dtype, device)
retVal.Uniform_(-bound, bound)
ts.NoGrad(func() {
retVal = ts.MustZeros(dims, dtype, device)
retVal.Uniform_(-bound, bound)
})
return retVal
}

44
nn/init_test.go Normal file
View File

@ -0,0 +1,44 @@
package nn
import (
"fmt"
"testing"
"time"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/ts"
)
// Test whether InitTensor() can cause memory blow-up due to accumulate gradient.
func TestInitTensor_Memcheck(t *testing.T) {
gotch.PrintMemStats("Start")
device := gotch.CPU
vs := NewVarStore(device)
params := 500
path := vs.Root()
dims := []int64{1024, 1024}
for i := 0; i < params; i++ {
ts.NoGrad(func() {
name := fmt.Sprintf("param_%v", i)
x := ts.MustRandn(dims, gotch.DefaultDType, device)
path.MustAdd(name, x, false)
x.MustDrop()
})
}
// vs.Summary()
fmt.Printf("vs created...\n")
// printMemStats("After varstore created")
vs.Destroy()
ts.CleanUp()
fmt.Printf("vs deleted...\n")
// printMemStats("After varstore deleted")
time.Sleep(time.Second * 10)
gotch.PrintMemStats("Final")
}

View File

@ -458,6 +458,20 @@ func (vs *VarStore) Summary() {
fmt.Printf("DType: %v\n", dtype)
}
// Destroy deletes all tensors in varstore and set it to nil.
func (vs *VarStore) Destroy() {
vs.Lock()
for n, v := range vs.vars {
v.Tensor.MustDrop()
delete(vs.vars, n)
}
vs.Unlock()
vs = nil
}
// ToDType casts all variables in VarStore to specified DType.
//
// NOTE. only float-like types (Half, BFloat16, Float, Double) can ensure convertible.