feat(nn/optimizer): completed
This commit is contained in:
parent
cbcc9fce65
commit
b5f112e030
|
@ -242,3 +242,8 @@ func AtgRandperm(ptr *Ctensor, n int64, optionKind int32, optionDevice int32) {
|
|||
|
||||
C.atg_randperm(ptr, cn, coptionKind, coptionDevice)
|
||||
}
|
||||
|
||||
// void atg_clamp_(tensor *, tensor self, scalar min, scalar max);
|
||||
func AtgClamp_(ptr *Ctensor, self Ctensor, min Cscalar, max Cscalar) {
|
||||
C.atg_clamp_(ptr, self, min, max)
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@ package nn
|
|||
|
||||
import (
|
||||
// "github.com/sugarme/gotch"
|
||||
"log"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
|
@ -175,5 +177,97 @@ func (c RMSPropConfig) Build(vs VarStore, lr float64) (retVal Optimizer, err err
|
|||
// Optimizer methods:
|
||||
// ==================
|
||||
func (opt *Optimizer) addMissingVariables() {
|
||||
// TODO: implement
|
||||
|
||||
opt.variables.mutex.Lock()
|
||||
defer opt.variables.mutex.Unlock()
|
||||
|
||||
missingVariables := len(opt.variables.TrainableVariable) - int(opt.variablesInOptimizer)
|
||||
|
||||
if missingVariables > 0 {
|
||||
opt.opt.AddParameters(opt.variables.TrainableVariable[opt.variablesInOptimizer:])
|
||||
opt.variablesInOptimizer = uint8(len(opt.variables.TrainableVariable))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// ZeroGrad zeroes the gradient for the tensors tracked by this optimizer.
|
||||
func (opt *Optimizer) ZeroGrad() {
|
||||
opt.addMissingVariables()
|
||||
if err := opt.opt.ZeroGrad(); err != nil {
|
||||
log.Fatalf("Optimizer - ZeroGrad method call error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Clips gradient value at some specified maximum value.
|
||||
func (opt *Optimizer) ClipGradValue(max float64) {
|
||||
|
||||
opt.variables.mutex.Lock()
|
||||
defer opt.variables.mutex.Unlock()
|
||||
|
||||
for _, tensor := range opt.variables.TrainableVariable {
|
||||
tensor.MustGrad().Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max))
|
||||
}
|
||||
}
|
||||
|
||||
// Step performs an optimization step, updating the tracked tensors based on their gradients.
|
||||
func (opt *Optimizer) Step() {
|
||||
opt.addMissingVariables()
|
||||
err := opt.opt.Step()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - Step method call error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// BackwardStep applies a backward step pass, update the gradients, and performs an optimization step.
|
||||
func (opt *Optimizer) BackwardStep(loss ts.Tensor) {
|
||||
opt.addMissingVariables()
|
||||
err := opt.opt.ZeroGrad()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - BackwardStep method call - ZeroGrad error: %v\n", err)
|
||||
}
|
||||
|
||||
loss.MustBackward()
|
||||
|
||||
err = opt.opt.Step()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - BackwardStep method call - Step() error: %v\n", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// BackwardStepClip applies a backward step pass, update the gradients, and performs an optimization step.
|
||||
//
|
||||
// The gradients are clipped based on `max` before being applied.
|
||||
func (opt *Optimizer) BackwardStepClip(loss ts.Tensor, max float64) {
|
||||
opt.addMissingVariables()
|
||||
|
||||
err := opt.opt.ZeroGrad()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - BackwardStepClip method call - ZeroGrad error: %v\n", err)
|
||||
}
|
||||
|
||||
loss.MustBackward()
|
||||
|
||||
opt.ClipGradValue(max)
|
||||
|
||||
err = opt.opt.Step()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - BackwardStepClip method call - Step() error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SetLR sets the optimizer learning rate.
|
||||
func (opt *Optimizer) SetLR(lr float64) {
|
||||
err := opt.opt.SetLearningRate(lr)
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - SetLR method call error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SetMomentum sets the optimizer momentum.
|
||||
func (opt *Optimizer) SetMomentum(m float64) {
|
||||
err := opt.opt.SetMomentum(m)
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - SetMomentum method call error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -730,3 +730,13 @@ func MustRandperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (r
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Clamp_(min Scalar, max Scalar) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgClamp_(ptr, ts.ctensor, min.cscalar, max.cscalar)
|
||||
if err = TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user