WIP(example/mnist): nn
This commit is contained in:
parent
9be19702d1
commit
a636372144
17
example/linear/main.go
Normal file
17
example/linear/main.go
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
"github.com/sugarme/gotch/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
|
||||||
|
vs := nn.NewVarStore(gotch.CPU)
|
||||||
|
|
||||||
|
path := vs.Root()
|
||||||
|
|
||||||
|
l := nn.NewLinear(path, 4, 3, nn.DefaultLinearConfig())
|
||||||
|
|
||||||
|
l.Bs.Print()
|
||||||
|
}
|
Binary file not shown.
|
@ -16,19 +16,24 @@ const (
|
||||||
LabelNN int64 = 10
|
LabelNN int64 = 10
|
||||||
MnistDirNN string = "../../data/mnist"
|
MnistDirNN string = "../../data/mnist"
|
||||||
|
|
||||||
epochsNN = 200
|
epochsNN = 3
|
||||||
batchSizeNN = 256
|
batchSizeNN = 256
|
||||||
|
|
||||||
LrNN = 1e-3
|
LrNN = 1e-3
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var l nn.Linear
|
||||||
|
|
||||||
func netInit(vs nn.Path) ts.Module {
|
func netInit(vs nn.Path) ts.Module {
|
||||||
n := nn.Seq()
|
n := nn.Seq()
|
||||||
|
|
||||||
n.Add(nn.NewLinear(vs.Sub("layer1"), ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig()))
|
l = nn.NewLinear(vs.Sub("layer1"), ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig())
|
||||||
n.AddFn(func(xs ts.Tensor) ts.Tensor {
|
|
||||||
|
n.Add(l)
|
||||||
|
|
||||||
|
n.AddFn(nn.ForwardWith(func(xs ts.Tensor) ts.Tensor {
|
||||||
return xs.MustRelu()
|
return xs.MustRelu()
|
||||||
})
|
}))
|
||||||
|
|
||||||
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig()))
|
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig()))
|
||||||
|
|
||||||
|
@ -46,13 +51,19 @@ func runNN() {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bsClone := l.Bs.MustShallowClone()
|
||||||
|
|
||||||
for epoch := 0; epoch < epochsNN; epoch++ {
|
for epoch := 0; epoch < epochsNN; epoch++ {
|
||||||
loss := net.Forward(ds.TrainImages).CrossEntropyForLogits(ds.TrainLabels)
|
loss := net.Forward(ds.TrainImages).CrossEntropyForLogits(ds.TrainLabels)
|
||||||
opt.BackwardStep(loss)
|
opt.BackwardStep(loss)
|
||||||
|
|
||||||
|
fmt.Printf("Bs vals: %v\n", bsClone.MustToString(int64(1)))
|
||||||
|
|
||||||
|
lossVal := loss.MustShallowClone().MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||||
|
|
||||||
testAccuracy := net.Forward(ds.TestImages).AccuracyForLogits(ds.TestLabels).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
testAccuracy := net.Forward(ds.TestImages).AccuracyForLogits(ds.TestLabels).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||||
|
|
||||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss, testAccuracy*100)
|
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
26
example/varstore/main.go
Normal file
26
example/varstore/main.go
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
"github.com/sugarme/gotch/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
|
||||||
|
vs := nn.NewVarStore(gotch.CPU)
|
||||||
|
|
||||||
|
fmt.Printf("Is VarStore emptry? %v\n ", vs.IsEmpty())
|
||||||
|
|
||||||
|
path := vs.Root()
|
||||||
|
|
||||||
|
init := nn.NewKaimingUniformInit()
|
||||||
|
|
||||||
|
init.InitTensor([]int64{1, 4}, gotch.CPU).Print()
|
||||||
|
|
||||||
|
path.NewVar("layer1", []int64{1, 10}, nn.NewKaimingUniformInit())
|
||||||
|
|
||||||
|
fmt.Printf("Is VarStore emptry? %v\n ", vs.IsEmpty())
|
||||||
|
|
||||||
|
}
|
|
@ -257,3 +257,13 @@ func AtgRelu(ptr *Ctensor, self Ctensor) {
|
||||||
func AtgRelu_(ptr *Ctensor, self Ctensor) {
|
func AtgRelu_(ptr *Ctensor, self Ctensor) {
|
||||||
C.atg_relu_(ptr, self)
|
C.atg_relu_(ptr, self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// void atg_t(tensor *, tensor self);
|
||||||
|
func AtgT(ptr *Ctensor, self Ctensor) {
|
||||||
|
C.atg_t(ptr, self)
|
||||||
|
}
|
||||||
|
|
||||||
|
// void atg_t_(tensor *, tensor self);
|
||||||
|
func AtgT_(ptr *Ctensor, self Ctensor) {
|
||||||
|
C.atg_t_(ptr, self)
|
||||||
|
}
|
||||||
|
|
23
nn/linear.go
23
nn/linear.go
|
@ -3,6 +3,8 @@ package nn
|
||||||
// linear is a fully-connected layer
|
// linear is a fully-connected layer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
"github.com/sugarme/gotch"
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
)
|
)
|
||||||
|
@ -16,8 +18,8 @@ type LinearConfig struct {
|
||||||
|
|
||||||
// DefaultLinearConfig creates default LinearConfig with
|
// DefaultLinearConfig creates default LinearConfig with
|
||||||
// weights initiated using KaimingUniform and Bias is set to true
|
// weights initiated using KaimingUniform and Bias is set to true
|
||||||
func DefaultLinearConfig() *LinearConfig {
|
func DefaultLinearConfig() LinearConfig {
|
||||||
return &LinearConfig{
|
return LinearConfig{
|
||||||
WsInit: NewKaimingUniformInit(),
|
WsInit: NewKaimingUniformInit(),
|
||||||
BsInit: nil,
|
BsInit: nil,
|
||||||
Bias: true,
|
Bias: true,
|
||||||
|
@ -35,7 +37,7 @@ type Linear struct {
|
||||||
// inDim - input dimension (x) [input features - columns]
|
// inDim - input dimension (x) [input features - columns]
|
||||||
// outDim - output dimension (y) [output features - columns]
|
// outDim - output dimension (y) [output features - columns]
|
||||||
// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim}
|
// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim}
|
||||||
func NewLinear(vs Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) Linear {
|
||||||
|
|
||||||
var bs ts.Tensor
|
var bs ts.Tensor
|
||||||
// bs has size of output dimension
|
// bs has size of output dimension
|
||||||
|
@ -43,10 +45,17 @@ func NewLinear(vs Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
||||||
case false:
|
case false:
|
||||||
bs = ts.MustZeros([]int64{outDim}, gotch.Float.CInt(), vs.Device().CInt())
|
bs = ts.MustZeros([]int64{outDim}, gotch.Float.CInt(), vs.Device().CInt())
|
||||||
case true:
|
case true:
|
||||||
bs = vs.NewVar("bias", []int64{outDim}, c.BsInit)
|
switch {
|
||||||
|
case c.BsInit == nil:
|
||||||
|
bound := 1.0 / math.Sqrt(float64(inDim))
|
||||||
|
bsInit := NewUniformInit(-bound, bound)
|
||||||
|
bs = vs.NewVar("bias", []int64{outDim}, bsInit)
|
||||||
|
case c.BsInit != nil:
|
||||||
|
bs = vs.NewVar("bias", []int64{outDim}, c.BsInit)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Linear{
|
return Linear{
|
||||||
Ws: vs.NewVar("weight", []int64{outDim, inDim}, c.WsInit),
|
Ws: vs.NewVar("weight", []int64{outDim, inDim}, c.WsInit),
|
||||||
Bs: bs,
|
Bs: bs,
|
||||||
}
|
}
|
||||||
|
@ -80,7 +89,7 @@ func NewLinear(vs Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
||||||
// 1 1 1
|
// 1 1 1
|
||||||
// 1 1 1
|
// 1 1 1
|
||||||
// 1 1 1 ]
|
// 1 1 1 ]
|
||||||
func (l *Linear) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
func (l Linear) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||||
|
|
||||||
return xs.MustMatMul(l.Ws).MustAdd(l.Bs)
|
return xs.MustMatMul(l.Ws.MustT()).MustAdd(l.Bs)
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,8 +44,10 @@ func defaultBuild(config OptimizerConfig, vs VarStore, lr float64) (retVal Optim
|
||||||
vs.variables.mutex.Lock()
|
vs.variables.mutex.Lock()
|
||||||
defer vs.variables.mutex.Unlock()
|
defer vs.variables.mutex.Unlock()
|
||||||
|
|
||||||
if err = opt.AddParameters(vs.variables.TrainableVariable); err != nil {
|
if len(vs.variables.TrainableVariable) > 0 {
|
||||||
return retVal, err
|
if err = opt.AddParameters(vs.variables.TrainableVariable); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Optimizer{
|
return Optimizer{
|
||||||
|
@ -220,6 +222,7 @@ func (opt *Optimizer) Step() {
|
||||||
|
|
||||||
// BackwardStep applies a backward step pass, update the gradients, and performs an optimization step.
|
// BackwardStep applies a backward step pass, update the gradients, and performs an optimization step.
|
||||||
func (opt *Optimizer) BackwardStep(loss ts.Tensor) {
|
func (opt *Optimizer) BackwardStep(loss ts.Tensor) {
|
||||||
|
|
||||||
opt.addMissingVariables()
|
opt.addMissingVariables()
|
||||||
err := opt.opt.ZeroGrad()
|
err := opt.opt.ZeroGrad()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -39,9 +39,9 @@ func (s *Sequential) Add(l ts.Module) {
|
||||||
//
|
//
|
||||||
// NOTE: fn should have signature `func(t ts.Tensor) ts.Tensor`
|
// NOTE: fn should have signature `func(t ts.Tensor) ts.Tensor`
|
||||||
// and it implements Module interface
|
// and it implements Module interface
|
||||||
func (s *Sequential) AddFn(fn interface{}) {
|
func (s *Sequential) AddFn(fn ts.Module) {
|
||||||
|
|
||||||
s.Add(fn.(ts.Module))
|
s.Add(fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForwardAll applies the forward pass and returns the output for each layer.
|
// ForwardAll applies the forward pass and returns the output for each layer.
|
||||||
|
@ -144,18 +144,18 @@ func (s *SequentialT) Add(l ts.ModuleT) {
|
||||||
//
|
//
|
||||||
// NOTE: fn should have signature `func(t ts.Tensor) ts.Tensor`
|
// NOTE: fn should have signature `func(t ts.Tensor) ts.Tensor`
|
||||||
// and it implements Module interface
|
// and it implements Module interface
|
||||||
func (s *SequentialT) AddFn(fn interface{}) {
|
func (s *SequentialT) AddFn(fn ts.ModuleT) {
|
||||||
|
|
||||||
s.Add(fn.(ts.ModuleT))
|
s.Add(fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddFn appends a closure after all the current layers.
|
// AddFn appends a closure after all the current layers.
|
||||||
//
|
//
|
||||||
// NOTE: fn should have signature `func(t ts.Tensor, train bool) ts.Tensor`
|
// NOTE: fn should have signature `func(t ts.Tensor, train bool) ts.Tensor`
|
||||||
// and it implements Module interface
|
// and it implements Module interface
|
||||||
func (s *SequentialT) AddFnT(fn interface{}) {
|
func (s *SequentialT) AddFnT(fn ts.ModuleT) {
|
||||||
|
|
||||||
s.Add(fn.(ts.ModuleT))
|
s.Add(fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForwardAll applies the forward pass and returns the output for each layer.
|
// ForwardAll applies the forward pass and returns the output for each layer.
|
||||||
|
@ -176,3 +176,21 @@ func (s *SequentialT) ForwardAllT(xs ts.Tensor, train bool, opts ...uint8) (retV
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ForwardWith is a handler function to implement Module interface for
|
||||||
|
// any (anonymous) function it wraps.
|
||||||
|
//
|
||||||
|
// Ref. https://stackoverflow.com/a/42182987
|
||||||
|
// NOTE: Specifically, `ForwardWith` is used to wrap anonymous function
|
||||||
|
// as input parameter of `AddFn` Sequential method.
|
||||||
|
type ForwardWith func(ts.Tensor) ts.Tensor
|
||||||
|
|
||||||
|
func (fw ForwardWith) Forward(xs ts.Tensor) ts.Tensor {
|
||||||
|
return fw(xs)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ForwardTWith func(ts.Tensor, bool) ts.Tensor
|
||||||
|
|
||||||
|
func (fw ForwardTWith) ForwardT(xs ts.Tensor, train bool) ts.Tensor {
|
||||||
|
return fw(xs, train)
|
||||||
|
}
|
||||||
|
|
|
@ -57,6 +57,9 @@ func NewVarStore(device gotch.Device) VarStore {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE:
|
||||||
|
// To get (initiate) a path, call vs.Root()
|
||||||
|
|
||||||
// VarStore methods:
|
// VarStore methods:
|
||||||
// =================
|
// =================
|
||||||
|
|
||||||
|
@ -417,9 +420,10 @@ func (p *Path) OnesNoTrain(name string, dims []int64) (retVal ts.Tensor) {
|
||||||
// will be tracked.
|
// will be tracked.
|
||||||
// The variable uses a float tensor initialized as per the
|
// The variable uses a float tensor initialized as per the
|
||||||
// related argument.
|
// related argument.
|
||||||
func (p *Path) NewVar(name string, dims []int64, init Init) (retVal ts.Tensor) {
|
func (p *Path) NewVar(name string, dims []int64, ini Init) (retVal ts.Tensor) {
|
||||||
|
|
||||||
|
v := ini.InitTensor(dims, p.varstore.device)
|
||||||
|
|
||||||
v := init.InitTensor(dims, p.varstore.device)
|
|
||||||
return p.add(name, v, true)
|
return p.add(name, v, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -774,3 +774,38 @@ func (ts Tensor) MustRelu() (retVal Tensor) {
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) T() (retVal Tensor, err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
lib.AtgT(ptr, ts.ctensor)
|
||||||
|
err = TorchErr()
|
||||||
|
if err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustT() (retVal Tensor) {
|
||||||
|
retVal, err := ts.T()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) T_() {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
lib.AtgT_(ptr, ts.ctensor)
|
||||||
|
err := TorchErr()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user