diff --git a/example/mnist/cnn.go b/example/mnist/cnn.go index baf72bd..88ab38f 100644 --- a/example/mnist/cnn.go +++ b/example/mnist/cnn.go @@ -3,6 +3,7 @@ package main import ( "fmt" "log" + "time" "github.com/sugarme/gotch" "github.com/sugarme/gotch/nn" @@ -13,8 +14,9 @@ import ( const ( MnistDirCNN string = "../../data/mnist" - epochsCNN = 10 + epochsCNN = 100 batchCNN = 256 + batchSize = 256 LrCNN = 1e-4 ) @@ -39,20 +41,47 @@ func newNet(vs *nn.Path) Net { *fc2} } -func (n Net) ForwardT(xs ts.Tensor, train bool) ts.Tensor { - out := xs.MustView([]int64{-1, 1, 28, 28}).Apply(n.conv1).MaxPool2DDefault(2, true) - out = out.Apply(n.conv2).MaxPool2DDefault(2, true) - out = out.MustView([]int64{-1, 1024}).Apply(&n.fc1).MustRelu(true) - out.Dropout_(0.5, train) - return out.Apply(&n.fc2) +func (n Net) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) { + outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false) + defer outView1.MustDrop() + + outC1 := outView1.Apply(n.conv1) + // defer outC1.MustDrop() + + outMP1 := outC1.MaxPool2DDefault(2, true) + defer outMP1.MustDrop() + + outC2 := outMP1.Apply(n.conv2) + // defer outC2.MustDrop() + + outMP2 := outC2.MaxPool2DDefault(2, true) + // defer outMP2.MustDrop() + + outView2 := outMP2.MustView([]int64{-1, 1024}, true) + defer outView2.MustDrop() + + outFC1 := outView2.Apply(&n.fc1) + // defer outFC1.MustDrop() + + outRelu := outFC1.MustRelu(true) + defer outRelu.MustDrop() + // outRelu.Dropout_(0.5, train) + outDropout := ts.MustDropout(outRelu, 0.5, train) + defer outDropout.MustDrop() + + return outDropout.Apply(&n.fc2) + } func runCNN() { var ds vision.Dataset ds = vision.LoadMNISTDir(MnistDirNN) + // testImages := ds.TestImages + // testLabels := ds.TestLabels cuda := gotch.CudaBuilder(0) vs := nn.NewVarStore(cuda.CudaIfAvailable()) + // vs := nn.NewVarStore(gotch.CPU) path := vs.Root() net := newNet(&path) opt, err := nn.DefaultAdamConfig().Build(vs, LrNN) @@ -60,28 +89,61 @@ func runCNN() { log.Fatal(err) } - for epoch := 0; epoch < epochsCNN; epoch++ { - var count = 0 - for { - iter := ds.TrainIter(batchCNN).Shuffle() - item, ok := iter.Next() - if !ok { - break - } + startTime := time.Now() - loss := net.ForwardT(item.Data.MustTo(vs.Device(), true), true).CrossEntropyForLogits(item.Label.MustTo(vs.Device(), true)) - opt.BackwardStep(loss) - loss.MustDrop() - count++ - if count == 50 { + for epoch := 0; epoch < epochsCNN; epoch++ { + + totalSize := ds.TrainImages.MustSize()[0] + samples := int(totalSize) + // index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU) + // imagesTs := ds.TrainImages.MustIndexSelect(0, index, false) + // labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false) + + batches := samples / batchSize + batchIndex := 0 + var epocLoss ts.Tensor + // var loss ts.Tensor + for i := 0; i < batches; i++ { + start := batchIndex * batchSize + size := batchSize + if samples-start < batchSize { + // size = samples - start break } - fmt.Printf("completed \t %v batches\n", count) + batchIndex += 1 + + // Indexing + narrowIndex := ts.NewNarrow(int64(start), int64(start+size)) + bImages := ds.TrainImages.Idx(narrowIndex) + bLabels := ds.TrainLabels.Idx(narrowIndex) + // bImages := imagesTs.Idx(narrowIndex) + // bLabels := labelsTs.Idx(narrowIndex) + + bImages = bImages.MustTo(vs.Device(), true) + bLabels = bLabels.MustTo(vs.Device(), true) + + logits := net.ForwardT(bImages, true) + loss := logits.CrossEntropyForLogits(bLabels) + + opt.BackwardStep(loss) + + epocLoss = loss.MustShallowClone() + epocLoss.Detach_() + + // fmt.Printf("completed \t %v batches\t %.2f\n", i, loss.Values()[0]) + + bImages.MustDrop() + bLabels.MustDrop() + // logits.MustDrop() + loss.MustDrop() } - // testAccuracy := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 1024) - // + // testAccuracy := ts.BatchAccuracyForLogits(net, testImages, testLabels, vs.Device(), 1024) // fmt.Printf("Epoch: %v \t Test accuracy: %.2f%%\n", epoch, testAccuracy*100) + + fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, epocLoss.Values()[0]) + epocLoss.MustDrop() } + fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes()) } diff --git a/example/mnist/linear.go b/example/mnist/linear.go index 3e60182..86be1a0 100644 --- a/example/mnist/linear.go +++ b/example/mnist/linear.go @@ -41,7 +41,7 @@ func runLinear() { }) testLogits := ds.TestImages.MustMm(ws, false).MustAdd(bs, true) - testAccuracy := testLogits.MustArgmax(-1, false, true).MustEq1(ds.TestLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true).MustView([]int64{-1}).MustFloat64Value([]int64{0}) + testAccuracy := testLogits.MustArgmax(-1, false, true).MustEq1(ds.TestLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0}) fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100) diff --git a/tensor/module.go b/tensor/module.go index 2ff2f60..4bb457d 100644 --- a/tensor/module.go +++ b/tensor/module.go @@ -72,7 +72,7 @@ func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize break } - acc := m.ForwardT(item.Data.MustTo(d, true), false).AccuracyForLogits(item.Label.MustTo(d, true)).MustView([]int64{-1}).MustFloat64Value([]int64{0}) + acc := m.ForwardT(item.Data.MustTo(d, true), false).AccuracyForLogits(item.Label.MustTo(d, true)).MustView([]int64{-1}, false).MustFloat64Value([]int64{0}) size := float64(item.Data.MustSize()[0]) sumAccuracy += acc * size sampleCount += size diff --git a/tensor/tensor-generated-sample.go b/tensor/tensor-generated-sample.go index 7c88a56..d5d0c1d 100644 --- a/tensor/tensor-generated-sample.go +++ b/tensor/tensor-generated-sample.go @@ -678,9 +678,11 @@ func (ts Tensor) MustMean(dtype int32, del bool) (retVal Tensor) { return retVal } -func (ts Tensor) View(sizeData []int64) (retVal Tensor, err error) { +func (ts Tensor) View(sizeData []int64, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgView(ptr, ts.ctensor, sizeData, len(sizeData)) if err = TorchErr(); err != nil { @@ -692,8 +694,8 @@ func (ts Tensor) View(sizeData []int64) (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) MustView(sizeData []int64) (retVal Tensor) { - retVal, err := ts.View(sizeData) +func (ts Tensor) MustView(sizeData []int64, del bool) (retVal Tensor) { + retVal, err := ts.View(sizeData, del) if err != nil { log.Fatal(err) } diff --git a/tensor/tensor.go b/tensor/tensor.go index 5faa2bc..7c66b87 100644 --- a/tensor/tensor.go +++ b/tensor/tensor.go @@ -993,5 +993,5 @@ func (r Reduction) ToInt() (retVal int) { func (ts Tensor) Values() []float64 { clone := ts.MustShallowClone() clone.Detach_() - return []float64{clone.MustView([]int64{-1}).MustFloat64Value([]int64{-1})} + return []float64{clone.MustView([]int64{-1}, true).MustFloat64Value([]int64{-1})} } diff --git a/vision/mnist.go b/vision/mnist.go index c326a6c..a81b22a 100644 --- a/vision/mnist.go +++ b/vision/mnist.go @@ -125,7 +125,7 @@ func readImages(filename string) (retVal ts.Tensor) { err = fmt.Errorf("create images tensor err.") log.Fatal(err) } - retVal = imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}).MustTotype(gotch.Float, true).MustDiv1(ts.FloatScalar(255.0), true) + retVal = imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}, true).MustTotype(gotch.Float, true).MustDiv1(ts.FloatScalar(255.0), true) return retVal }