WIP(example/mnist): nn
This commit is contained in:
parent
9be19702d1
commit
a636372144
17
example/linear/main.go
Normal file
17
example/linear/main.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
|
||||
path := vs.Root()
|
||||
|
||||
l := nn.NewLinear(path, 4, 3, nn.DefaultLinearConfig())
|
||||
|
||||
l.Bs.Print()
|
||||
}
|
Binary file not shown.
|
@ -16,19 +16,24 @@ const (
|
|||
LabelNN int64 = 10
|
||||
MnistDirNN string = "../../data/mnist"
|
||||
|
||||
epochsNN = 200
|
||||
epochsNN = 3
|
||||
batchSizeNN = 256
|
||||
|
||||
LrNN = 1e-3
|
||||
)
|
||||
|
||||
var l nn.Linear
|
||||
|
||||
func netInit(vs nn.Path) ts.Module {
|
||||
n := nn.Seq()
|
||||
|
||||
n.Add(nn.NewLinear(vs.Sub("layer1"), ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig()))
|
||||
n.AddFn(func(xs ts.Tensor) ts.Tensor {
|
||||
l = nn.NewLinear(vs.Sub("layer1"), ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig())
|
||||
|
||||
n.Add(l)
|
||||
|
||||
n.AddFn(nn.ForwardWith(func(xs ts.Tensor) ts.Tensor {
|
||||
return xs.MustRelu()
|
||||
})
|
||||
}))
|
||||
|
||||
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig()))
|
||||
|
||||
|
@ -46,13 +51,19 @@ func runNN() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
bsClone := l.Bs.MustShallowClone()
|
||||
|
||||
for epoch := 0; epoch < epochsNN; epoch++ {
|
||||
loss := net.Forward(ds.TrainImages).CrossEntropyForLogits(ds.TrainLabels)
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
fmt.Printf("Bs vals: %v\n", bsClone.MustToString(int64(1)))
|
||||
|
||||
lossVal := loss.MustShallowClone().MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
|
||||
testAccuracy := net.Forward(ds.TestImages).AccuracyForLogits(ds.TestLabels).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
|
||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss, testAccuracy*100)
|
||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
26
example/varstore/main.go
Normal file
26
example/varstore/main.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
|
||||
fmt.Printf("Is VarStore emptry? %v\n ", vs.IsEmpty())
|
||||
|
||||
path := vs.Root()
|
||||
|
||||
init := nn.NewKaimingUniformInit()
|
||||
|
||||
init.InitTensor([]int64{1, 4}, gotch.CPU).Print()
|
||||
|
||||
path.NewVar("layer1", []int64{1, 10}, nn.NewKaimingUniformInit())
|
||||
|
||||
fmt.Printf("Is VarStore emptry? %v\n ", vs.IsEmpty())
|
||||
|
||||
}
|
|
@ -257,3 +257,13 @@ func AtgRelu(ptr *Ctensor, self Ctensor) {
|
|||
func AtgRelu_(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_relu_(ptr, self)
|
||||
}
|
||||
|
||||
// void atg_t(tensor *, tensor self);
|
||||
func AtgT(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_t(ptr, self)
|
||||
}
|
||||
|
||||
// void atg_t_(tensor *, tensor self);
|
||||
func AtgT_(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_t_(ptr, self)
|
||||
}
|
||||
|
|
23
nn/linear.go
23
nn/linear.go
|
@ -3,6 +3,8 @@ package nn
|
|||
// linear is a fully-connected layer
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
@ -16,8 +18,8 @@ type LinearConfig struct {
|
|||
|
||||
// DefaultLinearConfig creates default LinearConfig with
|
||||
// weights initiated using KaimingUniform and Bias is set to true
|
||||
func DefaultLinearConfig() *LinearConfig {
|
||||
return &LinearConfig{
|
||||
func DefaultLinearConfig() LinearConfig {
|
||||
return LinearConfig{
|
||||
WsInit: NewKaimingUniformInit(),
|
||||
BsInit: nil,
|
||||
Bias: true,
|
||||
|
@ -35,7 +37,7 @@ type Linear struct {
|
|||
// inDim - input dimension (x) [input features - columns]
|
||||
// outDim - output dimension (y) [output features - columns]
|
||||
// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim}
|
||||
func NewLinear(vs Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
||||
func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) Linear {
|
||||
|
||||
var bs ts.Tensor
|
||||
// bs has size of output dimension
|
||||
|
@ -43,10 +45,17 @@ func NewLinear(vs Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
|||
case false:
|
||||
bs = ts.MustZeros([]int64{outDim}, gotch.Float.CInt(), vs.Device().CInt())
|
||||
case true:
|
||||
bs = vs.NewVar("bias", []int64{outDim}, c.BsInit)
|
||||
switch {
|
||||
case c.BsInit == nil:
|
||||
bound := 1.0 / math.Sqrt(float64(inDim))
|
||||
bsInit := NewUniformInit(-bound, bound)
|
||||
bs = vs.NewVar("bias", []int64{outDim}, bsInit)
|
||||
case c.BsInit != nil:
|
||||
bs = vs.NewVar("bias", []int64{outDim}, c.BsInit)
|
||||
}
|
||||
}
|
||||
|
||||
return &Linear{
|
||||
return Linear{
|
||||
Ws: vs.NewVar("weight", []int64{outDim, inDim}, c.WsInit),
|
||||
Bs: bs,
|
||||
}
|
||||
|
@ -80,7 +89,7 @@ func NewLinear(vs Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
|||
// 1 1 1
|
||||
// 1 1 1
|
||||
// 1 1 1 ]
|
||||
func (l *Linear) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||
func (l Linear) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||
|
||||
return xs.MustMatMul(l.Ws).MustAdd(l.Bs)
|
||||
return xs.MustMatMul(l.Ws.MustT()).MustAdd(l.Bs)
|
||||
}
|
||||
|
|
|
@ -44,8 +44,10 @@ func defaultBuild(config OptimizerConfig, vs VarStore, lr float64) (retVal Optim
|
|||
vs.variables.mutex.Lock()
|
||||
defer vs.variables.mutex.Unlock()
|
||||
|
||||
if err = opt.AddParameters(vs.variables.TrainableVariable); err != nil {
|
||||
return retVal, err
|
||||
if len(vs.variables.TrainableVariable) > 0 {
|
||||
if err = opt.AddParameters(vs.variables.TrainableVariable); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
}
|
||||
|
||||
return Optimizer{
|
||||
|
@ -220,6 +222,7 @@ func (opt *Optimizer) Step() {
|
|||
|
||||
// BackwardStep applies a backward step pass, update the gradients, and performs an optimization step.
|
||||
func (opt *Optimizer) BackwardStep(loss ts.Tensor) {
|
||||
|
||||
opt.addMissingVariables()
|
||||
err := opt.opt.ZeroGrad()
|
||||
if err != nil {
|
||||
|
|
|
@ -39,9 +39,9 @@ func (s *Sequential) Add(l ts.Module) {
|
|||
//
|
||||
// NOTE: fn should have signature `func(t ts.Tensor) ts.Tensor`
|
||||
// and it implements Module interface
|
||||
func (s *Sequential) AddFn(fn interface{}) {
|
||||
func (s *Sequential) AddFn(fn ts.Module) {
|
||||
|
||||
s.Add(fn.(ts.Module))
|
||||
s.Add(fn)
|
||||
}
|
||||
|
||||
// ForwardAll applies the forward pass and returns the output for each layer.
|
||||
|
@ -144,18 +144,18 @@ func (s *SequentialT) Add(l ts.ModuleT) {
|
|||
//
|
||||
// NOTE: fn should have signature `func(t ts.Tensor) ts.Tensor`
|
||||
// and it implements Module interface
|
||||
func (s *SequentialT) AddFn(fn interface{}) {
|
||||
func (s *SequentialT) AddFn(fn ts.ModuleT) {
|
||||
|
||||
s.Add(fn.(ts.ModuleT))
|
||||
s.Add(fn)
|
||||
}
|
||||
|
||||
// AddFn appends a closure after all the current layers.
|
||||
//
|
||||
// NOTE: fn should have signature `func(t ts.Tensor, train bool) ts.Tensor`
|
||||
// and it implements Module interface
|
||||
func (s *SequentialT) AddFnT(fn interface{}) {
|
||||
func (s *SequentialT) AddFnT(fn ts.ModuleT) {
|
||||
|
||||
s.Add(fn.(ts.ModuleT))
|
||||
s.Add(fn)
|
||||
}
|
||||
|
||||
// ForwardAll applies the forward pass and returns the output for each layer.
|
||||
|
@ -176,3 +176,21 @@ func (s *SequentialT) ForwardAllT(xs ts.Tensor, train bool, opts ...uint8) (retV
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// ForwardWith is a handler function to implement Module interface for
|
||||
// any (anonymous) function it wraps.
|
||||
//
|
||||
// Ref. https://stackoverflow.com/a/42182987
|
||||
// NOTE: Specifically, `ForwardWith` is used to wrap anonymous function
|
||||
// as input parameter of `AddFn` Sequential method.
|
||||
type ForwardWith func(ts.Tensor) ts.Tensor
|
||||
|
||||
func (fw ForwardWith) Forward(xs ts.Tensor) ts.Tensor {
|
||||
return fw(xs)
|
||||
}
|
||||
|
||||
type ForwardTWith func(ts.Tensor, bool) ts.Tensor
|
||||
|
||||
func (fw ForwardTWith) ForwardT(xs ts.Tensor, train bool) ts.Tensor {
|
||||
return fw(xs, train)
|
||||
}
|
||||
|
|
|
@ -57,6 +57,9 @@ func NewVarStore(device gotch.Device) VarStore {
|
|||
}
|
||||
}
|
||||
|
||||
// NOTE:
|
||||
// To get (initiate) a path, call vs.Root()
|
||||
|
||||
// VarStore methods:
|
||||
// =================
|
||||
|
||||
|
@ -417,9 +420,10 @@ func (p *Path) OnesNoTrain(name string, dims []int64) (retVal ts.Tensor) {
|
|||
// will be tracked.
|
||||
// The variable uses a float tensor initialized as per the
|
||||
// related argument.
|
||||
func (p *Path) NewVar(name string, dims []int64, init Init) (retVal ts.Tensor) {
|
||||
func (p *Path) NewVar(name string, dims []int64, ini Init) (retVal ts.Tensor) {
|
||||
|
||||
v := ini.InitTensor(dims, p.varstore.device)
|
||||
|
||||
v := init.InitTensor(dims, p.varstore.device)
|
||||
return p.add(name, v, true)
|
||||
}
|
||||
|
||||
|
|
|
@ -774,3 +774,38 @@ func (ts Tensor) MustRelu() (retVal Tensor) {
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) T() (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgT(ptr, ts.ctensor)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustT() (retVal Tensor) {
|
||||
retVal, err := ts.T()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) T_() {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgT_(ptr, ts.ctensor)
|
||||
err := TorchErr()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user