gotch/nn/init_test.go
Goncalves Henriques, Andre (UG - Computer Science) 9257404edd Move the name of the module
2024-04-21 15:15:00 +01:00

45 lines
863 B
Go

package nn
import (
"fmt"
"testing"
"time"
"git.andr3h3nriqu3s.com/andr3/gotch"
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
)
// Test whether InitTensor() can cause memory blow-up due to accumulate gradient.
func TestInitTensor_Memcheck(t *testing.T) {
gotch.PrintMemStats("Start")
device := gotch.CPU
vs := NewVarStore(device)
params := 500
path := vs.Root()
dims := []int64{1024, 1024}
for i := 0; i < params; i++ {
ts.NoGrad(func() {
name := fmt.Sprintf("param_%v", i)
x := ts.MustRandn(dims, gotch.DefaultDType, device)
path.MustAdd(name, x, false)
x.MustDrop()
})
}
// vs.Summary()
fmt.Printf("vs created...\n")
// printMemStats("After varstore created")
vs.Destroy()
ts.CleanUp()
fmt.Printf("vs deleted...\n")
// printMemStats("After varstore deleted")
time.Sleep(time.Second * 10)
gotch.PrintMemStats("Final")
}