chorus(example/mnist): cleanup. Still the memory blow-out issue

This commit is contained in:
sugarme 2020-06-21 10:57:29 +10:00
parent e0d2e0ca7e
commit 4ffe5feb7a
6 changed files with 124 additions and 174 deletions

View File

@ -13,19 +13,13 @@ const (
Label int64 = 10
MnistDir string = "../../data/mnist"
epochs = 200
batchSize = 256
epochs = 200
)
func runLinear() {
var ds vision.Dataset
ds = vision.LoadMNISTDir(MnistDir)
// fmt.Printf("Train image size: %v\n", ds.TrainImages.MustSize())
// fmt.Printf("Train label size: %v\n", ds.TrainLabels.MustSize())
// fmt.Printf("Test image size: %v\n", ds.TestImages.MustSize())
// fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
device := (gotch.CPU).CInt()
dtype := (gotch.Float).CInt()
@ -33,51 +27,9 @@ func runLinear() {
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
for epoch := 0; epoch < epochs; epoch++ {
/*
* totalSize := ds.TrainImages.MustSize()[0]
* samples := int(totalSize)
* index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
* imagesTs := ds.TrainImages.MustIndexSelect(0, index)
* labelsTs := ds.TrainLabels.MustIndexSelect(0, index)
*
* batches := samples / batchSize
* batchIndex := 0
* var loss ts.Tensor
* for i := 0; i < batches; i++ {
* start := batchIndex * batchSize
* size := batchSize
* if samples-start < batchSize {
* // size = samples - start
* break
* }
* batchIndex += 1
*
* // Indexing
* narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
* bImages := ds.TrainImages.Idx(narrowIndex)
* bLabels := ds.TrainLabels.Idx(narrowIndex)
* // bImages := imagesTs.Idx(narrowIndex)
* // bLabels := labelsTs.Idx(narrowIndex)
*
* logits := bImages.MustMm(ws).MustAdd(bs)
* loss = logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels).MustSetRequiresGrad(true)
*
* ws.ZeroGrad()
* bs.ZeroGrad()
* loss.MustBackward()
*
* ts.NoGrad(func() {
* ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
* bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
* })
* }
*
* imagesTs.MustDrop()
* labelsTs.MustDrop()
* */
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
ws.ZeroGrad()
bs.ZeroGrad()
@ -91,7 +43,7 @@ func runLinear() {
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
lossVal := loss.MustShallowClone().MustView([]int64{-1}).MustFloat64Value([]int64{0})
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100)
}
}

View File

@ -40,12 +40,22 @@ func netInit(vs nn.Path) ts.Module {
return n
}
func train(trainX, trainY, testX, testY ts.Tensor, m ts.Module, opt nn.Optimizer, epoch int) {
loss := m.Forward(trainX).CrossEntropyForLogits(trainY)
opt.BackwardStep(loss)
testAccuracy := m.Forward(testX).AccuracyForLogits(testY).Values()[0]
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100)
}
func runNN() {
var ds vision.Dataset
ds = vision.LoadMNISTDir(MnistDirNN)
vs := nn.NewVarStore(gotch.CPU)
net := netInit(vs.Root())
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
if err != nil {
log.Fatal(err)
@ -53,14 +63,8 @@ func runNN() {
for epoch := 0; epoch < epochsNN; epoch++ {
loss := net.Forward(ds.TrainImages).CrossEntropyForLogits(ds.TrainLabels)
train(ds.TrainImages, ds.TrainLabels, ds.TestImages, ds.TestLabels, net, opt, epoch)
opt.BackwardStep(loss)
lossVal := loss.MustShallowClone().MustView([]int64{-1}).MustFloat64Value([]int64{0})
testAccuracy := net.Forward(ds.TestImages).AccuracyForLogits(ds.TestLabels).MustView([]int64{-1}).MustFloat64Value([]int64{0})
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
fmt.Printf("Loss: %v\n", lossVal)
}
}

View File

@ -10,15 +10,15 @@ import (
// Optimizer is a struct object to run gradient descent.
type Optimizer struct {
opt ts.COptimizer
variables Variables // having embedded sync.Mutex
opt ts.COptimizer
// variables Variables // having embedded sync.Mutex
variablesInOptimizer uint8
config interface{}
}
// OptimizerConfig defines Optimizer configurations. These configs can be used to build optimizer.
type OptimizerConfig interface {
BuildCOpt(lr float64) (retVal ts.COptimizer, err error)
buildCOpt(lr float64) (retVal ts.COptimizer, err error)
// Build builds an optimizer with the specified learning rate handling variables stored in `vs`.
//
@ -35,29 +35,27 @@ type OptimizerConfig interface {
// defaultBuild is `default` Build method for OptimizerConfig interface
func defaultBuild(config OptimizerConfig, vs VarStore, lr float64) (retVal Optimizer, err error) {
opt, err := config.BuildCOpt(lr)
opt, err := config.buildCOpt(lr)
if err != nil {
return retVal, err
}
// vs.variables.mutex.Lock()
// defer vs.variables.mutex.Unlock()
// fmt.Printf("Trainable Variables: \n:%v", len(vs.Variables()))
var parameters []ts.Tensor
for _, v := range vs.Variables() {
parameters = append(parameters, v)
for _, v := range vs.Vars.TrainableVariables {
param := v.MustShallowClone()
parameters = append(parameters, param)
}
// if err = opt.AddParameters(vs.variables.TrainableVariables); err != nil {
if err = opt.AddParameters(parameters); err != nil {
if err = opt.AddParameters(vs.Vars.TrainableVariables); err != nil {
return retVal, err
}
// TODO: should we clone or copy?
return Optimizer{
opt: opt,
variables: vs.variables,
variablesInOptimizer: uint8(len(vs.variables.TrainableVariables)),
opt: opt,
// variables: vs.Vars,
variablesInOptimizer: uint8(len(vs.Vars.TrainableVariables)),
config: config,
}, nil
}
@ -94,7 +92,7 @@ func NewSGDConfig(momentum, dampening, wd float64, nesterov bool) (retVal SGDCon
}
// Implement OptimizerConfig interface for SGDConfig
func (c SGDConfig) BuildCOpt(lr float64) (retVal ts.COptimizer, err error) {
func (c SGDConfig) buildCOpt(lr float64) (retVal ts.COptimizer, err error) {
return ts.Sgd(lr, c.Momentum, c.Dampening, c.Wd, c.Nesterov)
}
@ -130,7 +128,7 @@ func NewAdamConfig(beta1, beta2, wd float64) AdamConfig {
}
// Implement OptimizerConfig interface for AdamConfig
func (c AdamConfig) BuildCOpt(lr float64) (retVal ts.COptimizer, err error) {
func (c AdamConfig) buildCOpt(lr float64) (retVal ts.COptimizer, err error) {
return ts.Adam(lr, c.Beta1, c.Beta2, c.Wd)
}
@ -172,7 +170,7 @@ func NewRMSPropConfig(alpha, eps, wd, momentum float64, centered bool) RMSPropCo
}
// Implement OptimizerConfig interface for RMSPropConfig
func (c RMSPropConfig) BuildCOpt(lr float64) (retVal ts.COptimizer, err error) {
func (c RMSPropConfig) buildCOpt(lr float64) (retVal ts.COptimizer, err error) {
return ts.RmsProp(lr, c.Alpha, c.Eps, c.Wd, c.Momentum, c.Centered)
}
@ -184,15 +182,19 @@ func (c RMSPropConfig) Build(vs VarStore, lr float64) (retVal Optimizer, err err
// ==================
func (opt *Optimizer) addMissingVariables() {
opt.variables.mutex.Lock()
defer opt.variables.mutex.Unlock()
missingVariables := len(opt.variables.TrainableVariables) - int(opt.variablesInOptimizer)
if missingVariables > 0 {
opt.opt.AddParameters(opt.variables.TrainableVariables[opt.variablesInOptimizer:])
opt.variablesInOptimizer = uint8(len(opt.variables.TrainableVariables))
}
// missingVariables := len(opt.variables.TrainableVariables) - int(opt.variablesInOptimizer)
//
// if missingVariables > 0 {
// var tensors []ts.Tensor
// for _, t := range opt.variables.TrainableVariables[opt.variablesInOptimizer:] {
// tensor := t.MustShallowClone()
// tensor.Detach_()
// tensors = append(tensors, tensor)
// }
//
// opt.opt.AddParameters(tensors)
// opt.variablesInOptimizer = uint8(len(opt.variables.TrainableVariables))
// }
}
@ -207,12 +209,12 @@ func (opt *Optimizer) ZeroGrad() {
// Clips gradient value at some specified maximum value.
func (opt *Optimizer) ClipGradValue(max float64) {
opt.variables.mutex.Lock()
defer opt.variables.mutex.Unlock()
// opt.variables.mutex.Lock()
// defer opt.variables.mutex.Unlock()
for _, tensor := range opt.variables.TrainableVariables {
tensor.MustGrad().Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max))
}
// for _, tensor := range opt.variables.TrainableVariables {
// tensor.MustGrad().Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max))
// }
}
// Step performs an optimization step, updating the tracked tensors based on their gradients.

View File

@ -26,8 +26,8 @@ type Variables struct {
// VarStore is used to store variables used by one or multiple layers.
// It specifies a SINGLE device where all variables are stored.
type VarStore struct {
device gotch.Device
variables Variables // TODO: should we export this field
device gotch.Device
Vars Variables
}
// Path is variable store with an associated path for variables naming.
@ -52,8 +52,8 @@ func NewVarStore(device gotch.Device) VarStore {
}
return VarStore{
device: device,
variables: variables,
device: device,
Vars: variables,
}
}
@ -70,36 +70,45 @@ func (vs *VarStore) Device() gotch.Device {
// Len returns the number of tensors currently stored on this var-store
func (vs *VarStore) Len() (retVal int) {
vs.variables.mutex.Lock()
defer vs.variables.mutex.Unlock()
retVal = len(vs.variables.NamedVariables)
vs.Vars.mutex.Lock()
defer vs.Vars.mutex.Unlock()
retVal = len(vs.Vars.NamedVariables)
return retVal
}
// IsEmpty returns true if no tensors are currently stored on this var-store
func (vs *VarStore) IsEmpty() (retVal bool) {
vs.variables.mutex.Lock()
defer vs.variables.mutex.Unlock()
retVal = (len(vs.variables.NamedVariables) == 0)
vs.Vars.mutex.Lock()
defer vs.Vars.mutex.Unlock()
retVal = (len(vs.Vars.NamedVariables) == 0)
return retVal
}
// TrainableVariabless returns all trainable variables for this var-store
func (vs *VarStore) TrainableVariables() (retVal []ts.Tensor) {
vs.variables.mutex.Lock()
defer vs.variables.mutex.Unlock()
retVal = vs.variables.TrainableVariables
vs.Vars.mutex.Lock()
defer vs.Vars.mutex.Unlock()
retVal = vs.Vars.TrainableVariables
for _, t := range vs.Vars.TrainableVariables {
retVal = append(retVal, t.MustShallowClone())
}
return retVal
}
// Variables returns all variables and their names in a map[variable_name]Tensor
func (vs *VarStore) Variables() (retVal map[string]ts.Tensor) {
vs.variables.mutex.Lock()
defer vs.variables.mutex.Unlock()
retVal = vs.variables.NamedVariables
vs.Vars.mutex.Lock()
defer vs.Vars.mutex.Unlock()
retVal = make(map[string]ts.Tensor, 0)
for k, v := range vs.Vars.NamedVariables {
retVal[k] = v.MustShallowClone()
}
return retVal
}
@ -121,12 +130,12 @@ func (vs *VarStore) Root() (retVal Path) {
// NOTE: Weight values for all the tensors currently stored in the
// var-store gets saved in the given file.
func (vs *VarStore) Save(filepath string) (err error) {
vs.variables.mutex.Lock()
defer vs.variables.mutex.Unlock()
vs.Vars.mutex.Lock()
defer vs.Vars.mutex.Unlock()
// Convert map to []NamedTensor
var namedTensors []ts.NamedTensor
for k, v := range vs.variables.NamedVariables {
for k, v := range vs.Vars.NamedVariables {
namedTensors = append(namedTensors, ts.NamedTensor{
Name: k,
Tensor: v,
@ -152,13 +161,13 @@ func (vs *VarStore) Load(filepath string) (err error) {
// Match and in-place copy value (update) from newly loaded tensors
// to existing named tensors if name is matched. Throw error otherwise.
vs.variables.mutex.Lock()
defer vs.variables.mutex.Unlock()
vs.Vars.mutex.Lock()
defer vs.Vars.mutex.Unlock()
for _, namedTs := range namedTensors {
var currTs ts.Tensor
var ok bool
if currTs, ok = vs.variables.NamedVariables[namedTs.Name]; !ok {
if currTs, ok = vs.Vars.NamedVariables[namedTs.Name]; !ok {
err = fmt.Errorf("Cannot find tensor with name: %v in variable store. \n", namedTs.Name)
return err
}
@ -192,13 +201,13 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
// Match and in-place copy value (update) from newly loaded tensors
// to existing named tensors if name is matched. Throw error otherwise.
vs.variables.mutex.Lock()
defer vs.variables.mutex.Unlock()
vs.Vars.mutex.Lock()
defer vs.Vars.mutex.Unlock()
for _, namedTs := range namedTensors {
var currTs ts.Tensor
var ok bool
if currTs, ok = vs.variables.NamedVariables[namedTs.Name]; !ok {
if currTs, ok = vs.Vars.NamedVariables[namedTs.Name]; !ok {
// missing
missingVariables = append(missingVariables, namedTs.Name)
}
@ -217,10 +226,10 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
// Gradients for the variables in this store are not tracked
// anymore.
func (vs *VarStore) Freeze() {
vs.variables.mutex.Lock()
defer vs.variables.mutex.Unlock()
vs.Vars.mutex.Lock()
defer vs.Vars.mutex.Unlock()
for _, v := range vs.variables.TrainableVariables {
for _, v := range vs.Vars.TrainableVariables {
_, err := v.SetRequiresGrad(false)
if err != nil {
log.Fatalf("Freeze() Error: %v\n", err)
@ -232,10 +241,10 @@ func (vs *VarStore) Freeze() {
//
// Gradients for the variables in this store are tracked again.
func (vs *VarStore) Unfreeze() {
vs.variables.mutex.Lock()
defer vs.variables.mutex.Unlock()
vs.Vars.mutex.Lock()
defer vs.Vars.mutex.Unlock()
for _, v := range vs.variables.TrainableVariables {
for _, v := range vs.Vars.TrainableVariables {
_, err := v.SetRequiresGrad(true)
if err != nil {
log.Fatalf("Unfreeze() Error: %v\n", err)
@ -248,22 +257,22 @@ func (vs *VarStore) Unfreeze() {
// All the variables in this var store have to exist with the same
// name in the source var store, otherwise an error is returned.
func (vs *VarStore) Copy(src VarStore) (err error) {
vs.variables.mutex.Lock()
defer vs.variables.mutex.Unlock()
src.variables.mutex.Lock()
defer src.variables.mutex.Unlock()
vs.Vars.mutex.Lock()
defer vs.Vars.mutex.Unlock()
src.Vars.mutex.Lock()
defer src.Vars.mutex.Unlock()
srcNamedVariables := src.variables.NamedVariables
srcNamedVariables := src.Vars.NamedVariables
device := vs.device
for k, _ := range vs.variables.NamedVariables {
for k, _ := range vs.Vars.NamedVariables {
if _, ok := srcNamedVariables[k]; !ok {
err = fmt.Errorf("VarStore copy error: cannot find %v in the source var store.\n", k)
return err
}
}
for k, v := range vs.variables.NamedVariables {
for k, v := range vs.Vars.NamedVariables {
srcTs, _ := srcNamedVariables[k]
srcDevTs, err := srcTs.To(device)
if err != nil {
@ -319,11 +328,11 @@ func (p *Path) getpath(name string) (retVal string) {
func (p *Path) add(name string, newTs ts.Tensor, trainable bool) (retVal ts.Tensor) {
path := p.getpath(name)
p.varstore.variables.mutex.Lock()
defer p.varstore.variables.mutex.Unlock()
p.varstore.Vars.mutex.Lock()
defer p.varstore.Vars.mutex.Unlock()
if _, ok := p.varstore.variables.NamedVariables[path]; ok {
path = fmt.Sprintf("%v__%v", path, len(p.varstore.variables.NamedVariables))
if _, ok := p.varstore.Vars.NamedVariables[path]; ok {
path = fmt.Sprintf("%v__%v", path, len(p.varstore.Vars.NamedVariables))
}
var (
@ -331,19 +340,19 @@ func (p *Path) add(name string, newTs ts.Tensor, trainable bool) (retVal ts.Tens
err error
)
if trainable {
tensor, err = newTs.SetRequiresGrad(true)
tensor, err = newTs.MustShallowClone().SetRequiresGrad(true)
if err != nil {
log.Fatalf("Path 'add' method error: %v\n", err)
}
} else {
tensor = newTs
tensor = newTs.MustShallowClone()
}
if trainable {
p.varstore.variables.TrainableVariables = append(p.varstore.variables.TrainableVariables, tensor)
p.varstore.Vars.TrainableVariables = append(p.varstore.Vars.TrainableVariables, tensor)
}
p.varstore.variables.NamedVariables[path] = tensor
p.varstore.Vars.NamedVariables[path] = tensor
return tensor
}
@ -522,10 +531,10 @@ func (p *Path) VarCopy(name string, t ts.Tensor) (retVal ts.Tensor) {
// Get gets the tensor corresponding to a given name if present.
func (p *Path) Get(name string) (retVal ts.Tensor, err error) {
p.varstore.variables.mutex.Lock()
defer p.varstore.variables.mutex.Unlock()
p.varstore.Vars.mutex.Lock()
defer p.varstore.Vars.mutex.Unlock()
v, ok := p.varstore.variables.NamedVariables[name]
v, ok := p.varstore.Vars.NamedVariables[name]
if !ok {
err = fmt.Errorf("Path - Get method call error: Cannot find variable for name: %v\n", name)
return retVal, err
@ -536,12 +545,12 @@ func (p *Path) Get(name string) (retVal ts.Tensor, err error) {
// Entry gets the entry corresponding to a given name for in-place manipulation.
func (p *Path) Entry(name string) (retVal Entry) {
p.varstore.variables.mutex.Lock()
defer p.varstore.variables.mutex.Unlock()
p.varstore.Vars.mutex.Lock()
defer p.varstore.Vars.mutex.Unlock()
return Entry{
name: name,
variables: &p.varstore.variables,
variables: &p.varstore.Vars,
path: p,
}
}

View File

@ -80,42 +80,22 @@ func (ts Tensor) MustGrad() (retVal Tensor) {
return retVal
}
func (ts Tensor) Detach_() (retVal Tensor, err error) {
func (ts Tensor) Detach_() {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgDetach_(ptr, ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustDetach_() (retVal Tensor) {
retVal, err := ts.Detach_()
if err != nil {
if err := TorchErr(); err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Zero_() (err error) {
func (ts Tensor) Zero_() {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgZero_(ptr, ts.ctensor)
if err = TorchErr(); err != nil {
return err
}
return nil
}
func (ts Tensor) MustZero_() {
err := ts.Zero_()
if err != nil {
if err := TorchErr(); err != nil {
log.Fatal(err)
}
}

View File

@ -10,7 +10,6 @@ import (
"fmt"
"log"
"reflect"
// "strings"
"unsafe"
gotch "github.com/sugarme/gotch"
@ -432,11 +431,8 @@ func (ts Tensor) IsSparse() (retVal bool, err error) {
func (ts Tensor) ZeroGrad() {
grad := ts.MustGrad()
if grad.MustDefined() {
// TODO: can we chain them?
// grad.MustDetach_().MustZero_()
// https://www.calhoun.io/using-functional-options-instead-of-method-chaining-in-go/
detach := grad.MustDetach_()
detach.MustZero_()
grad.Detach_()
grad.Zero_()
}
}
@ -989,3 +985,10 @@ func (r Reduction) ToInt() (retVal int) {
}
return
}
// Values returns values of tensor in a slice of float64.
func (ts Tensor) Values() []float64 {
clone := ts.MustShallowClone()
clone.Detach_()
return []float64{clone.MustView([]int64{-1}).MustFloat64Value([]int64{-1})}
}