WIP(example/mnist): nn

This commit is contained in:
sugarme 2020-06-19 21:39:34 +10:00
parent a0faf0799d
commit e0d2e0ca7e
3 changed files with 30 additions and 30 deletions

View File

@ -43,9 +43,9 @@ func netInit(vs nn.Path) ts.Module {
func runNN() {
var ds vision.Dataset
ds = vision.LoadMNISTDir(MnistDirNN)
vs := nn.NewVarStore(gotch.CPU)
net := netInit(vs.Root())
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
if err != nil {
log.Fatal(err)
@ -58,7 +58,7 @@ func runNN() {
opt.BackwardStep(loss)
lossVal := loss.MustShallowClone().MustView([]int64{-1}).MustFloat64Value([]int64{0})
testAccuracy := net.Forward(ds.TestImages).AccuracyForLogits(ds.TestLabels).MustView([]int64{-1}).MustFloat64Value([]int64{0})
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
fmt.Printf("Loss: %v\n", lossVal)
}

View File

@ -49,7 +49,7 @@ func defaultBuild(config OptimizerConfig, vs VarStore, lr float64) (retVal Optim
parameters = append(parameters, v)
}
// if err = opt.AddParameters(vs.variables.TrainableVariable); err != nil {
// if err = opt.AddParameters(vs.variables.TrainableVariables); err != nil {
if err = opt.AddParameters(parameters); err != nil {
return retVal, err
}
@ -57,7 +57,7 @@ func defaultBuild(config OptimizerConfig, vs VarStore, lr float64) (retVal Optim
return Optimizer{
opt: opt,
variables: vs.variables,
variablesInOptimizer: uint8(len(vs.variables.TrainableVariable)),
variablesInOptimizer: uint8(len(vs.variables.TrainableVariables)),
config: config,
}, nil
}
@ -187,11 +187,11 @@ func (opt *Optimizer) addMissingVariables() {
opt.variables.mutex.Lock()
defer opt.variables.mutex.Unlock()
missingVariables := len(opt.variables.TrainableVariable) - int(opt.variablesInOptimizer)
missingVariables := len(opt.variables.TrainableVariables) - int(opt.variablesInOptimizer)
if missingVariables > 0 {
opt.opt.AddParameters(opt.variables.TrainableVariable[opt.variablesInOptimizer:])
opt.variablesInOptimizer = uint8(len(opt.variables.TrainableVariable))
opt.opt.AddParameters(opt.variables.TrainableVariables[opt.variablesInOptimizer:])
opt.variablesInOptimizer = uint8(len(opt.variables.TrainableVariables))
}
}
@ -210,7 +210,7 @@ func (opt *Optimizer) ClipGradValue(max float64) {
opt.variables.mutex.Lock()
defer opt.variables.mutex.Unlock()
for _, tensor := range opt.variables.TrainableVariable {
for _, tensor := range opt.variables.TrainableVariables {
tensor.MustGrad().Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max))
}
}

View File

@ -18,9 +18,9 @@ const SEP = "."
// NOTE: When the variable store is frozen, trainable still is set to tree,
// however the tensor is not set to require gradients.
type Variables struct {
mutex *sync.Mutex
NamedVariables map[string]ts.Tensor
TrainableVariable []ts.Tensor
mutex *sync.Mutex
NamedVariables map[string]ts.Tensor
TrainableVariables []ts.Tensor
}
// VarStore is used to store variables used by one or multiple layers.
@ -33,22 +33,22 @@ type VarStore struct {
// Path is variable store with an associated path for variables naming.
type Path struct {
path []string
varstore VarStore
varstore *VarStore
}
// Entry holds an entry corresponding to a given name in Path.
type Entry struct {
name string
variables Variables // MutexGuard
path Path
variables *Variables // MutexGuard
path *Path
}
// NewVarStore creates a new variable store located on the specified device
func NewVarStore(device gotch.Device) VarStore {
variables := Variables{
mutex: &sync.Mutex{},
NamedVariables: make(map[string]ts.Tensor, 0),
TrainableVariable: make([]ts.Tensor, 0),
mutex: &sync.Mutex{},
NamedVariables: make(map[string]ts.Tensor, 0),
TrainableVariables: make([]ts.Tensor, 0),
}
return VarStore{
@ -86,11 +86,11 @@ func (vs *VarStore) IsEmpty() (retVal bool) {
return retVal
}
// TrainableVariables returns all trainable variables for this var-store
func (vs *VarStore) TrainableVariable() (retVal []ts.Tensor) {
// TrainableVariabless returns all trainable variables for this var-store
func (vs *VarStore) TrainableVariables() (retVal []ts.Tensor) {
vs.variables.mutex.Lock()
defer vs.variables.mutex.Unlock()
retVal = vs.variables.TrainableVariable
retVal = vs.variables.TrainableVariables
return retVal
}
@ -112,7 +112,7 @@ func (vs *VarStore) Variables() (retVal map[string]ts.Tensor) {
func (vs *VarStore) Root() (retVal Path) {
return Path{
path: []string{},
varstore: *vs,
varstore: vs,
}
}
@ -220,7 +220,7 @@ func (vs *VarStore) Freeze() {
vs.variables.mutex.Lock()
defer vs.variables.mutex.Unlock()
for _, v := range vs.variables.TrainableVariable {
for _, v := range vs.variables.TrainableVariables {
_, err := v.SetRequiresGrad(false)
if err != nil {
log.Fatalf("Freeze() Error: %v\n", err)
@ -235,7 +235,7 @@ func (vs *VarStore) Unfreeze() {
vs.variables.mutex.Lock()
defer vs.variables.mutex.Unlock()
for _, v := range vs.variables.TrainableVariable {
for _, v := range vs.variables.TrainableVariables {
_, err := v.SetRequiresGrad(true)
if err != nil {
log.Fatalf("Unfreeze() Error: %v\n", err)
@ -340,7 +340,7 @@ func (p *Path) add(name string, newTs ts.Tensor, trainable bool) (retVal ts.Tens
}
if trainable {
p.varstore.variables.TrainableVariable = append(p.varstore.variables.TrainableVariable, tensor)
p.varstore.variables.TrainableVariables = append(p.varstore.variables.TrainableVariables, tensor)
}
p.varstore.variables.NamedVariables[path] = tensor
@ -369,7 +369,7 @@ func (p *Path) getOrAddWithLock(name string, tensor ts.Tensor, trainable bool, v
}
if trainable {
variables.TrainableVariable = append(variables.TrainableVariable, ttensor)
variables.TrainableVariables = append(variables.TrainableVariables, ttensor)
}
variables.NamedVariables[path] = ttensor
@ -541,8 +541,8 @@ func (p *Path) Entry(name string) (retVal Entry) {
return Entry{
name: name,
variables: p.varstore.variables,
path: *p,
variables: &p.varstore.variables,
path: p,
}
}
@ -558,7 +558,7 @@ func (p *Path) Entry(name string) (retVal Entry) {
func (e *Entry) OrVar(dims []int64, init Init) (retVal ts.Tensor) {
v := init.InitTensor(dims, e.path.varstore.device)
return e.path.getOrAddWithLock(e.name, v, true, e.variables)
return e.path.getOrAddWithLock(e.name, v, true, *e.variables)
}
// Returns the existing entry if, otherwise create a new variable.
@ -593,7 +593,7 @@ func (e *Entry) OrOnes(dims []int64) (retVal ts.Tensor) {
func (e *Entry) OrOnesNoTrain(dims []int64) (retVal ts.Tensor) {
o := ts.MustOnes(dims, gotch.Float.CInt(), e.path.Device().CInt())
return e.path.getOrAddWithLock(e.name, o, true, e.variables)
return e.path.getOrAddWithLock(e.name, o, true, *e.variables)
}
// OrRandn returns the existing entry if, otherwise create a new variable.
@ -624,7 +624,7 @@ func (e *Entry) OrZeros(dims []int64) (retVal ts.Tensor) {
func (e *Entry) OrZerosNoTrain(dims []int64) (retVal ts.Tensor) {
z := ts.MustZeros(dims, gotch.Float.CInt(), e.path.Device().CInt())
return e.path.getOrAddWithLock(e.name, z, true, e.variables)
return e.path.getOrAddWithLock(e.name, z, true, *e.variables)
}
// TODO: can we implement `Div` operator in Go?