chorus(example/mnist): cleanup. Still the memory blow-out issue
This commit is contained in:
parent
e0d2e0ca7e
commit
4ffe5feb7a
|
@ -13,19 +13,13 @@ const (
|
||||||
Label int64 = 10
|
Label int64 = 10
|
||||||
MnistDir string = "../../data/mnist"
|
MnistDir string = "../../data/mnist"
|
||||||
|
|
||||||
epochs = 200
|
epochs = 200
|
||||||
batchSize = 256
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func runLinear() {
|
func runLinear() {
|
||||||
var ds vision.Dataset
|
var ds vision.Dataset
|
||||||
ds = vision.LoadMNISTDir(MnistDir)
|
ds = vision.LoadMNISTDir(MnistDir)
|
||||||
|
|
||||||
// fmt.Printf("Train image size: %v\n", ds.TrainImages.MustSize())
|
|
||||||
// fmt.Printf("Train label size: %v\n", ds.TrainLabels.MustSize())
|
|
||||||
// fmt.Printf("Test image size: %v\n", ds.TestImages.MustSize())
|
|
||||||
// fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
|
|
||||||
|
|
||||||
device := (gotch.CPU).CInt()
|
device := (gotch.CPU).CInt()
|
||||||
dtype := (gotch.Float).CInt()
|
dtype := (gotch.Float).CInt()
|
||||||
|
|
||||||
|
@ -33,51 +27,9 @@ func runLinear() {
|
||||||
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
|
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
|
||||||
|
|
||||||
for epoch := 0; epoch < epochs; epoch++ {
|
for epoch := 0; epoch < epochs; epoch++ {
|
||||||
/*
|
|
||||||
* totalSize := ds.TrainImages.MustSize()[0]
|
|
||||||
* samples := int(totalSize)
|
|
||||||
* index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
|
|
||||||
* imagesTs := ds.TrainImages.MustIndexSelect(0, index)
|
|
||||||
* labelsTs := ds.TrainLabels.MustIndexSelect(0, index)
|
|
||||||
*
|
|
||||||
* batches := samples / batchSize
|
|
||||||
* batchIndex := 0
|
|
||||||
* var loss ts.Tensor
|
|
||||||
* for i := 0; i < batches; i++ {
|
|
||||||
* start := batchIndex * batchSize
|
|
||||||
* size := batchSize
|
|
||||||
* if samples-start < batchSize {
|
|
||||||
* // size = samples - start
|
|
||||||
* break
|
|
||||||
* }
|
|
||||||
* batchIndex += 1
|
|
||||||
*
|
|
||||||
* // Indexing
|
|
||||||
* narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
|
|
||||||
* bImages := ds.TrainImages.Idx(narrowIndex)
|
|
||||||
* bLabels := ds.TrainLabels.Idx(narrowIndex)
|
|
||||||
* // bImages := imagesTs.Idx(narrowIndex)
|
|
||||||
* // bLabels := labelsTs.Idx(narrowIndex)
|
|
||||||
*
|
|
||||||
* logits := bImages.MustMm(ws).MustAdd(bs)
|
|
||||||
* loss = logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels).MustSetRequiresGrad(true)
|
|
||||||
*
|
|
||||||
* ws.ZeroGrad()
|
|
||||||
* bs.ZeroGrad()
|
|
||||||
* loss.MustBackward()
|
|
||||||
*
|
|
||||||
* ts.NoGrad(func() {
|
|
||||||
* ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
|
||||||
* bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
|
||||||
* })
|
|
||||||
* }
|
|
||||||
*
|
|
||||||
* imagesTs.MustDrop()
|
|
||||||
* labelsTs.MustDrop()
|
|
||||||
* */
|
|
||||||
|
|
||||||
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
|
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
|
||||||
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
|
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
|
||||||
|
|
||||||
ws.ZeroGrad()
|
ws.ZeroGrad()
|
||||||
bs.ZeroGrad()
|
bs.ZeroGrad()
|
||||||
|
@ -91,7 +43,7 @@ func runLinear() {
|
||||||
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
|
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
|
||||||
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||||
|
|
||||||
lossVal := loss.MustShallowClone().MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100)
|
||||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,12 +40,22 @@ func netInit(vs nn.Path) ts.Module {
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func train(trainX, trainY, testX, testY ts.Tensor, m ts.Module, opt nn.Optimizer, epoch int) {
|
||||||
|
loss := m.Forward(trainX).CrossEntropyForLogits(trainY)
|
||||||
|
|
||||||
|
opt.BackwardStep(loss)
|
||||||
|
|
||||||
|
testAccuracy := m.Forward(testX).AccuracyForLogits(testY).Values()[0]
|
||||||
|
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func runNN() {
|
func runNN() {
|
||||||
|
|
||||||
var ds vision.Dataset
|
var ds vision.Dataset
|
||||||
ds = vision.LoadMNISTDir(MnistDirNN)
|
ds = vision.LoadMNISTDir(MnistDirNN)
|
||||||
vs := nn.NewVarStore(gotch.CPU)
|
vs := nn.NewVarStore(gotch.CPU)
|
||||||
net := netInit(vs.Root())
|
net := netInit(vs.Root())
|
||||||
|
|
||||||
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
|
@ -53,14 +63,8 @@ func runNN() {
|
||||||
|
|
||||||
for epoch := 0; epoch < epochsNN; epoch++ {
|
for epoch := 0; epoch < epochsNN; epoch++ {
|
||||||
|
|
||||||
loss := net.Forward(ds.TrainImages).CrossEntropyForLogits(ds.TrainLabels)
|
train(ds.TrainImages, ds.TrainLabels, ds.TestImages, ds.TestLabels, net, opt, epoch)
|
||||||
|
|
||||||
opt.BackwardStep(loss)
|
|
||||||
lossVal := loss.MustShallowClone().MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
|
||||||
testAccuracy := net.Forward(ds.TestImages).AccuracyForLogits(ds.TestLabels).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
|
||||||
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
|
|
||||||
|
|
||||||
fmt.Printf("Loss: %v\n", lossVal)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,15 +10,15 @@ import (
|
||||||
|
|
||||||
// Optimizer is a struct object to run gradient descent.
|
// Optimizer is a struct object to run gradient descent.
|
||||||
type Optimizer struct {
|
type Optimizer struct {
|
||||||
opt ts.COptimizer
|
opt ts.COptimizer
|
||||||
variables Variables // having embedded sync.Mutex
|
// variables Variables // having embedded sync.Mutex
|
||||||
variablesInOptimizer uint8
|
variablesInOptimizer uint8
|
||||||
config interface{}
|
config interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OptimizerConfig defines Optimizer configurations. These configs can be used to build optimizer.
|
// OptimizerConfig defines Optimizer configurations. These configs can be used to build optimizer.
|
||||||
type OptimizerConfig interface {
|
type OptimizerConfig interface {
|
||||||
BuildCOpt(lr float64) (retVal ts.COptimizer, err error)
|
buildCOpt(lr float64) (retVal ts.COptimizer, err error)
|
||||||
|
|
||||||
// Build builds an optimizer with the specified learning rate handling variables stored in `vs`.
|
// Build builds an optimizer with the specified learning rate handling variables stored in `vs`.
|
||||||
//
|
//
|
||||||
|
@ -35,29 +35,27 @@ type OptimizerConfig interface {
|
||||||
// defaultBuild is `default` Build method for OptimizerConfig interface
|
// defaultBuild is `default` Build method for OptimizerConfig interface
|
||||||
func defaultBuild(config OptimizerConfig, vs VarStore, lr float64) (retVal Optimizer, err error) {
|
func defaultBuild(config OptimizerConfig, vs VarStore, lr float64) (retVal Optimizer, err error) {
|
||||||
|
|
||||||
opt, err := config.BuildCOpt(lr)
|
opt, err := config.buildCOpt(lr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return retVal, err
|
return retVal, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// vs.variables.mutex.Lock()
|
|
||||||
// defer vs.variables.mutex.Unlock()
|
|
||||||
|
|
||||||
// fmt.Printf("Trainable Variables: \n:%v", len(vs.Variables()))
|
|
||||||
var parameters []ts.Tensor
|
var parameters []ts.Tensor
|
||||||
for _, v := range vs.Variables() {
|
for _, v := range vs.Vars.TrainableVariables {
|
||||||
parameters = append(parameters, v)
|
param := v.MustShallowClone()
|
||||||
|
parameters = append(parameters, param)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if err = opt.AddParameters(vs.variables.TrainableVariables); err != nil {
|
if err = opt.AddParameters(vs.Vars.TrainableVariables); err != nil {
|
||||||
if err = opt.AddParameters(parameters); err != nil {
|
|
||||||
return retVal, err
|
return retVal, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: should we clone or copy?
|
||||||
|
|
||||||
return Optimizer{
|
return Optimizer{
|
||||||
opt: opt,
|
opt: opt,
|
||||||
variables: vs.variables,
|
// variables: vs.Vars,
|
||||||
variablesInOptimizer: uint8(len(vs.variables.TrainableVariables)),
|
variablesInOptimizer: uint8(len(vs.Vars.TrainableVariables)),
|
||||||
config: config,
|
config: config,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -94,7 +92,7 @@ func NewSGDConfig(momentum, dampening, wd float64, nesterov bool) (retVal SGDCon
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implement OptimizerConfig interface for SGDConfig
|
// Implement OptimizerConfig interface for SGDConfig
|
||||||
func (c SGDConfig) BuildCOpt(lr float64) (retVal ts.COptimizer, err error) {
|
func (c SGDConfig) buildCOpt(lr float64) (retVal ts.COptimizer, err error) {
|
||||||
return ts.Sgd(lr, c.Momentum, c.Dampening, c.Wd, c.Nesterov)
|
return ts.Sgd(lr, c.Momentum, c.Dampening, c.Wd, c.Nesterov)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,7 +128,7 @@ func NewAdamConfig(beta1, beta2, wd float64) AdamConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implement OptimizerConfig interface for AdamConfig
|
// Implement OptimizerConfig interface for AdamConfig
|
||||||
func (c AdamConfig) BuildCOpt(lr float64) (retVal ts.COptimizer, err error) {
|
func (c AdamConfig) buildCOpt(lr float64) (retVal ts.COptimizer, err error) {
|
||||||
return ts.Adam(lr, c.Beta1, c.Beta2, c.Wd)
|
return ts.Adam(lr, c.Beta1, c.Beta2, c.Wd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,7 +170,7 @@ func NewRMSPropConfig(alpha, eps, wd, momentum float64, centered bool) RMSPropCo
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implement OptimizerConfig interface for RMSPropConfig
|
// Implement OptimizerConfig interface for RMSPropConfig
|
||||||
func (c RMSPropConfig) BuildCOpt(lr float64) (retVal ts.COptimizer, err error) {
|
func (c RMSPropConfig) buildCOpt(lr float64) (retVal ts.COptimizer, err error) {
|
||||||
return ts.RmsProp(lr, c.Alpha, c.Eps, c.Wd, c.Momentum, c.Centered)
|
return ts.RmsProp(lr, c.Alpha, c.Eps, c.Wd, c.Momentum, c.Centered)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -184,15 +182,19 @@ func (c RMSPropConfig) Build(vs VarStore, lr float64) (retVal Optimizer, err err
|
||||||
// ==================
|
// ==================
|
||||||
func (opt *Optimizer) addMissingVariables() {
|
func (opt *Optimizer) addMissingVariables() {
|
||||||
|
|
||||||
opt.variables.mutex.Lock()
|
// missingVariables := len(opt.variables.TrainableVariables) - int(opt.variablesInOptimizer)
|
||||||
defer opt.variables.mutex.Unlock()
|
//
|
||||||
|
// if missingVariables > 0 {
|
||||||
missingVariables := len(opt.variables.TrainableVariables) - int(opt.variablesInOptimizer)
|
// var tensors []ts.Tensor
|
||||||
|
// for _, t := range opt.variables.TrainableVariables[opt.variablesInOptimizer:] {
|
||||||
if missingVariables > 0 {
|
// tensor := t.MustShallowClone()
|
||||||
opt.opt.AddParameters(opt.variables.TrainableVariables[opt.variablesInOptimizer:])
|
// tensor.Detach_()
|
||||||
opt.variablesInOptimizer = uint8(len(opt.variables.TrainableVariables))
|
// tensors = append(tensors, tensor)
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
|
// opt.opt.AddParameters(tensors)
|
||||||
|
// opt.variablesInOptimizer = uint8(len(opt.variables.TrainableVariables))
|
||||||
|
// }
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -207,12 +209,12 @@ func (opt *Optimizer) ZeroGrad() {
|
||||||
// Clips gradient value at some specified maximum value.
|
// Clips gradient value at some specified maximum value.
|
||||||
func (opt *Optimizer) ClipGradValue(max float64) {
|
func (opt *Optimizer) ClipGradValue(max float64) {
|
||||||
|
|
||||||
opt.variables.mutex.Lock()
|
// opt.variables.mutex.Lock()
|
||||||
defer opt.variables.mutex.Unlock()
|
// defer opt.variables.mutex.Unlock()
|
||||||
|
|
||||||
for _, tensor := range opt.variables.TrainableVariables {
|
// for _, tensor := range opt.variables.TrainableVariables {
|
||||||
tensor.MustGrad().Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max))
|
// tensor.MustGrad().Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max))
|
||||||
}
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step performs an optimization step, updating the tracked tensors based on their gradients.
|
// Step performs an optimization step, updating the tracked tensors based on their gradients.
|
||||||
|
|
113
nn/varstore.go
113
nn/varstore.go
|
@ -26,8 +26,8 @@ type Variables struct {
|
||||||
// VarStore is used to store variables used by one or multiple layers.
|
// VarStore is used to store variables used by one or multiple layers.
|
||||||
// It specifies a SINGLE device where all variables are stored.
|
// It specifies a SINGLE device where all variables are stored.
|
||||||
type VarStore struct {
|
type VarStore struct {
|
||||||
device gotch.Device
|
device gotch.Device
|
||||||
variables Variables // TODO: should we export this field
|
Vars Variables
|
||||||
}
|
}
|
||||||
|
|
||||||
// Path is variable store with an associated path for variables naming.
|
// Path is variable store with an associated path for variables naming.
|
||||||
|
@ -52,8 +52,8 @@ func NewVarStore(device gotch.Device) VarStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
return VarStore{
|
return VarStore{
|
||||||
device: device,
|
device: device,
|
||||||
variables: variables,
|
Vars: variables,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,36 +70,45 @@ func (vs *VarStore) Device() gotch.Device {
|
||||||
|
|
||||||
// Len returns the number of tensors currently stored on this var-store
|
// Len returns the number of tensors currently stored on this var-store
|
||||||
func (vs *VarStore) Len() (retVal int) {
|
func (vs *VarStore) Len() (retVal int) {
|
||||||
vs.variables.mutex.Lock()
|
vs.Vars.mutex.Lock()
|
||||||
defer vs.variables.mutex.Unlock()
|
defer vs.Vars.mutex.Unlock()
|
||||||
retVal = len(vs.variables.NamedVariables)
|
retVal = len(vs.Vars.NamedVariables)
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsEmpty returns true if no tensors are currently stored on this var-store
|
// IsEmpty returns true if no tensors are currently stored on this var-store
|
||||||
func (vs *VarStore) IsEmpty() (retVal bool) {
|
func (vs *VarStore) IsEmpty() (retVal bool) {
|
||||||
vs.variables.mutex.Lock()
|
vs.Vars.mutex.Lock()
|
||||||
defer vs.variables.mutex.Unlock()
|
defer vs.Vars.mutex.Unlock()
|
||||||
retVal = (len(vs.variables.NamedVariables) == 0)
|
retVal = (len(vs.Vars.NamedVariables) == 0)
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrainableVariabless returns all trainable variables for this var-store
|
// TrainableVariabless returns all trainable variables for this var-store
|
||||||
func (vs *VarStore) TrainableVariables() (retVal []ts.Tensor) {
|
func (vs *VarStore) TrainableVariables() (retVal []ts.Tensor) {
|
||||||
vs.variables.mutex.Lock()
|
vs.Vars.mutex.Lock()
|
||||||
defer vs.variables.mutex.Unlock()
|
defer vs.Vars.mutex.Unlock()
|
||||||
retVal = vs.variables.TrainableVariables
|
|
||||||
|
retVal = vs.Vars.TrainableVariables
|
||||||
|
for _, t := range vs.Vars.TrainableVariables {
|
||||||
|
retVal = append(retVal, t.MustShallowClone())
|
||||||
|
}
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
// Variables returns all variables and their names in a map[variable_name]Tensor
|
// Variables returns all variables and their names in a map[variable_name]Tensor
|
||||||
func (vs *VarStore) Variables() (retVal map[string]ts.Tensor) {
|
func (vs *VarStore) Variables() (retVal map[string]ts.Tensor) {
|
||||||
vs.variables.mutex.Lock()
|
vs.Vars.mutex.Lock()
|
||||||
defer vs.variables.mutex.Unlock()
|
defer vs.Vars.mutex.Unlock()
|
||||||
retVal = vs.variables.NamedVariables
|
|
||||||
|
retVal = make(map[string]ts.Tensor, 0)
|
||||||
|
|
||||||
|
for k, v := range vs.Vars.NamedVariables {
|
||||||
|
retVal[k] = v.MustShallowClone()
|
||||||
|
}
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
@ -121,12 +130,12 @@ func (vs *VarStore) Root() (retVal Path) {
|
||||||
// NOTE: Weight values for all the tensors currently stored in the
|
// NOTE: Weight values for all the tensors currently stored in the
|
||||||
// var-store gets saved in the given file.
|
// var-store gets saved in the given file.
|
||||||
func (vs *VarStore) Save(filepath string) (err error) {
|
func (vs *VarStore) Save(filepath string) (err error) {
|
||||||
vs.variables.mutex.Lock()
|
vs.Vars.mutex.Lock()
|
||||||
defer vs.variables.mutex.Unlock()
|
defer vs.Vars.mutex.Unlock()
|
||||||
|
|
||||||
// Convert map to []NamedTensor
|
// Convert map to []NamedTensor
|
||||||
var namedTensors []ts.NamedTensor
|
var namedTensors []ts.NamedTensor
|
||||||
for k, v := range vs.variables.NamedVariables {
|
for k, v := range vs.Vars.NamedVariables {
|
||||||
namedTensors = append(namedTensors, ts.NamedTensor{
|
namedTensors = append(namedTensors, ts.NamedTensor{
|
||||||
Name: k,
|
Name: k,
|
||||||
Tensor: v,
|
Tensor: v,
|
||||||
|
@ -152,13 +161,13 @@ func (vs *VarStore) Load(filepath string) (err error) {
|
||||||
|
|
||||||
// Match and in-place copy value (update) from newly loaded tensors
|
// Match and in-place copy value (update) from newly loaded tensors
|
||||||
// to existing named tensors if name is matched. Throw error otherwise.
|
// to existing named tensors if name is matched. Throw error otherwise.
|
||||||
vs.variables.mutex.Lock()
|
vs.Vars.mutex.Lock()
|
||||||
defer vs.variables.mutex.Unlock()
|
defer vs.Vars.mutex.Unlock()
|
||||||
|
|
||||||
for _, namedTs := range namedTensors {
|
for _, namedTs := range namedTensors {
|
||||||
var currTs ts.Tensor
|
var currTs ts.Tensor
|
||||||
var ok bool
|
var ok bool
|
||||||
if currTs, ok = vs.variables.NamedVariables[namedTs.Name]; !ok {
|
if currTs, ok = vs.Vars.NamedVariables[namedTs.Name]; !ok {
|
||||||
err = fmt.Errorf("Cannot find tensor with name: %v in variable store. \n", namedTs.Name)
|
err = fmt.Errorf("Cannot find tensor with name: %v in variable store. \n", namedTs.Name)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -192,13 +201,13 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
|
||||||
|
|
||||||
// Match and in-place copy value (update) from newly loaded tensors
|
// Match and in-place copy value (update) from newly loaded tensors
|
||||||
// to existing named tensors if name is matched. Throw error otherwise.
|
// to existing named tensors if name is matched. Throw error otherwise.
|
||||||
vs.variables.mutex.Lock()
|
vs.Vars.mutex.Lock()
|
||||||
defer vs.variables.mutex.Unlock()
|
defer vs.Vars.mutex.Unlock()
|
||||||
|
|
||||||
for _, namedTs := range namedTensors {
|
for _, namedTs := range namedTensors {
|
||||||
var currTs ts.Tensor
|
var currTs ts.Tensor
|
||||||
var ok bool
|
var ok bool
|
||||||
if currTs, ok = vs.variables.NamedVariables[namedTs.Name]; !ok {
|
if currTs, ok = vs.Vars.NamedVariables[namedTs.Name]; !ok {
|
||||||
// missing
|
// missing
|
||||||
missingVariables = append(missingVariables, namedTs.Name)
|
missingVariables = append(missingVariables, namedTs.Name)
|
||||||
}
|
}
|
||||||
|
@ -217,10 +226,10 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
|
||||||
// Gradients for the variables in this store are not tracked
|
// Gradients for the variables in this store are not tracked
|
||||||
// anymore.
|
// anymore.
|
||||||
func (vs *VarStore) Freeze() {
|
func (vs *VarStore) Freeze() {
|
||||||
vs.variables.mutex.Lock()
|
vs.Vars.mutex.Lock()
|
||||||
defer vs.variables.mutex.Unlock()
|
defer vs.Vars.mutex.Unlock()
|
||||||
|
|
||||||
for _, v := range vs.variables.TrainableVariables {
|
for _, v := range vs.Vars.TrainableVariables {
|
||||||
_, err := v.SetRequiresGrad(false)
|
_, err := v.SetRequiresGrad(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Freeze() Error: %v\n", err)
|
log.Fatalf("Freeze() Error: %v\n", err)
|
||||||
|
@ -232,10 +241,10 @@ func (vs *VarStore) Freeze() {
|
||||||
//
|
//
|
||||||
// Gradients for the variables in this store are tracked again.
|
// Gradients for the variables in this store are tracked again.
|
||||||
func (vs *VarStore) Unfreeze() {
|
func (vs *VarStore) Unfreeze() {
|
||||||
vs.variables.mutex.Lock()
|
vs.Vars.mutex.Lock()
|
||||||
defer vs.variables.mutex.Unlock()
|
defer vs.Vars.mutex.Unlock()
|
||||||
|
|
||||||
for _, v := range vs.variables.TrainableVariables {
|
for _, v := range vs.Vars.TrainableVariables {
|
||||||
_, err := v.SetRequiresGrad(true)
|
_, err := v.SetRequiresGrad(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Unfreeze() Error: %v\n", err)
|
log.Fatalf("Unfreeze() Error: %v\n", err)
|
||||||
|
@ -248,22 +257,22 @@ func (vs *VarStore) Unfreeze() {
|
||||||
// All the variables in this var store have to exist with the same
|
// All the variables in this var store have to exist with the same
|
||||||
// name in the source var store, otherwise an error is returned.
|
// name in the source var store, otherwise an error is returned.
|
||||||
func (vs *VarStore) Copy(src VarStore) (err error) {
|
func (vs *VarStore) Copy(src VarStore) (err error) {
|
||||||
vs.variables.mutex.Lock()
|
vs.Vars.mutex.Lock()
|
||||||
defer vs.variables.mutex.Unlock()
|
defer vs.Vars.mutex.Unlock()
|
||||||
src.variables.mutex.Lock()
|
src.Vars.mutex.Lock()
|
||||||
defer src.variables.mutex.Unlock()
|
defer src.Vars.mutex.Unlock()
|
||||||
|
|
||||||
srcNamedVariables := src.variables.NamedVariables
|
srcNamedVariables := src.Vars.NamedVariables
|
||||||
device := vs.device
|
device := vs.device
|
||||||
|
|
||||||
for k, _ := range vs.variables.NamedVariables {
|
for k, _ := range vs.Vars.NamedVariables {
|
||||||
if _, ok := srcNamedVariables[k]; !ok {
|
if _, ok := srcNamedVariables[k]; !ok {
|
||||||
err = fmt.Errorf("VarStore copy error: cannot find %v in the source var store.\n", k)
|
err = fmt.Errorf("VarStore copy error: cannot find %v in the source var store.\n", k)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range vs.variables.NamedVariables {
|
for k, v := range vs.Vars.NamedVariables {
|
||||||
srcTs, _ := srcNamedVariables[k]
|
srcTs, _ := srcNamedVariables[k]
|
||||||
srcDevTs, err := srcTs.To(device)
|
srcDevTs, err := srcTs.To(device)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -319,11 +328,11 @@ func (p *Path) getpath(name string) (retVal string) {
|
||||||
func (p *Path) add(name string, newTs ts.Tensor, trainable bool) (retVal ts.Tensor) {
|
func (p *Path) add(name string, newTs ts.Tensor, trainable bool) (retVal ts.Tensor) {
|
||||||
path := p.getpath(name)
|
path := p.getpath(name)
|
||||||
|
|
||||||
p.varstore.variables.mutex.Lock()
|
p.varstore.Vars.mutex.Lock()
|
||||||
defer p.varstore.variables.mutex.Unlock()
|
defer p.varstore.Vars.mutex.Unlock()
|
||||||
|
|
||||||
if _, ok := p.varstore.variables.NamedVariables[path]; ok {
|
if _, ok := p.varstore.Vars.NamedVariables[path]; ok {
|
||||||
path = fmt.Sprintf("%v__%v", path, len(p.varstore.variables.NamedVariables))
|
path = fmt.Sprintf("%v__%v", path, len(p.varstore.Vars.NamedVariables))
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -331,19 +340,19 @@ func (p *Path) add(name string, newTs ts.Tensor, trainable bool) (retVal ts.Tens
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
if trainable {
|
if trainable {
|
||||||
tensor, err = newTs.SetRequiresGrad(true)
|
tensor, err = newTs.MustShallowClone().SetRequiresGrad(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Path 'add' method error: %v\n", err)
|
log.Fatalf("Path 'add' method error: %v\n", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
tensor = newTs
|
tensor = newTs.MustShallowClone()
|
||||||
}
|
}
|
||||||
|
|
||||||
if trainable {
|
if trainable {
|
||||||
p.varstore.variables.TrainableVariables = append(p.varstore.variables.TrainableVariables, tensor)
|
p.varstore.Vars.TrainableVariables = append(p.varstore.Vars.TrainableVariables, tensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.varstore.variables.NamedVariables[path] = tensor
|
p.varstore.Vars.NamedVariables[path] = tensor
|
||||||
|
|
||||||
return tensor
|
return tensor
|
||||||
}
|
}
|
||||||
|
@ -522,10 +531,10 @@ func (p *Path) VarCopy(name string, t ts.Tensor) (retVal ts.Tensor) {
|
||||||
// Get gets the tensor corresponding to a given name if present.
|
// Get gets the tensor corresponding to a given name if present.
|
||||||
func (p *Path) Get(name string) (retVal ts.Tensor, err error) {
|
func (p *Path) Get(name string) (retVal ts.Tensor, err error) {
|
||||||
|
|
||||||
p.varstore.variables.mutex.Lock()
|
p.varstore.Vars.mutex.Lock()
|
||||||
defer p.varstore.variables.mutex.Unlock()
|
defer p.varstore.Vars.mutex.Unlock()
|
||||||
|
|
||||||
v, ok := p.varstore.variables.NamedVariables[name]
|
v, ok := p.varstore.Vars.NamedVariables[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
err = fmt.Errorf("Path - Get method call error: Cannot find variable for name: %v\n", name)
|
err = fmt.Errorf("Path - Get method call error: Cannot find variable for name: %v\n", name)
|
||||||
return retVal, err
|
return retVal, err
|
||||||
|
@ -536,12 +545,12 @@ func (p *Path) Get(name string) (retVal ts.Tensor, err error) {
|
||||||
|
|
||||||
// Entry gets the entry corresponding to a given name for in-place manipulation.
|
// Entry gets the entry corresponding to a given name for in-place manipulation.
|
||||||
func (p *Path) Entry(name string) (retVal Entry) {
|
func (p *Path) Entry(name string) (retVal Entry) {
|
||||||
p.varstore.variables.mutex.Lock()
|
p.varstore.Vars.mutex.Lock()
|
||||||
defer p.varstore.variables.mutex.Unlock()
|
defer p.varstore.Vars.mutex.Unlock()
|
||||||
|
|
||||||
return Entry{
|
return Entry{
|
||||||
name: name,
|
name: name,
|
||||||
variables: &p.varstore.variables,
|
variables: &p.varstore.Vars,
|
||||||
path: p,
|
path: p,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -80,42 +80,22 @@ func (ts Tensor) MustGrad() (retVal Tensor) {
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts Tensor) Detach_() (retVal Tensor, err error) {
|
func (ts Tensor) Detach_() {
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
defer C.free(unsafe.Pointer(ptr))
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
lib.AtgDetach_(ptr, ts.ctensor)
|
lib.AtgDetach_(ptr, ts.ctensor)
|
||||||
|
|
||||||
if err = TorchErr(); err != nil {
|
if err := TorchErr(); err != nil {
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return Tensor{ctensor: *ptr}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ts Tensor) MustDetach_() (retVal Tensor) {
|
|
||||||
retVal, err := ts.Detach_()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return retVal
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts Tensor) Zero_() (err error) {
|
func (ts Tensor) Zero_() {
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
defer C.free(unsafe.Pointer(ptr))
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
lib.AtgZero_(ptr, ts.ctensor)
|
lib.AtgZero_(ptr, ts.ctensor)
|
||||||
|
|
||||||
if err = TorchErr(); err != nil {
|
if err := TorchErr(); err != nil {
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ts Tensor) MustZero_() {
|
|
||||||
err := ts.Zero_()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"reflect"
|
"reflect"
|
||||||
// "strings"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
gotch "github.com/sugarme/gotch"
|
gotch "github.com/sugarme/gotch"
|
||||||
|
@ -432,11 +431,8 @@ func (ts Tensor) IsSparse() (retVal bool, err error) {
|
||||||
func (ts Tensor) ZeroGrad() {
|
func (ts Tensor) ZeroGrad() {
|
||||||
grad := ts.MustGrad()
|
grad := ts.MustGrad()
|
||||||
if grad.MustDefined() {
|
if grad.MustDefined() {
|
||||||
// TODO: can we chain them?
|
grad.Detach_()
|
||||||
// grad.MustDetach_().MustZero_()
|
grad.Zero_()
|
||||||
// https://www.calhoun.io/using-functional-options-instead-of-method-chaining-in-go/
|
|
||||||
detach := grad.MustDetach_()
|
|
||||||
detach.MustZero_()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -989,3 +985,10 @@ func (r Reduction) ToInt() (retVal int) {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Values returns values of tensor in a slice of float64.
|
||||||
|
func (ts Tensor) Values() []float64 {
|
||||||
|
clone := ts.MustShallowClone()
|
||||||
|
clone.Detach_()
|
||||||
|
return []float64{clone.MustView([]int64{-1}).MustFloat64Value([]int64{-1})}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user