fixed mem blow-up due to gradient accumulation when loading model pretrained
This commit is contained in:
parent
ef00723027
commit
163e625426
|
@ -5,6 +5,7 @@ import (
|
|||
"log"
|
||||
"runtime"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
|
@ -14,14 +15,13 @@ const (
|
|||
ImageDimNN int64 = 784
|
||||
HiddenNodesNN int64 = 128
|
||||
LabelNN int64 = 10
|
||||
// MnistDirNN string = "../../data/mnist"
|
||||
MnistDirNN string = "/mnt/projects/numbat/data/mnist"
|
||||
|
||||
epochsNN = 200
|
||||
|
||||
LrNN = 1e-3
|
||||
)
|
||||
|
||||
var MnistDirNN string = fmt.Sprintf("%s/%s", gotch.CachedDir, "mnist")
|
||||
var l nn.Linear
|
||||
|
||||
func netInit(vs *nn.Path) ts.Module {
|
||||
|
|
146
mem-util.go
Normal file
146
mem-util.go
Normal file
|
@ -0,0 +1,146 @@
|
|||
package gotch
|
||||
|
||||
// helper to debug memory blow-up
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
)
|
||||
|
||||
func PrintMemStats(messageOpt ...string) {
|
||||
message := "Memory Stats"
|
||||
if len(messageOpt) > 0 {
|
||||
message = fmt.Sprintf("%s: %s", message, messageOpt[0])
|
||||
}
|
||||
|
||||
var rtm runtime.MemStats
|
||||
runtime.ReadMemStats(&rtm)
|
||||
|
||||
tp := newTablePrinter()
|
||||
tp.title = message
|
||||
|
||||
tp.AddRecord("|", "Allocated heap objects", padRight(fmt.Sprintf("%v", rtm.Mallocs), 10), "|")
|
||||
tp.AddRecord("|", "Released heap objects", padRight(fmt.Sprintf("%v", rtm.Frees), 10), "|")
|
||||
tp.AddRecord("|", "Living heap objects", padRight(fmt.Sprintf("%v", rtm.HeapObjects), 10), "|")
|
||||
tp.AddRecord("|", "Memory in use by heap objects (bytes)", padRight(fmt.Sprintf("%v", rtm.HeapAlloc), 10), "|")
|
||||
tp.AddRecord("|", "Reserved memory (by Go runtime for heap, stack,...) (bytes)", padRight(fmt.Sprintf("%v", rtm.Sys), 10), "|")
|
||||
tp.AddRecord("|", "Total pause time by GC (nanoseconds)", padRight(fmt.Sprintf("%v", rtm.PauseTotalNs), 10), "|")
|
||||
tp.AddRecord("|", "Number of GC called", padRight(fmt.Sprintf("%v", rtm.NumGC), 10), "|")
|
||||
// tp.AddRecord("Last GC called", fmt.Sprintf("%v", time.UnixMilli(int64(rtm.LastGC/1_000_000))))
|
||||
|
||||
tp.Print()
|
||||
|
||||
}
|
||||
|
||||
type tablePrinter struct {
|
||||
w *tabwriter.Writer
|
||||
maxLength int
|
||||
title string
|
||||
}
|
||||
|
||||
type printItem struct {
|
||||
val string
|
||||
alignRight bool
|
||||
}
|
||||
|
||||
func item(val string, alignRightOpt ...bool) printItem {
|
||||
alignRight := false
|
||||
if len(alignRightOpt) > 0 {
|
||||
alignRight = alignRightOpt[0]
|
||||
}
|
||||
return printItem{
|
||||
val: val,
|
||||
alignRight: alignRight,
|
||||
}
|
||||
}
|
||||
|
||||
func newTablePrinter() *tablePrinter {
|
||||
w := tabwriter.NewWriter(
|
||||
os.Stdout, //output
|
||||
0, // min width
|
||||
1, // tabwidth
|
||||
2, // padding
|
||||
' ', // padding character
|
||||
0, // align left
|
||||
)
|
||||
|
||||
return &tablePrinter{
|
||||
w: w,
|
||||
maxLength: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (tp *tablePrinter) AddRecord(items ...string) {
|
||||
tp.printRecord(items...)
|
||||
}
|
||||
|
||||
func (tp *tablePrinter) AlignRight() {
|
||||
tp.w.Init(
|
||||
os.Stdout, //output
|
||||
0, // min width
|
||||
1, // tabwidth
|
||||
2, // padding
|
||||
' ', // padding character
|
||||
tabwriter.AlignRight,
|
||||
) // flags
|
||||
}
|
||||
|
||||
func (tp *tablePrinter) AlignLeft() {
|
||||
tp.w.Init(
|
||||
os.Stdout, //output
|
||||
0, // min width
|
||||
1, // tabwidth
|
||||
2, // padding
|
||||
' ', // padding character
|
||||
0, // align left
|
||||
) // flags
|
||||
}
|
||||
|
||||
func (tp *tablePrinter) printRecord(rec ...string) {
|
||||
var val string
|
||||
for i, item := range rec {
|
||||
switch i {
|
||||
case 0:
|
||||
val = item
|
||||
case len(rec) - 1:
|
||||
val += fmt.Sprintf("\t%s\n", item)
|
||||
default:
|
||||
val += fmt.Sprintf("\t%s", item)
|
||||
}
|
||||
}
|
||||
|
||||
nbytes, err := tp.w.Write([]byte(val))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if nbytes > tp.maxLength {
|
||||
tp.maxLength = nbytes
|
||||
}
|
||||
}
|
||||
|
||||
func (tp *tablePrinter) Print() {
|
||||
printBorder(tp.maxLength)
|
||||
printLine(tp.maxLength, tp.title)
|
||||
printBorder(tp.maxLength)
|
||||
tp.w.Flush()
|
||||
printBorder(tp.maxLength)
|
||||
}
|
||||
|
||||
func padRight(val interface{}, rightEnd int) string {
|
||||
value := fmt.Sprintf("%v", val)
|
||||
pad := fmt.Sprintf("%s", strings.Repeat(" ", rightEnd-len(value)))
|
||||
return fmt.Sprintf("%s%s", pad, value)
|
||||
}
|
||||
|
||||
func printLine(lineLength int, value string) {
|
||||
fmt.Printf("| %s %s\n", value, padRight("|", lineLength-len(value)-1))
|
||||
}
|
||||
|
||||
func printBorder(length int) {
|
||||
line := fmt.Sprintf("%s", strings.Repeat("-", length))
|
||||
fmt.Printf("+%s+\n", line)
|
||||
}
|
42
nn/init.go
42
nn/init.go
|
@ -86,13 +86,17 @@ func (r randnInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...got
|
|||
dtype = dtypeOpt[0]
|
||||
}
|
||||
|
||||
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
|
||||
if r.mean == 0 {
|
||||
return ts.MustRandn(dims, dtype, device)
|
||||
}
|
||||
ts.NoGrad(func() {
|
||||
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
|
||||
if r.mean == 0 {
|
||||
retVal = ts.MustRandn(dims, dtype, device)
|
||||
}
|
||||
|
||||
initTs := ts.MustRandn(dims, dtype, device)
|
||||
return initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true)
|
||||
initTs := ts.MustRandn(dims, dtype, device)
|
||||
retVal = initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true)
|
||||
})
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (r randnInit) Set(tensor *ts.Tensor) {
|
||||
|
@ -101,9 +105,11 @@ func (r randnInit) Set(tensor *ts.Tensor) {
|
|||
log.Fatalf("randInit - Set method call error: %v\n", err)
|
||||
}
|
||||
|
||||
initTs := r.InitTensor(dims, tensor.MustDevice())
|
||||
tensor.Copy_(initTs)
|
||||
initTs.MustDrop()
|
||||
ts.NoGrad(func() {
|
||||
initTs := r.InitTensor(dims, tensor.MustDevice())
|
||||
tensor.Copy_(initTs)
|
||||
initTs.MustDrop()
|
||||
})
|
||||
}
|
||||
|
||||
// uniformInit :
|
||||
|
@ -127,11 +133,13 @@ func (u uniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...g
|
|||
}
|
||||
|
||||
var err error
|
||||
retVal = ts.MustZeros(dims, dtype, device)
|
||||
retVal.Uniform_(u.lo, u.up)
|
||||
if err != nil {
|
||||
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
|
||||
}
|
||||
ts.NoGrad(func() {
|
||||
retVal = ts.MustZeros(dims, dtype, device)
|
||||
retVal.Uniform_(u.lo, u.up)
|
||||
if err != nil {
|
||||
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
|
||||
}
|
||||
})
|
||||
return retVal
|
||||
}
|
||||
|
||||
|
@ -230,8 +238,10 @@ func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device, dtype
|
|||
// Calculate uniform bounds from standard deviation
|
||||
bound := math.Sqrt(3.0) * std
|
||||
|
||||
retVal = ts.MustZeros(dims, dtype, device)
|
||||
retVal.Uniform_(-bound, bound)
|
||||
ts.NoGrad(func() {
|
||||
retVal = ts.MustZeros(dims, dtype, device)
|
||||
retVal.Uniform_(-bound, bound)
|
||||
})
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
44
nn/init_test.go
Normal file
44
nn/init_test.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package nn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// Test whether InitTensor() can cause memory blow-up due to accumulate gradient.
|
||||
func TestInitTensor_Memcheck(t *testing.T) {
|
||||
gotch.PrintMemStats("Start")
|
||||
device := gotch.CPU
|
||||
vs := NewVarStore(device)
|
||||
params := 500
|
||||
|
||||
path := vs.Root()
|
||||
dims := []int64{1024, 1024}
|
||||
for i := 0; i < params; i++ {
|
||||
ts.NoGrad(func() {
|
||||
name := fmt.Sprintf("param_%v", i)
|
||||
x := ts.MustRandn(dims, gotch.DefaultDType, device)
|
||||
path.MustAdd(name, x, false)
|
||||
x.MustDrop()
|
||||
})
|
||||
}
|
||||
|
||||
// vs.Summary()
|
||||
|
||||
fmt.Printf("vs created...\n")
|
||||
// printMemStats("After varstore created")
|
||||
|
||||
vs.Destroy()
|
||||
ts.CleanUp()
|
||||
|
||||
fmt.Printf("vs deleted...\n")
|
||||
|
||||
// printMemStats("After varstore deleted")
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
gotch.PrintMemStats("Final")
|
||||
}
|
|
@ -458,6 +458,20 @@ func (vs *VarStore) Summary() {
|
|||
fmt.Printf("DType: %v\n", dtype)
|
||||
}
|
||||
|
||||
// Destroy deletes all tensors in varstore and set it to nil.
|
||||
func (vs *VarStore) Destroy() {
|
||||
vs.Lock()
|
||||
for n, v := range vs.vars {
|
||||
v.Tensor.MustDrop()
|
||||
|
||||
delete(vs.vars, n)
|
||||
}
|
||||
|
||||
vs.Unlock()
|
||||
|
||||
vs = nil
|
||||
}
|
||||
|
||||
// ToDType casts all variables in VarStore to specified DType.
|
||||
//
|
||||
// NOTE. only float-like types (Half, BFloat16, Float, Double) can ensure convertible.
|
||||
|
|
Loading…
Reference in New Issue
Block a user