From 613cd934436b55217a73711f3b479fd39df3bd61 Mon Sep 17 00:00:00 2001 From: sugarme Date: Wed, 17 Jun 2020 11:23:00 +1000 Subject: [PATCH] feat(example/mnist): get works when printing out tensor --- example/linear-regression/main.go | 48 ++++++++++++++++ example/mnist/linear.go | 92 ++++++++++++++++++++----------- nn/data.go | 10 +--- tensor/tensor.go | 1 + 4 files changed, 111 insertions(+), 40 deletions(-) create mode 100644 example/linear-regression/main.go diff --git a/example/linear-regression/main.go b/example/linear-regression/main.go new file mode 100644 index 0000000..b9c1cc6 --- /dev/null +++ b/example/linear-regression/main.go @@ -0,0 +1,48 @@ +package main + +import ( + "fmt" + "log" + + "github.com/sugarme/gotch" + ts "github.com/sugarme/gotch/tensor" +) + +func main() { + + // mockup data + var ( + n int = 20 + xvals []float32 + yvals []float32 + epochs = 10 + ) + + for i := 0; i < n; i++ { + xvals = append(xvals, float32(i)) + yvals = append(yvals, float32(2*i+1)) + } + + xtrain, err := ts.NewTensorFromData(xvals, []int64{int64(n), 1}) + if err != nil { + log.Fatal(err) + } + ytrain, err := ts.NewTensorFromData(yvals, []int64{int64(n), 1}) + if err != nil { + log.Fatal(err) + } + + ws := ts.MustZeros([]int64{1, int64(n)}, gotch.Float.CInt(), gotch.CPU.CInt()) + bs := ts.MustZeros([]int64{1, int64(n)}, gotch.Float.CInt(), gotch.CPU.CInt()) + + for epoch := 0; epoch < epochs; epoch++ { + + logit := ws.MustMatMul(xtrain).MustAdd(bs) + loss := ts.NewTensor().MustLogSoftmax(-1, gotch.Float.CInt()) + + ws.MustGrad() + bs.MustGrad() + loss.MustBackward() + + } +} diff --git a/example/mnist/linear.go b/example/mnist/linear.go index 5bcd0f9..b29e792 100644 --- a/example/mnist/linear.go +++ b/example/mnist/linear.go @@ -13,8 +13,10 @@ const ( Label int64 = 10 MnistDir string = "../../data/mnist" - epochs = 100 - batchSize = 256 + // epochs = 500 + // batchSize = 256 + epochs = 200 + batchSize = 60000 ) func runLinear() { @@ -33,51 +35,77 @@ func runLinear() { bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true) for epoch := 0; epoch < epochs; epoch++ { - - var loss ts.Tensor - trainIter := ds.TrainIter(batchSize) - trainIter.Shuffle().ToDevice(gotch.CPU) - // item a pair of images and labels as 2 tensors - for { - batch, ok := trainIter.Next() - if !ok { - break - } - - logits := batch.Images.MustMm(ws).MustAdd(bs) - loss = logits.MustLogSoftmax(-1, dtype).MustNllLoss(batch.Labels) - ws.ZeroGrad() - bs.ZeroGrad() - loss.Backward() - - ts.NoGrad(func() { - ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0))) - bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0))) - }) - } - /* - * logits := ds.TrainImages.MustMm(ws).MustAdd(bs) - * loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels) + * totalSize := ds.TrainImages.MustSize()[0] + * samples := int(totalSize) + * index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU) + * imagesTs := ds.TrainImages.MustIndexSelect(0, index) + * labelsTs := ds.TrainLabels.MustIndexSelect(0, index) + * + * batches := samples / batchSize + * batchIndex := 0 + * for i := 0; i < batches; i++ { + * start := batchIndex * batchSize + * size := batchSize + * if samples-start < batchSize { + * // size = samples - start + * break + * } + * batchIndex += 1 + * + * // Indexing + * narrowIndex := ts.NewNarrow(int64(start), int64(start+size)) + * // bImages := ds.TrainImages.Idx(narrowIndex) + * // bLabels := ds.TrainLabels.Idx(narrowIndex) + * bImages := imagesTs.Idx(narrowIndex) + * bLabels := labelsTs.Idx(narrowIndex) + * + * logits := bImages.MustMm(ws).MustAdd(bs) + * // loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels) + * loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels) * * ws.ZeroGrad() * bs.ZeroGrad() * loss.Backward() * + * bs.MustGrad().Print() + * * ts.NoGrad(func() { * ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0))) * bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0))) * }) - * loss.Print() + * } + * + * imagesTs.MustDrop() + * labelsTs.MustDrop() * */ - // bs.MustGrad().Print() + logits := ds.TrainImages.MustMm(ws).MustAdd(bs) + // loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true) + loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels) + // loss := ds.TrainImages.MustMm(ws).MustAdd(bs).MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true) + + ws.ZeroGrad() + bs.ZeroGrad() + // loss.MustBackward() + loss.Backward() + + // TODO: why `loss` need to print out to get updated? + fmt.Printf("loss (epoch %v): %v\n", epoch, loss.MustToString(0)) + // fmt.Printf("bs grad (epoch %v): %v\n", epoch, bs.MustGrad().MustToString(1)) + + ts.NoGrad(func() { + ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0))) + bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0))) + }) + + // fmt.Printf("bs(epoch %v): \n%v\n", epoch, bs.MustToString(1)) + // fmt.Printf("ws mean(epoch %v): \n%v\n", epoch, ws.MustMean(gotch.Float.CInt()).MustToString(1)) testLogits := ds.TestImages.MustMm(ws).MustAdd(bs) testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0}) - + // testAccuracy := ds.TestImages.MustMm(ws).MustAdd(bs).MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0}) + // fmt.Printf("Epoch: %v - Test accuracy: %v\n", epoch, testAccuracy*100) - - // fmt.Printf("Epoch: %v - Train loss: %v - Test accuracy: %v\n", epoch, loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}), testAccuracy*100) } } diff --git a/nn/data.go b/nn/data.go index 49af887..96a362f 100644 --- a/nn/data.go +++ b/nn/data.go @@ -124,15 +124,9 @@ func (it *Iter2) Next() (item Iter2Item, ok bool) { // Indexing narrowIndex := ts.NewNarrow(start, start+size) - // ts1 := it.xs.Idx(narrowIndex).MustTo(it.device) - // ts2 := it.ys.Idx(narrowIndex).MustTo(it.device) - - ts1 := it.xs.Idx(narrowIndex) - ts2 := it.ys.Idx(narrowIndex) - return Iter2Item{ - Images: ts1, - Labels: ts2, + Images: it.xs.Idx(narrowIndex), + Labels: it.ys.Idx(narrowIndex), }, true } } diff --git a/tensor/tensor.go b/tensor/tensor.go index 2f2359c..8ac6790 100644 --- a/tensor/tensor.go +++ b/tensor/tensor.go @@ -825,6 +825,7 @@ func (ts Tensor) ToString(lw int64) (retVal string, err error) { // MustToString returns a string representation for the tensor. It will be panic // if error. +// lw : line width (size) func (ts Tensor) MustToString(lw int64) (retVal string) { retVal, err := ts.ToString(lw) if err != nil {