From 31a3f0e5879a60d24b6d915952340a1af15d1bdc Mon Sep 17 00:00:00 2001 From: sugarme Date: Wed, 24 Jun 2020 01:37:33 +1000 Subject: [PATCH] feat(example/mnis): conv --- example/mnist/cnn.go | 83 ++++++++++++++++++++++++++++++++++++------- example/mnist/main.go | 2 +- tensor/data.go | 4 +-- tensor/module.go | 75 ++++++++++++++++++++++++++++++++++++-- 4 files changed, 147 insertions(+), 17 deletions(-) diff --git a/example/mnist/cnn.go b/example/mnist/cnn.go index 88ab38f..6db47c9 100644 --- a/example/mnist/cnn.go +++ b/example/mnist/cnn.go @@ -73,12 +73,13 @@ func (n Net) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) { } -func runCNN() { +func runCNN1() { var ds vision.Dataset ds = vision.LoadMNISTDir(MnistDirNN) - // testImages := ds.TestImages - // testLabels := ds.TestLabels + testImages := ds.TestImages + testLabels := ds.TestLabels + cuda := gotch.CudaBuilder(0) vs := nn.NewVarStore(cuda.CudaIfAvailable()) // vs := nn.NewVarStore(gotch.CPU) @@ -95,9 +96,9 @@ func runCNN() { totalSize := ds.TrainImages.MustSize()[0] samples := int(totalSize) - // index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU) - // imagesTs := ds.TrainImages.MustIndexSelect(0, index, false) - // labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false) + index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU) + imagesTs := ds.TrainImages.MustIndexSelect(0, index, false) + labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false) batches := samples / batchSize batchIndex := 0 @@ -114,10 +115,10 @@ func runCNN() { // Indexing narrowIndex := ts.NewNarrow(int64(start), int64(start+size)) - bImages := ds.TrainImages.Idx(narrowIndex) - bLabels := ds.TrainLabels.Idx(narrowIndex) - // bImages := imagesTs.Idx(narrowIndex) - // bLabels := labelsTs.Idx(narrowIndex) + // bImages := ds.TrainImages.Idx(narrowIndex) + // bLabels := ds.TrainLabels.Idx(narrowIndex) + bImages := imagesTs.Idx(narrowIndex) + bLabels := labelsTs.Idx(narrowIndex) bImages = bImages.MustTo(vs.Device(), true) bLabels = bLabels.MustTo(vs.Device(), true) @@ -138,11 +139,69 @@ func runCNN() { loss.MustDrop() } - // testAccuracy := ts.BatchAccuracyForLogits(net, testImages, testLabels, vs.Device(), 1024) - // fmt.Printf("Epoch: %v \t Test accuracy: %.2f%%\n", epoch, testAccuracy*100) + // testAccuracy := ts.BatchAccuracyForLogitsIdx(net, testImages, testLabels, vs.Device(), 1024) + // fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Values()[0], testAccuracy*100) fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, epocLoss.Values()[0]) epocLoss.MustDrop() + imagesTs.MustDrop() + labelsTs.MustDrop() + } + + testAccuracy := ts.BatchAccuracyForLogitsIdx(net, testImages, testLabels, vs.Device(), 1024) + fmt.Printf("Test accuracy: %.2f%%\n", testAccuracy*100) + + fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes()) +} + +func runCNN2() { + + var ds vision.Dataset + ds = vision.LoadMNISTDir(MnistDirNN) + + cuda := gotch.CudaBuilder(0) + vs := nn.NewVarStore(cuda.CudaIfAvailable()) + path := vs.Root() + net := newNet(&path) + opt, err := nn.DefaultAdamConfig().Build(vs, LrNN) + if err != nil { + log.Fatal(err) + } + + startTime := time.Now() + + var lossVal float64 + for epoch := 0; epoch < epochsCNN; epoch++ { + + iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchCNN) + // iter.Shuffle() + + for { + item, ok := iter.Next() + if !ok { + break + } + + bImages := item.Data.MustTo(vs.Device(), true) + bLabels := item.Label.MustTo(vs.Device(), true) + + _ = ts.MustGradSetEnabled(true) + + logits := net.ForwardT(bImages, true) + loss := logits.CrossEntropyForLogits(bLabels) + + opt.BackwardStep(loss) + + lossVal = loss.Values()[0] + + bImages.MustDrop() + bLabels.MustDrop() + loss.MustDrop() + } + + testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN) + + fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\t Accuracy: %.2f\n", epoch, lossVal, testAcc*100) } fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes()) diff --git a/example/mnist/main.go b/example/mnist/main.go index dd27f15..1064903 100644 --- a/example/mnist/main.go +++ b/example/mnist/main.go @@ -21,7 +21,7 @@ func main() { case "nn": runNN() case "cnn": - runCNN() + runCNN2() default: panic("No specified model to run") } diff --git a/tensor/data.go b/tensor/data.go index 173e35f..633b418 100644 --- a/tensor/data.go +++ b/tensor/data.go @@ -81,12 +81,12 @@ func MustNewIter2(xs, ys Tensor, batchSize int64) (retVal Iter2) { // // The iterator would still run over the whole dataset but the order in // which elements are grouped in mini-batches is randomized. -func (it Iter2) Shuffle() (retVal Iter2) { +func (it *Iter2) Shuffle() { index := MustRandperm(it.totalSize, gotch.Int64, gotch.CPU) it.xs = it.xs.MustIndexSelect(0, index, true) it.ys = it.ys.MustIndexSelect(0, index, true) - return it + } // ToDevice transfers the mini-batches to a specified device. diff --git a/tensor/module.go b/tensor/module.go index 4bb457d..3e54c34 100644 --- a/tensor/module.go +++ b/tensor/module.go @@ -72,15 +72,86 @@ func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize break } - acc := m.ForwardT(item.Data.MustTo(d, true), false).AccuracyForLogits(item.Label.MustTo(d, true)).MustView([]int64{-1}, false).MustFloat64Value([]int64{0}) size := float64(item.Data.MustSize()[0]) - sumAccuracy += acc * size + bImages := item.Data.MustTo(d, true) + bLabels := item.Label.MustTo(d, true) + + logits := m.ForwardT(bImages, false) + acc := logits.AccuracyForLogits(bLabels) + sumAccuracy += acc.Values()[0] * size sampleCount += size + + bImages.MustDrop() + bLabels.MustDrop() + acc.MustDrop() } return sumAccuracy / sampleCount } +// BatchAccuracyForLogitIdx is an alternative of BatchAccuracyForLogits to +// calculate accuracy for specified batch on module weight. It uses tensor +// indexing instead of Iter2 +func BatchAccuracyForLogitsIdx(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize int) (retVal float64) { + var ( + sumAccuracy float64 = 0.0 + sampleCount float64 = 0.0 + ) + + // Switch Grad off + _ = NewNoGradGuard() + + totalSize := xs.MustSize()[0] + samples := int(totalSize) + + index := MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU) + imagesTs := xs.MustIndexSelect(0, index, false) + labelsTs := ys.MustIndexSelect(0, index, false) + + batches := samples / batchSize + batchIndex := 0 + + for i := 0; i < batches; i++ { + start := batchIndex * batchSize + size := batchSize + if samples-start < batchSize { + // size = samples - start + break + } + batchIndex += 1 + + // Indexing + narrowIndex := NewNarrow(int64(start), int64(start+size)) + bImages := imagesTs.Idx(narrowIndex) + bLabels := labelsTs.Idx(narrowIndex) + + bImages = bImages.MustTo(d, true) + bLabels = bLabels.MustTo(d, true) + + logits := m.ForwardT(bImages, true) + bAccuracy := logits.AccuracyForLogits(bLabels) + + accuVal := bAccuracy.Values()[0] + bSamples := float64(xs.MustSize()[0]) + sumAccuracy += accuVal * bSamples + sampleCount += bSamples + + // Free up tensors on C memory + bImages.MustDrop() + bLabels.MustDrop() + // logits.MustDrop() + bAccuracy.MustDrop() + } + + imagesTs.MustDrop() + labelsTs.MustDrop() + + // Switch Grad on + // _ = MustGradSetEnabled(true) + + return sumAccuracy / sampleCount +} + // Tensor methods for Module and ModuleT: // ======================================