feat(nn/optimizer): completed

This commit is contained in:
sugarme 2020-06-18 15:31:30 +10:00
parent cbcc9fce65
commit b5f112e030
3 changed files with 110 additions and 1 deletions

View File

@ -242,3 +242,8 @@ func AtgRandperm(ptr *Ctensor, n int64, optionKind int32, optionDevice int32) {
C.atg_randperm(ptr, cn, coptionKind, coptionDevice)
}
// void atg_clamp_(tensor *, tensor self, scalar min, scalar max);
func AtgClamp_(ptr *Ctensor, self Ctensor, min Cscalar, max Cscalar) {
C.atg_clamp_(ptr, self, min, max)
}

View File

@ -4,6 +4,8 @@ package nn
import (
// "github.com/sugarme/gotch"
"log"
ts "github.com/sugarme/gotch/tensor"
)
@ -175,5 +177,97 @@ func (c RMSPropConfig) Build(vs VarStore, lr float64) (retVal Optimizer, err err
// Optimizer methods:
// ==================
func (opt *Optimizer) addMissingVariables() {
// TODO: implement
opt.variables.mutex.Lock()
defer opt.variables.mutex.Unlock()
missingVariables := len(opt.variables.TrainableVariable) - int(opt.variablesInOptimizer)
if missingVariables > 0 {
opt.opt.AddParameters(opt.variables.TrainableVariable[opt.variablesInOptimizer:])
opt.variablesInOptimizer = uint8(len(opt.variables.TrainableVariable))
}
}
// ZeroGrad zeroes the gradient for the tensors tracked by this optimizer.
func (opt *Optimizer) ZeroGrad() {
opt.addMissingVariables()
if err := opt.opt.ZeroGrad(); err != nil {
log.Fatalf("Optimizer - ZeroGrad method call error: %v\n", err)
}
}
// Clips gradient value at some specified maximum value.
func (opt *Optimizer) ClipGradValue(max float64) {
opt.variables.mutex.Lock()
defer opt.variables.mutex.Unlock()
for _, tensor := range opt.variables.TrainableVariable {
tensor.MustGrad().Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max))
}
}
// Step performs an optimization step, updating the tracked tensors based on their gradients.
func (opt *Optimizer) Step() {
opt.addMissingVariables()
err := opt.opt.Step()
if err != nil {
log.Fatalf("Optimizer - Step method call error: %v\n", err)
}
}
// BackwardStep applies a backward step pass, update the gradients, and performs an optimization step.
func (opt *Optimizer) BackwardStep(loss ts.Tensor) {
opt.addMissingVariables()
err := opt.opt.ZeroGrad()
if err != nil {
log.Fatalf("Optimizer - BackwardStep method call - ZeroGrad error: %v\n", err)
}
loss.MustBackward()
err = opt.opt.Step()
if err != nil {
log.Fatalf("Optimizer - BackwardStep method call - Step() error: %v\n", err)
}
}
// BackwardStepClip applies a backward step pass, update the gradients, and performs an optimization step.
//
// The gradients are clipped based on `max` before being applied.
func (opt *Optimizer) BackwardStepClip(loss ts.Tensor, max float64) {
opt.addMissingVariables()
err := opt.opt.ZeroGrad()
if err != nil {
log.Fatalf("Optimizer - BackwardStepClip method call - ZeroGrad error: %v\n", err)
}
loss.MustBackward()
opt.ClipGradValue(max)
err = opt.opt.Step()
if err != nil {
log.Fatalf("Optimizer - BackwardStepClip method call - Step() error: %v\n", err)
}
}
// SetLR sets the optimizer learning rate.
func (opt *Optimizer) SetLR(lr float64) {
err := opt.opt.SetLearningRate(lr)
if err != nil {
log.Fatalf("Optimizer - SetLR method call error: %v\n", err)
}
}
// SetMomentum sets the optimizer momentum.
func (opt *Optimizer) SetMomentum(m float64) {
err := opt.opt.SetMomentum(m)
if err != nil {
log.Fatalf("Optimizer - SetMomentum method call error: %v\n", err)
}
}

View File

@ -730,3 +730,13 @@ func MustRandperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (r
return retVal
}
func (ts Tensor) Clamp_(min Scalar, max Scalar) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgClamp_(ptr, ts.ctensor, min.cscalar, max.cscalar)
if err = TorchErr(); err != nil {
log.Fatal(err)
}
}