diff --git a/example/mnist/linear.go b/example/mnist/linear.go index cceabb8..3e60182 100644 --- a/example/mnist/linear.go +++ b/example/mnist/linear.go @@ -41,7 +41,7 @@ func runLinear() { }) testLogits := ds.TestImages.MustMm(ws, false).MustAdd(bs, true) - testAccuracy := testLogits.MustArgmax(-1, false, true).MustEq1(ds.TestLabels).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true).MustView([]int64{-1}).MustFloat64Value([]int64{0}) + testAccuracy := testLogits.MustArgmax(-1, false, true).MustEq1(ds.TestLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true).MustView([]int64{-1}).MustFloat64Value([]int64{0}) fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100) diff --git a/example/mnist/nn.go b/example/mnist/nn.go index 094ee76..09d386b 100644 --- a/example/mnist/nn.go +++ b/example/mnist/nn.go @@ -28,9 +28,9 @@ func netInit(vs nn.Path) ts.Module { n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, *nn.DefaultLinearConfig())) - n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor { - return xs.MustRelu(true) - })) + // n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor { + // return xs.MustRelu(true) + // })) n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, *nn.DefaultLinearConfig())) // n.Add(nn.NewLinear(vs, ImageDimNN, LabelNN, nn.DefaultLinearConfig())) @@ -40,7 +40,8 @@ func netInit(vs nn.Path) ts.Module { func train(trainX, trainY, testX, testY ts.Tensor, m ts.Module, opt nn.Optimizer, epoch int) { - loss := m.Forward(trainX).CrossEntropyForLogits(trainY) + logits := m.Forward(trainX) + loss := logits.CrossEntropyForLogits(trainY) opt.BackwardStep(loss) diff --git a/example/seq/main.go b/example/seq/main.go new file mode 100644 index 0000000..180618d --- /dev/null +++ b/example/seq/main.go @@ -0,0 +1,205 @@ +package main + +import ( + "fmt" + "log" + "math" + + "github.com/sugarme/gotch" + "github.com/sugarme/gotch/nn" + ts "github.com/sugarme/gotch/tensor" + "github.com/sugarme/gotch/vision" +) + +func main() { + // noSeq() + withSeq() + // noSeq2Layers() + + // seqNoVarStore() +} + +func noSeq() { + ds := vision.LoadMNISTDir("../../data/mnist") + + wsInit := nn.NewKaimingUniformInit() + ws := wsInit.InitTensor([]int64{10, 784}, gotch.CPU).MustT(true) + + bound := 1.0 / math.Sqrt(float64(784)) + bsInit := nn.NewUniformInit(-bound, bound) + bs := bsInit.InitTensor([]int64{10}, gotch.CPU) + + for i := 0; i < 2000; i++ { + mul := ds.TrainImages.MustMatMul(ws, false) + logits := mul.MustAdd(bs, true) + loss := logits.AccuracyForLogits(ds.TrainLabels) + + fmt.Printf("Epoch %v\t Loss: %.3f\n", i, loss.Values()[0]) + loss.MustDrop() + } + +} + +func withSeq() { + seq := nn.Seq() + vs := nn.NewVarStore(gotch.CPU) + // seq.Add(nn.NewLinear(vs.Root(), 784, 10, *nn.DefaultLinearConfig())) + seq.Add(nn.NewLinear(vs.Root(), 784, 128, *nn.DefaultLinearConfig())) + seq.Add(nn.NewLinear(vs.Root(), 128, 10, *nn.DefaultLinearConfig())) + + opt, err := nn.DefaultAdamConfig().Build(vs, 1e-2) + if err != nil { + log.Fatal(err) + } + + ds := vision.LoadMNISTDir("../../data/mnist") + + for i := 0; i < 2000; i++ { + logits := seq.Forward(ds.TrainImages) + loss := logits.CrossEntropyForLogits(ds.TrainLabels) + opt.BackwardStep(loss) + + testLogits := seq.Forward(ds.TestImages) + testAccuracy := testLogits.AccuracyForLogits(ds.TestLabels) + + fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", i, loss.Values()[0], testAccuracy.Values()[0]*100) + + loss.MustDrop() + testAccuracy.MustDrop() + } + +} + +func noSeq2Layers() { + ds := vision.LoadMNISTDir("../../data/mnist") + + wsInit := nn.NewKaimingUniformInit() + ws1 := wsInit.InitTensor([]int64{1024, 784}, gotch.CPU).MustT(true) + ws2 := wsInit.InitTensor([]int64{10, 1024}, gotch.CPU).MustT(true) + + bound1 := 1.0 / math.Sqrt(float64(784)) + bsInit1 := nn.NewUniformInit(-bound1, bound1) + bs1 := bsInit1.InitTensor([]int64{1024}, gotch.CPU) + + bound2 := 1.0 / math.Sqrt(float64(1024)) + bsInit2 := nn.NewUniformInit(-bound2, bound2) + bs2 := bsInit2.InitTensor([]int64{10}, gotch.CPU) + + for i := 0; i < 2000; i++ { + mul1 := ds.TrainImages.MustMatMul(ws1, false) + out1 := mul1.MustAdd(bs1, true) + + mul2 := out1.MustMatMul(ws2, true) + logits := mul2.MustAdd(bs2, true) + + loss := logits.AccuracyForLogits(ds.TrainLabels) + + fmt.Printf("Epoch %v\t Loss: %.3f\n", i, loss.Values()[0]) + loss.MustDrop() + } +} + +func seqNoVarStore() { + + ds := vision.LoadMNISTDir("../../data/mnist") + + wsInit := nn.NewKaimingUniformInit() + ws1 := wsInit.InitTensor([]int64{1024, 784}, gotch.CPU).MustT(true) + ws2 := wsInit.InitTensor([]int64{10, 1024}, gotch.CPU).MustT(true) + + bound1 := 1.0 / math.Sqrt(float64(784)) + bsInit1 := nn.NewUniformInit(-bound1, bound1) + bs1 := bsInit1.InitTensor([]int64{1024}, gotch.CPU) + + bound2 := 1.0 / math.Sqrt(float64(1024)) + bsInit2 := nn.NewUniformInit(-bound2, bound2) + bs2 := bsInit2.InitTensor([]int64{10}, gotch.CPU) + + l1 := Linear{&ws1, &bs1} + l2 := Linear{&ws2, &bs2} + + seq := Seq() + seq.Add(l1) + seq.Add(l2) + // seq.Add1(l1) + // seq.Add2(l2) + + for i := 0; i < 2000; i++ { + logits := seq.Forward(ds.TrainImages) + + logits.MustDrop() + } + +} + +type Linear struct { + Ws *ts.Tensor + Bs *ts.Tensor +} + +func (l Linear) Forward(xs ts.Tensor) ts.Tensor { + mul := xs.MustMatMul(*l.Ws, false) + return mul.MustAdd(*l.Bs, true) +} + +type Sequential struct { + layers []ts.Module + l1 ts.Module + l2 ts.Module +} + +func Seq() Sequential { + return Sequential{layers: make([]ts.Module, 0)} +} + +// Len returns number of sub-layers embedded in this layer +func (s *Sequential) Len() (retVal int64) { + return int64(len(s.layers)) +} + +// IsEmpty returns true if this layer does not have any sub-layers. +func (s *Sequential) IsEmpty() (retVal bool) { + return len(s.layers) == 0 +} + +// Add appends a layer after all the current layers. +func (s *Sequential) Add(l ts.Module) { + + s.layers = append(s.layers, l) +} + +func (s *Sequential) Add1(l ts.Module) { + s.l1 = l +} + +func (s *Sequential) Add2(l ts.Module) { + s.l2 = l +} + +func (s Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) { + if s.IsEmpty() { + return xs.MustShallowClone() + } + + // forward sequentially + outs := make([]ts.Tensor, len(s.layers)) + for i := 0; i < len(s.layers); i++ { + if i == 0 { + outs[0] = s.layers[i].Forward(xs) + defer outs[0].MustDrop() + } else if i == len(s.layers)-1 { + return s.layers[i].Forward(outs[i-1]) + } else { + outs[i+1] = s.layers[i].Forward(outs[i-1]) + defer outs[i+1].MustDrop() + } + } + + return + + // out1 := s.l1.Forward(xs) + // defer out1.MustDrop() + // + // return s.l2.Forward(out1) + +} diff --git a/nn/linear.go b/nn/linear.go index 3664529..f8a2358 100644 --- a/nn/linear.go +++ b/nn/linear.go @@ -56,7 +56,7 @@ func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) *Linear { } return &Linear{ - Ws: vs.NewVar("weight", []int64{outDim, inDim}, c.WsInit), + Ws: vs.NewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false), Bs: bs, } } @@ -90,7 +90,7 @@ func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) *Linear { // 1 1 1 // 1 1 1 ] func (l *Linear) Forward(xs ts.Tensor) (retVal ts.Tensor) { - // TODO: measure memory leak here. - mul := xs.MustMatMul(l.Ws.MustT(false), false) + + mul := xs.MustMatMul(l.Ws, false) return mul.MustAdd(l.Bs, true) } diff --git a/nn/sequential.go b/nn/sequential.go index e21901c..220d0a6 100644 --- a/nn/sequential.go +++ b/nn/sequential.go @@ -72,18 +72,28 @@ func WithUint8(n uint8) func() uint8 { // Implement Module interface for Sequential: // ========================================== -func (s *Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) { + +// Forward implements Module interface for Sequential +func (s Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) { if s.IsEmpty() { return xs.MustShallowClone() } // forward sequentially + outs := make([]ts.Tensor, len(s.layers)) for i := 0; i < len(s.layers); i++ { - // xs = s.layers[i].Forward(xs) - xs = xs.Apply(s.layers[i]) + if i == 0 { + outs[0] = s.layers[i].Forward(xs) + defer outs[0].MustDrop() + } else if i == len(s.layers)-1 { + return s.layers[i].Forward(outs[i-1]) + } else { + outs[i+1] = s.layers[i].Forward(outs[i-1]) + defer outs[i+1].MustDrop() + } } - return xs + return } // SequentialT is a sequential layer combining new layers with support for a training mode. diff --git a/tensor/other.go b/tensor/other.go index 5153091..b1845ff 100644 --- a/tensor/other.go +++ b/tensor/other.go @@ -8,13 +8,18 @@ import ( // CrossEntropyForLogits computes the cross-entropy loss based on some logits and targets. func (ts Tensor) CrossEntropyForLogits(targets Tensor) (retVal Tensor) { - return ts.MustLogSoftmax(-1, gotch.Float.CInt(), true).MustNllLoss(targets, true) + // return ts.MustLogSoftmax(-1, gotch.Float.CInt(), true).MustNllLoss(targets, true) + + logSm := ts.MustLogSoftmax(-1, gotch.Float.CInt(), true) + return logSm.MustNllLoss(targets, true) } // AccuracyForLogits returns the average accuracy for some given logits assuming that // targets represent ground-truth. func (ts Tensor) AccuracyForLogits(targets Tensor) (retVal Tensor) { - return ts.MustArgmax(-1, false, true).MustEq1(targets).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true) + argmax := ts.MustArgmax(-1, false, true) + eq1 := argmax.MustEq1(targets, true) + return eq1.MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true) } func (ts Tensor) MaxPool2DDefault(ksize int64, del bool) (retVal Tensor) { diff --git a/tensor/tensor.go b/tensor/tensor.go index 05fe028..5faa2bc 100644 --- a/tensor/tensor.go +++ b/tensor/tensor.go @@ -295,12 +295,14 @@ func (ts Tensor) Device() (retVal gotch.Device, err error) { return device.OfCInt(int32(cInt)), nil } -func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) { +func (ts Tensor) Eq1(other Tensor, del bool) (retVal Tensor, err error) { // Get a C null pointer // https://stackoverflow.com/a/2022369 ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgEq1(ptr, ts.ctensor, other.ctensor) if err = TorchErr(); err != nil { @@ -311,8 +313,8 @@ func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) { } -func (ts Tensor) MustEq1(other Tensor) (retVal Tensor) { - retVal, err := ts.Eq1(other) +func (ts Tensor) MustEq1(other Tensor, del bool) (retVal Tensor) { + retVal, err := ts.Eq1(other, del) if err != nil { log.Fatal(err) }