chorus(example/mnist): cleanup. Still the memory blow-out issue
This commit is contained in:
parent
e0d2e0ca7e
commit
4ffe5feb7a
|
@ -13,19 +13,13 @@ const (
|
|||
Label int64 = 10
|
||||
MnistDir string = "../../data/mnist"
|
||||
|
||||
epochs = 200
|
||||
batchSize = 256
|
||||
epochs = 200
|
||||
)
|
||||
|
||||
func runLinear() {
|
||||
var ds vision.Dataset
|
||||
ds = vision.LoadMNISTDir(MnistDir)
|
||||
|
||||
// fmt.Printf("Train image size: %v\n", ds.TrainImages.MustSize())
|
||||
// fmt.Printf("Train label size: %v\n", ds.TrainLabels.MustSize())
|
||||
// fmt.Printf("Test image size: %v\n", ds.TestImages.MustSize())
|
||||
// fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
|
||||
|
||||
device := (gotch.CPU).CInt()
|
||||
dtype := (gotch.Float).CInt()
|
||||
|
||||
|
@ -33,51 +27,9 @@ func runLinear() {
|
|||
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
|
||||
|
||||
for epoch := 0; epoch < epochs; epoch++ {
|
||||
/*
|
||||
* totalSize := ds.TrainImages.MustSize()[0]
|
||||
* samples := int(totalSize)
|
||||
* index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
|
||||
* imagesTs := ds.TrainImages.MustIndexSelect(0, index)
|
||||
* labelsTs := ds.TrainLabels.MustIndexSelect(0, index)
|
||||
*
|
||||
* batches := samples / batchSize
|
||||
* batchIndex := 0
|
||||
* var loss ts.Tensor
|
||||
* for i := 0; i < batches; i++ {
|
||||
* start := batchIndex * batchSize
|
||||
* size := batchSize
|
||||
* if samples-start < batchSize {
|
||||
* // size = samples - start
|
||||
* break
|
||||
* }
|
||||
* batchIndex += 1
|
||||
*
|
||||
* // Indexing
|
||||
* narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
|
||||
* bImages := ds.TrainImages.Idx(narrowIndex)
|
||||
* bLabels := ds.TrainLabels.Idx(narrowIndex)
|
||||
* // bImages := imagesTs.Idx(narrowIndex)
|
||||
* // bLabels := labelsTs.Idx(narrowIndex)
|
||||
*
|
||||
* logits := bImages.MustMm(ws).MustAdd(bs)
|
||||
* loss = logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels).MustSetRequiresGrad(true)
|
||||
*
|
||||
* ws.ZeroGrad()
|
||||
* bs.ZeroGrad()
|
||||
* loss.MustBackward()
|
||||
*
|
||||
* ts.NoGrad(func() {
|
||||
* ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
* bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
* })
|
||||
* }
|
||||
*
|
||||
* imagesTs.MustDrop()
|
||||
* labelsTs.MustDrop()
|
||||
* */
|
||||
|
||||
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
|
||||
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
|
||||
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
|
||||
|
||||
ws.ZeroGrad()
|
||||
bs.ZeroGrad()
|
||||
|
@ -91,7 +43,7 @@ func runLinear() {
|
|||
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
|
||||
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
|
||||
lossVal := loss.MustShallowClone().MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
|
||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100)
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,12 +40,22 @@ func netInit(vs nn.Path) ts.Module {
|
|||
return n
|
||||
}
|
||||
|
||||
func train(trainX, trainY, testX, testY ts.Tensor, m ts.Module, opt nn.Optimizer, epoch int) {
|
||||
loss := m.Forward(trainX).CrossEntropyForLogits(trainY)
|
||||
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
testAccuracy := m.Forward(testX).AccuracyForLogits(testY).Values()[0]
|
||||
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100)
|
||||
|
||||
}
|
||||
|
||||
func runNN() {
|
||||
|
||||
var ds vision.Dataset
|
||||
ds = vision.LoadMNISTDir(MnistDirNN)
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
net := netInit(vs.Root())
|
||||
|
||||
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -53,14 +63,8 @@ func runNN() {
|
|||
|
||||
for epoch := 0; epoch < epochsNN; epoch++ {
|
||||
|
||||
loss := net.Forward(ds.TrainImages).CrossEntropyForLogits(ds.TrainLabels)
|
||||
train(ds.TrainImages, ds.TrainLabels, ds.TestImages, ds.TestLabels, net, opt, epoch)
|
||||
|
||||
opt.BackwardStep(loss)
|
||||
lossVal := loss.MustShallowClone().MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
testAccuracy := net.Forward(ds.TestImages).AccuracyForLogits(ds.TestLabels).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
|
||||
|
||||
fmt.Printf("Loss: %v\n", lossVal)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -10,15 +10,15 @@ import (
|
|||
|
||||
// Optimizer is a struct object to run gradient descent.
|
||||
type Optimizer struct {
|
||||
opt ts.COptimizer
|
||||
variables Variables // having embedded sync.Mutex
|
||||
opt ts.COptimizer
|
||||
// variables Variables // having embedded sync.Mutex
|
||||
variablesInOptimizer uint8
|
||||
config interface{}
|
||||
}
|
||||
|
||||
// OptimizerConfig defines Optimizer configurations. These configs can be used to build optimizer.
|
||||
type OptimizerConfig interface {
|
||||
BuildCOpt(lr float64) (retVal ts.COptimizer, err error)
|
||||
buildCOpt(lr float64) (retVal ts.COptimizer, err error)
|
||||
|
||||
// Build builds an optimizer with the specified learning rate handling variables stored in `vs`.
|
||||
//
|
||||
|
@ -35,29 +35,27 @@ type OptimizerConfig interface {
|
|||
// defaultBuild is `default` Build method for OptimizerConfig interface
|
||||
func defaultBuild(config OptimizerConfig, vs VarStore, lr float64) (retVal Optimizer, err error) {
|
||||
|
||||
opt, err := config.BuildCOpt(lr)
|
||||
opt, err := config.buildCOpt(lr)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
// vs.variables.mutex.Lock()
|
||||
// defer vs.variables.mutex.Unlock()
|
||||
|
||||
// fmt.Printf("Trainable Variables: \n:%v", len(vs.Variables()))
|
||||
var parameters []ts.Tensor
|
||||
for _, v := range vs.Variables() {
|
||||
parameters = append(parameters, v)
|
||||
for _, v := range vs.Vars.TrainableVariables {
|
||||
param := v.MustShallowClone()
|
||||
parameters = append(parameters, param)
|
||||
}
|
||||
|
||||
// if err = opt.AddParameters(vs.variables.TrainableVariables); err != nil {
|
||||
if err = opt.AddParameters(parameters); err != nil {
|
||||
if err = opt.AddParameters(vs.Vars.TrainableVariables); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
// TODO: should we clone or copy?
|
||||
|
||||
return Optimizer{
|
||||
opt: opt,
|
||||
variables: vs.variables,
|
||||
variablesInOptimizer: uint8(len(vs.variables.TrainableVariables)),
|
||||
opt: opt,
|
||||
// variables: vs.Vars,
|
||||
variablesInOptimizer: uint8(len(vs.Vars.TrainableVariables)),
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
@ -94,7 +92,7 @@ func NewSGDConfig(momentum, dampening, wd float64, nesterov bool) (retVal SGDCon
|
|||
}
|
||||
|
||||
// Implement OptimizerConfig interface for SGDConfig
|
||||
func (c SGDConfig) BuildCOpt(lr float64) (retVal ts.COptimizer, err error) {
|
||||
func (c SGDConfig) buildCOpt(lr float64) (retVal ts.COptimizer, err error) {
|
||||
return ts.Sgd(lr, c.Momentum, c.Dampening, c.Wd, c.Nesterov)
|
||||
}
|
||||
|
||||
|
@ -130,7 +128,7 @@ func NewAdamConfig(beta1, beta2, wd float64) AdamConfig {
|
|||
}
|
||||
|
||||
// Implement OptimizerConfig interface for AdamConfig
|
||||
func (c AdamConfig) BuildCOpt(lr float64) (retVal ts.COptimizer, err error) {
|
||||
func (c AdamConfig) buildCOpt(lr float64) (retVal ts.COptimizer, err error) {
|
||||
return ts.Adam(lr, c.Beta1, c.Beta2, c.Wd)
|
||||
}
|
||||
|
||||
|
@ -172,7 +170,7 @@ func NewRMSPropConfig(alpha, eps, wd, momentum float64, centered bool) RMSPropCo
|
|||
}
|
||||
|
||||
// Implement OptimizerConfig interface for RMSPropConfig
|
||||
func (c RMSPropConfig) BuildCOpt(lr float64) (retVal ts.COptimizer, err error) {
|
||||
func (c RMSPropConfig) buildCOpt(lr float64) (retVal ts.COptimizer, err error) {
|
||||
return ts.RmsProp(lr, c.Alpha, c.Eps, c.Wd, c.Momentum, c.Centered)
|
||||
}
|
||||
|
||||
|
@ -184,15 +182,19 @@ func (c RMSPropConfig) Build(vs VarStore, lr float64) (retVal Optimizer, err err
|
|||
// ==================
|
||||
func (opt *Optimizer) addMissingVariables() {
|
||||
|
||||
opt.variables.mutex.Lock()
|
||||
defer opt.variables.mutex.Unlock()
|
||||
|
||||
missingVariables := len(opt.variables.TrainableVariables) - int(opt.variablesInOptimizer)
|
||||
|
||||
if missingVariables > 0 {
|
||||
opt.opt.AddParameters(opt.variables.TrainableVariables[opt.variablesInOptimizer:])
|
||||
opt.variablesInOptimizer = uint8(len(opt.variables.TrainableVariables))
|
||||
}
|
||||
// missingVariables := len(opt.variables.TrainableVariables) - int(opt.variablesInOptimizer)
|
||||
//
|
||||
// if missingVariables > 0 {
|
||||
// var tensors []ts.Tensor
|
||||
// for _, t := range opt.variables.TrainableVariables[opt.variablesInOptimizer:] {
|
||||
// tensor := t.MustShallowClone()
|
||||
// tensor.Detach_()
|
||||
// tensors = append(tensors, tensor)
|
||||
// }
|
||||
//
|
||||
// opt.opt.AddParameters(tensors)
|
||||
// opt.variablesInOptimizer = uint8(len(opt.variables.TrainableVariables))
|
||||
// }
|
||||
|
||||
}
|
||||
|
||||
|
@ -207,12 +209,12 @@ func (opt *Optimizer) ZeroGrad() {
|
|||
// Clips gradient value at some specified maximum value.
|
||||
func (opt *Optimizer) ClipGradValue(max float64) {
|
||||
|
||||
opt.variables.mutex.Lock()
|
||||
defer opt.variables.mutex.Unlock()
|
||||
// opt.variables.mutex.Lock()
|
||||
// defer opt.variables.mutex.Unlock()
|
||||
|
||||
for _, tensor := range opt.variables.TrainableVariables {
|
||||
tensor.MustGrad().Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max))
|
||||
}
|
||||
// for _, tensor := range opt.variables.TrainableVariables {
|
||||
// tensor.MustGrad().Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max))
|
||||
// }
|
||||
}
|
||||
|
||||
// Step performs an optimization step, updating the tracked tensors based on their gradients.
|
||||
|
|
113
nn/varstore.go
113
nn/varstore.go
|
@ -26,8 +26,8 @@ type Variables struct {
|
|||
// VarStore is used to store variables used by one or multiple layers.
|
||||
// It specifies a SINGLE device where all variables are stored.
|
||||
type VarStore struct {
|
||||
device gotch.Device
|
||||
variables Variables // TODO: should we export this field
|
||||
device gotch.Device
|
||||
Vars Variables
|
||||
}
|
||||
|
||||
// Path is variable store with an associated path for variables naming.
|
||||
|
@ -52,8 +52,8 @@ func NewVarStore(device gotch.Device) VarStore {
|
|||
}
|
||||
|
||||
return VarStore{
|
||||
device: device,
|
||||
variables: variables,
|
||||
device: device,
|
||||
Vars: variables,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,36 +70,45 @@ func (vs *VarStore) Device() gotch.Device {
|
|||
|
||||
// Len returns the number of tensors currently stored on this var-store
|
||||
func (vs *VarStore) Len() (retVal int) {
|
||||
vs.variables.mutex.Lock()
|
||||
defer vs.variables.mutex.Unlock()
|
||||
retVal = len(vs.variables.NamedVariables)
|
||||
vs.Vars.mutex.Lock()
|
||||
defer vs.Vars.mutex.Unlock()
|
||||
retVal = len(vs.Vars.NamedVariables)
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// IsEmpty returns true if no tensors are currently stored on this var-store
|
||||
func (vs *VarStore) IsEmpty() (retVal bool) {
|
||||
vs.variables.mutex.Lock()
|
||||
defer vs.variables.mutex.Unlock()
|
||||
retVal = (len(vs.variables.NamedVariables) == 0)
|
||||
vs.Vars.mutex.Lock()
|
||||
defer vs.Vars.mutex.Unlock()
|
||||
retVal = (len(vs.Vars.NamedVariables) == 0)
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// TrainableVariabless returns all trainable variables for this var-store
|
||||
func (vs *VarStore) TrainableVariables() (retVal []ts.Tensor) {
|
||||
vs.variables.mutex.Lock()
|
||||
defer vs.variables.mutex.Unlock()
|
||||
retVal = vs.variables.TrainableVariables
|
||||
vs.Vars.mutex.Lock()
|
||||
defer vs.Vars.mutex.Unlock()
|
||||
|
||||
retVal = vs.Vars.TrainableVariables
|
||||
for _, t := range vs.Vars.TrainableVariables {
|
||||
retVal = append(retVal, t.MustShallowClone())
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// Variables returns all variables and their names in a map[variable_name]Tensor
|
||||
func (vs *VarStore) Variables() (retVal map[string]ts.Tensor) {
|
||||
vs.variables.mutex.Lock()
|
||||
defer vs.variables.mutex.Unlock()
|
||||
retVal = vs.variables.NamedVariables
|
||||
vs.Vars.mutex.Lock()
|
||||
defer vs.Vars.mutex.Unlock()
|
||||
|
||||
retVal = make(map[string]ts.Tensor, 0)
|
||||
|
||||
for k, v := range vs.Vars.NamedVariables {
|
||||
retVal[k] = v.MustShallowClone()
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
@ -121,12 +130,12 @@ func (vs *VarStore) Root() (retVal Path) {
|
|||
// NOTE: Weight values for all the tensors currently stored in the
|
||||
// var-store gets saved in the given file.
|
||||
func (vs *VarStore) Save(filepath string) (err error) {
|
||||
vs.variables.mutex.Lock()
|
||||
defer vs.variables.mutex.Unlock()
|
||||
vs.Vars.mutex.Lock()
|
||||
defer vs.Vars.mutex.Unlock()
|
||||
|
||||
// Convert map to []NamedTensor
|
||||
var namedTensors []ts.NamedTensor
|
||||
for k, v := range vs.variables.NamedVariables {
|
||||
for k, v := range vs.Vars.NamedVariables {
|
||||
namedTensors = append(namedTensors, ts.NamedTensor{
|
||||
Name: k,
|
||||
Tensor: v,
|
||||
|
@ -152,13 +161,13 @@ func (vs *VarStore) Load(filepath string) (err error) {
|
|||
|
||||
// Match and in-place copy value (update) from newly loaded tensors
|
||||
// to existing named tensors if name is matched. Throw error otherwise.
|
||||
vs.variables.mutex.Lock()
|
||||
defer vs.variables.mutex.Unlock()
|
||||
vs.Vars.mutex.Lock()
|
||||
defer vs.Vars.mutex.Unlock()
|
||||
|
||||
for _, namedTs := range namedTensors {
|
||||
var currTs ts.Tensor
|
||||
var ok bool
|
||||
if currTs, ok = vs.variables.NamedVariables[namedTs.Name]; !ok {
|
||||
if currTs, ok = vs.Vars.NamedVariables[namedTs.Name]; !ok {
|
||||
err = fmt.Errorf("Cannot find tensor with name: %v in variable store. \n", namedTs.Name)
|
||||
return err
|
||||
}
|
||||
|
@ -192,13 +201,13 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
|
|||
|
||||
// Match and in-place copy value (update) from newly loaded tensors
|
||||
// to existing named tensors if name is matched. Throw error otherwise.
|
||||
vs.variables.mutex.Lock()
|
||||
defer vs.variables.mutex.Unlock()
|
||||
vs.Vars.mutex.Lock()
|
||||
defer vs.Vars.mutex.Unlock()
|
||||
|
||||
for _, namedTs := range namedTensors {
|
||||
var currTs ts.Tensor
|
||||
var ok bool
|
||||
if currTs, ok = vs.variables.NamedVariables[namedTs.Name]; !ok {
|
||||
if currTs, ok = vs.Vars.NamedVariables[namedTs.Name]; !ok {
|
||||
// missing
|
||||
missingVariables = append(missingVariables, namedTs.Name)
|
||||
}
|
||||
|
@ -217,10 +226,10 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
|
|||
// Gradients for the variables in this store are not tracked
|
||||
// anymore.
|
||||
func (vs *VarStore) Freeze() {
|
||||
vs.variables.mutex.Lock()
|
||||
defer vs.variables.mutex.Unlock()
|
||||
vs.Vars.mutex.Lock()
|
||||
defer vs.Vars.mutex.Unlock()
|
||||
|
||||
for _, v := range vs.variables.TrainableVariables {
|
||||
for _, v := range vs.Vars.TrainableVariables {
|
||||
_, err := v.SetRequiresGrad(false)
|
||||
if err != nil {
|
||||
log.Fatalf("Freeze() Error: %v\n", err)
|
||||
|
@ -232,10 +241,10 @@ func (vs *VarStore) Freeze() {
|
|||
//
|
||||
// Gradients for the variables in this store are tracked again.
|
||||
func (vs *VarStore) Unfreeze() {
|
||||
vs.variables.mutex.Lock()
|
||||
defer vs.variables.mutex.Unlock()
|
||||
vs.Vars.mutex.Lock()
|
||||
defer vs.Vars.mutex.Unlock()
|
||||
|
||||
for _, v := range vs.variables.TrainableVariables {
|
||||
for _, v := range vs.Vars.TrainableVariables {
|
||||
_, err := v.SetRequiresGrad(true)
|
||||
if err != nil {
|
||||
log.Fatalf("Unfreeze() Error: %v\n", err)
|
||||
|
@ -248,22 +257,22 @@ func (vs *VarStore) Unfreeze() {
|
|||
// All the variables in this var store have to exist with the same
|
||||
// name in the source var store, otherwise an error is returned.
|
||||
func (vs *VarStore) Copy(src VarStore) (err error) {
|
||||
vs.variables.mutex.Lock()
|
||||
defer vs.variables.mutex.Unlock()
|
||||
src.variables.mutex.Lock()
|
||||
defer src.variables.mutex.Unlock()
|
||||
vs.Vars.mutex.Lock()
|
||||
defer vs.Vars.mutex.Unlock()
|
||||
src.Vars.mutex.Lock()
|
||||
defer src.Vars.mutex.Unlock()
|
||||
|
||||
srcNamedVariables := src.variables.NamedVariables
|
||||
srcNamedVariables := src.Vars.NamedVariables
|
||||
device := vs.device
|
||||
|
||||
for k, _ := range vs.variables.NamedVariables {
|
||||
for k, _ := range vs.Vars.NamedVariables {
|
||||
if _, ok := srcNamedVariables[k]; !ok {
|
||||
err = fmt.Errorf("VarStore copy error: cannot find %v in the source var store.\n", k)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range vs.variables.NamedVariables {
|
||||
for k, v := range vs.Vars.NamedVariables {
|
||||
srcTs, _ := srcNamedVariables[k]
|
||||
srcDevTs, err := srcTs.To(device)
|
||||
if err != nil {
|
||||
|
@ -319,11 +328,11 @@ func (p *Path) getpath(name string) (retVal string) {
|
|||
func (p *Path) add(name string, newTs ts.Tensor, trainable bool) (retVal ts.Tensor) {
|
||||
path := p.getpath(name)
|
||||
|
||||
p.varstore.variables.mutex.Lock()
|
||||
defer p.varstore.variables.mutex.Unlock()
|
||||
p.varstore.Vars.mutex.Lock()
|
||||
defer p.varstore.Vars.mutex.Unlock()
|
||||
|
||||
if _, ok := p.varstore.variables.NamedVariables[path]; ok {
|
||||
path = fmt.Sprintf("%v__%v", path, len(p.varstore.variables.NamedVariables))
|
||||
if _, ok := p.varstore.Vars.NamedVariables[path]; ok {
|
||||
path = fmt.Sprintf("%v__%v", path, len(p.varstore.Vars.NamedVariables))
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -331,19 +340,19 @@ func (p *Path) add(name string, newTs ts.Tensor, trainable bool) (retVal ts.Tens
|
|||
err error
|
||||
)
|
||||
if trainable {
|
||||
tensor, err = newTs.SetRequiresGrad(true)
|
||||
tensor, err = newTs.MustShallowClone().SetRequiresGrad(true)
|
||||
if err != nil {
|
||||
log.Fatalf("Path 'add' method error: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
tensor = newTs
|
||||
tensor = newTs.MustShallowClone()
|
||||
}
|
||||
|
||||
if trainable {
|
||||
p.varstore.variables.TrainableVariables = append(p.varstore.variables.TrainableVariables, tensor)
|
||||
p.varstore.Vars.TrainableVariables = append(p.varstore.Vars.TrainableVariables, tensor)
|
||||
}
|
||||
|
||||
p.varstore.variables.NamedVariables[path] = tensor
|
||||
p.varstore.Vars.NamedVariables[path] = tensor
|
||||
|
||||
return tensor
|
||||
}
|
||||
|
@ -522,10 +531,10 @@ func (p *Path) VarCopy(name string, t ts.Tensor) (retVal ts.Tensor) {
|
|||
// Get gets the tensor corresponding to a given name if present.
|
||||
func (p *Path) Get(name string) (retVal ts.Tensor, err error) {
|
||||
|
||||
p.varstore.variables.mutex.Lock()
|
||||
defer p.varstore.variables.mutex.Unlock()
|
||||
p.varstore.Vars.mutex.Lock()
|
||||
defer p.varstore.Vars.mutex.Unlock()
|
||||
|
||||
v, ok := p.varstore.variables.NamedVariables[name]
|
||||
v, ok := p.varstore.Vars.NamedVariables[name]
|
||||
if !ok {
|
||||
err = fmt.Errorf("Path - Get method call error: Cannot find variable for name: %v\n", name)
|
||||
return retVal, err
|
||||
|
@ -536,12 +545,12 @@ func (p *Path) Get(name string) (retVal ts.Tensor, err error) {
|
|||
|
||||
// Entry gets the entry corresponding to a given name for in-place manipulation.
|
||||
func (p *Path) Entry(name string) (retVal Entry) {
|
||||
p.varstore.variables.mutex.Lock()
|
||||
defer p.varstore.variables.mutex.Unlock()
|
||||
p.varstore.Vars.mutex.Lock()
|
||||
defer p.varstore.Vars.mutex.Unlock()
|
||||
|
||||
return Entry{
|
||||
name: name,
|
||||
variables: &p.varstore.variables,
|
||||
variables: &p.varstore.Vars,
|
||||
path: p,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -80,42 +80,22 @@ func (ts Tensor) MustGrad() (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Detach_() (retVal Tensor, err error) {
|
||||
func (ts Tensor) Detach_() {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgDetach_(ptr, ts.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustDetach_() (retVal Tensor) {
|
||||
retVal, err := ts.Detach_()
|
||||
if err != nil {
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Zero_() (err error) {
|
||||
func (ts Tensor) Zero_() {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgZero_(ptr, ts.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustZero_() {
|
||||
err := ts.Zero_()
|
||||
if err != nil {
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
// "strings"
|
||||
"unsafe"
|
||||
|
||||
gotch "github.com/sugarme/gotch"
|
||||
|
@ -432,11 +431,8 @@ func (ts Tensor) IsSparse() (retVal bool, err error) {
|
|||
func (ts Tensor) ZeroGrad() {
|
||||
grad := ts.MustGrad()
|
||||
if grad.MustDefined() {
|
||||
// TODO: can we chain them?
|
||||
// grad.MustDetach_().MustZero_()
|
||||
// https://www.calhoun.io/using-functional-options-instead-of-method-chaining-in-go/
|
||||
detach := grad.MustDetach_()
|
||||
detach.MustZero_()
|
||||
grad.Detach_()
|
||||
grad.Zero_()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -989,3 +985,10 @@ func (r Reduction) ToInt() (retVal int) {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Values returns values of tensor in a slice of float64.
|
||||
func (ts Tensor) Values() []float64 {
|
||||
clone := ts.MustShallowClone()
|
||||
clone.Detach_()
|
||||
return []float64{clone.MustView([]int64{-1}).MustFloat64Value([]int64{-1})}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user