WIP: varstore rework
This commit is contained in:
parent
d5634fc397
commit
d95eaba5b3
|
@ -11,8 +11,7 @@ import (
|
|||
|
||||
// Optimizer is a struct object to run gradient descent.
|
||||
type Optimizer struct {
|
||||
opt *ts.COptimizer
|
||||
// variables Variables // having embedded sync.Mutex
|
||||
opt *ts.COptimizer
|
||||
variablesInOptimizer uint8
|
||||
config interface{}
|
||||
stepCount int
|
||||
|
@ -36,7 +35,6 @@ type OptimizerConfig interface {
|
|||
|
||||
// defaultBuild is `default` Build method for OptimizerConfig interface
|
||||
func defaultBuild(config OptimizerConfig, vs *VarStore, lr float64) (retVal *Optimizer, err error) {
|
||||
|
||||
opt, err := config.buildCOpt(lr)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
|
|
708
nn/varstore.go
708
nn/varstore.go
File diff suppressed because it is too large
Load Diff
|
@ -15,9 +15,9 @@ func TestVarStoreEntry(t *testing.T) {
|
|||
vs := nn.NewVarStore(gotch.CPU)
|
||||
root := vs.Root()
|
||||
e1 := root.Entry("key")
|
||||
t1 := e1.OrZeros([]int64{3, 1, 4})
|
||||
t1 := e1.MustOrZeros([]int64{3, 1, 4})
|
||||
e2 := root.Entry("key")
|
||||
t2 := e2.OrZeros([]int64{1, 5, 9})
|
||||
t2 := e2.MustOrZeros([]int64{1, 5, 9})
|
||||
|
||||
wantT1 := []int64{3, 1, 4}
|
||||
wantT2 := []int64{3, 1, 4}
|
||||
|
@ -49,14 +49,14 @@ func TestSaveLoad(t *testing.T) {
|
|||
add := func(vs *nn.Path) (*ts.Tensor, *ts.Tensor) {
|
||||
subA := vs.Sub("a")
|
||||
subB := subA.Sub("b")
|
||||
v := subB.Ones("t2", []int64{3})
|
||||
u := vs.Zeros("t1", []int64{4})
|
||||
v := subB.MustOnes("t2", []int64{3})
|
||||
u := vs.MustZeros("t1", []int64{4})
|
||||
|
||||
wa := vs.Sub("a")
|
||||
wb := wa.Sub("b")
|
||||
wc := wb.Sub("ccc")
|
||||
_ = wc.Ones("t123", []int64{3})
|
||||
_ = wc.Ones("t123", []int64{3})
|
||||
_ = wc.MustOnes("t123", []int64{3})
|
||||
_ = wc.MustOnes("t123", []int64{3})
|
||||
|
||||
return u, v
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user