WIP(example/mnist): nn
This commit is contained in:
parent
a0faf0799d
commit
e0d2e0ca7e
|
@ -43,9 +43,9 @@ func netInit(vs nn.Path) ts.Module {
|
|||
func runNN() {
|
||||
var ds vision.Dataset
|
||||
ds = vision.LoadMNISTDir(MnistDirNN)
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
net := netInit(vs.Root())
|
||||
|
||||
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -58,7 +58,7 @@ func runNN() {
|
|||
opt.BackwardStep(loss)
|
||||
lossVal := loss.MustShallowClone().MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
testAccuracy := net.Forward(ds.TestImages).AccuracyForLogits(ds.TestLabels).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
|
||||
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
|
||||
|
||||
fmt.Printf("Loss: %v\n", lossVal)
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ func defaultBuild(config OptimizerConfig, vs VarStore, lr float64) (retVal Optim
|
|||
parameters = append(parameters, v)
|
||||
}
|
||||
|
||||
// if err = opt.AddParameters(vs.variables.TrainableVariable); err != nil {
|
||||
// if err = opt.AddParameters(vs.variables.TrainableVariables); err != nil {
|
||||
if err = opt.AddParameters(parameters); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ func defaultBuild(config OptimizerConfig, vs VarStore, lr float64) (retVal Optim
|
|||
return Optimizer{
|
||||
opt: opt,
|
||||
variables: vs.variables,
|
||||
variablesInOptimizer: uint8(len(vs.variables.TrainableVariable)),
|
||||
variablesInOptimizer: uint8(len(vs.variables.TrainableVariables)),
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
@ -187,11 +187,11 @@ func (opt *Optimizer) addMissingVariables() {
|
|||
opt.variables.mutex.Lock()
|
||||
defer opt.variables.mutex.Unlock()
|
||||
|
||||
missingVariables := len(opt.variables.TrainableVariable) - int(opt.variablesInOptimizer)
|
||||
missingVariables := len(opt.variables.TrainableVariables) - int(opt.variablesInOptimizer)
|
||||
|
||||
if missingVariables > 0 {
|
||||
opt.opt.AddParameters(opt.variables.TrainableVariable[opt.variablesInOptimizer:])
|
||||
opt.variablesInOptimizer = uint8(len(opt.variables.TrainableVariable))
|
||||
opt.opt.AddParameters(opt.variables.TrainableVariables[opt.variablesInOptimizer:])
|
||||
opt.variablesInOptimizer = uint8(len(opt.variables.TrainableVariables))
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -210,7 +210,7 @@ func (opt *Optimizer) ClipGradValue(max float64) {
|
|||
opt.variables.mutex.Lock()
|
||||
defer opt.variables.mutex.Unlock()
|
||||
|
||||
for _, tensor := range opt.variables.TrainableVariable {
|
||||
for _, tensor := range opt.variables.TrainableVariables {
|
||||
tensor.MustGrad().Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,9 +18,9 @@ const SEP = "."
|
|||
// NOTE: When the variable store is frozen, trainable still is set to tree,
|
||||
// however the tensor is not set to require gradients.
|
||||
type Variables struct {
|
||||
mutex *sync.Mutex
|
||||
NamedVariables map[string]ts.Tensor
|
||||
TrainableVariable []ts.Tensor
|
||||
mutex *sync.Mutex
|
||||
NamedVariables map[string]ts.Tensor
|
||||
TrainableVariables []ts.Tensor
|
||||
}
|
||||
|
||||
// VarStore is used to store variables used by one or multiple layers.
|
||||
|
@ -33,22 +33,22 @@ type VarStore struct {
|
|||
// Path is variable store with an associated path for variables naming.
|
||||
type Path struct {
|
||||
path []string
|
||||
varstore VarStore
|
||||
varstore *VarStore
|
||||
}
|
||||
|
||||
// Entry holds an entry corresponding to a given name in Path.
|
||||
type Entry struct {
|
||||
name string
|
||||
variables Variables // MutexGuard
|
||||
path Path
|
||||
variables *Variables // MutexGuard
|
||||
path *Path
|
||||
}
|
||||
|
||||
// NewVarStore creates a new variable store located on the specified device
|
||||
func NewVarStore(device gotch.Device) VarStore {
|
||||
variables := Variables{
|
||||
mutex: &sync.Mutex{},
|
||||
NamedVariables: make(map[string]ts.Tensor, 0),
|
||||
TrainableVariable: make([]ts.Tensor, 0),
|
||||
mutex: &sync.Mutex{},
|
||||
NamedVariables: make(map[string]ts.Tensor, 0),
|
||||
TrainableVariables: make([]ts.Tensor, 0),
|
||||
}
|
||||
|
||||
return VarStore{
|
||||
|
@ -86,11 +86,11 @@ func (vs *VarStore) IsEmpty() (retVal bool) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
// TrainableVariables returns all trainable variables for this var-store
|
||||
func (vs *VarStore) TrainableVariable() (retVal []ts.Tensor) {
|
||||
// TrainableVariabless returns all trainable variables for this var-store
|
||||
func (vs *VarStore) TrainableVariables() (retVal []ts.Tensor) {
|
||||
vs.variables.mutex.Lock()
|
||||
defer vs.variables.mutex.Unlock()
|
||||
retVal = vs.variables.TrainableVariable
|
||||
retVal = vs.variables.TrainableVariables
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
@ -112,7 +112,7 @@ func (vs *VarStore) Variables() (retVal map[string]ts.Tensor) {
|
|||
func (vs *VarStore) Root() (retVal Path) {
|
||||
return Path{
|
||||
path: []string{},
|
||||
varstore: *vs,
|
||||
varstore: vs,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -220,7 +220,7 @@ func (vs *VarStore) Freeze() {
|
|||
vs.variables.mutex.Lock()
|
||||
defer vs.variables.mutex.Unlock()
|
||||
|
||||
for _, v := range vs.variables.TrainableVariable {
|
||||
for _, v := range vs.variables.TrainableVariables {
|
||||
_, err := v.SetRequiresGrad(false)
|
||||
if err != nil {
|
||||
log.Fatalf("Freeze() Error: %v\n", err)
|
||||
|
@ -235,7 +235,7 @@ func (vs *VarStore) Unfreeze() {
|
|||
vs.variables.mutex.Lock()
|
||||
defer vs.variables.mutex.Unlock()
|
||||
|
||||
for _, v := range vs.variables.TrainableVariable {
|
||||
for _, v := range vs.variables.TrainableVariables {
|
||||
_, err := v.SetRequiresGrad(true)
|
||||
if err != nil {
|
||||
log.Fatalf("Unfreeze() Error: %v\n", err)
|
||||
|
@ -340,7 +340,7 @@ func (p *Path) add(name string, newTs ts.Tensor, trainable bool) (retVal ts.Tens
|
|||
}
|
||||
|
||||
if trainable {
|
||||
p.varstore.variables.TrainableVariable = append(p.varstore.variables.TrainableVariable, tensor)
|
||||
p.varstore.variables.TrainableVariables = append(p.varstore.variables.TrainableVariables, tensor)
|
||||
}
|
||||
|
||||
p.varstore.variables.NamedVariables[path] = tensor
|
||||
|
@ -369,7 +369,7 @@ func (p *Path) getOrAddWithLock(name string, tensor ts.Tensor, trainable bool, v
|
|||
}
|
||||
|
||||
if trainable {
|
||||
variables.TrainableVariable = append(variables.TrainableVariable, ttensor)
|
||||
variables.TrainableVariables = append(variables.TrainableVariables, ttensor)
|
||||
}
|
||||
|
||||
variables.NamedVariables[path] = ttensor
|
||||
|
@ -541,8 +541,8 @@ func (p *Path) Entry(name string) (retVal Entry) {
|
|||
|
||||
return Entry{
|
||||
name: name,
|
||||
variables: p.varstore.variables,
|
||||
path: *p,
|
||||
variables: &p.varstore.variables,
|
||||
path: p,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -558,7 +558,7 @@ func (p *Path) Entry(name string) (retVal Entry) {
|
|||
func (e *Entry) OrVar(dims []int64, init Init) (retVal ts.Tensor) {
|
||||
|
||||
v := init.InitTensor(dims, e.path.varstore.device)
|
||||
return e.path.getOrAddWithLock(e.name, v, true, e.variables)
|
||||
return e.path.getOrAddWithLock(e.name, v, true, *e.variables)
|
||||
}
|
||||
|
||||
// Returns the existing entry if, otherwise create a new variable.
|
||||
|
@ -593,7 +593,7 @@ func (e *Entry) OrOnes(dims []int64) (retVal ts.Tensor) {
|
|||
func (e *Entry) OrOnesNoTrain(dims []int64) (retVal ts.Tensor) {
|
||||
|
||||
o := ts.MustOnes(dims, gotch.Float.CInt(), e.path.Device().CInt())
|
||||
return e.path.getOrAddWithLock(e.name, o, true, e.variables)
|
||||
return e.path.getOrAddWithLock(e.name, o, true, *e.variables)
|
||||
}
|
||||
|
||||
// OrRandn returns the existing entry if, otherwise create a new variable.
|
||||
|
@ -624,7 +624,7 @@ func (e *Entry) OrZeros(dims []int64) (retVal ts.Tensor) {
|
|||
func (e *Entry) OrZerosNoTrain(dims []int64) (retVal ts.Tensor) {
|
||||
|
||||
z := ts.MustZeros(dims, gotch.Float.CInt(), e.path.Device().CInt())
|
||||
return e.path.getOrAddWithLock(e.name, z, true, e.variables)
|
||||
return e.path.getOrAddWithLock(e.name, z, true, *e.variables)
|
||||
}
|
||||
|
||||
// TODO: can we implement `Div` operator in Go?
|
||||
|
|
Loading…
Reference in New Issue
Block a user