gotch/example/mnist/cnn.go

212 lines
5.1 KiB
Go
Raw Normal View History

2020-06-18 08:14:48 +01:00
package main
import (
"fmt"
2020-06-22 16:07:07 +01:00
"log"
2020-06-23 10:14:08 +01:00
"time"
2020-06-22 16:07:07 +01:00
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
"github.com/sugarme/gotch/vision"
)
const (
MnistDirCNN string = "../../data/mnist"
2020-06-23 10:14:08 +01:00
epochsCNN = 100
2020-06-22 16:07:07 +01:00
batchCNN = 256
2020-06-23 10:14:08 +01:00
batchSize = 256
2020-06-22 16:07:07 +01:00
LrCNN = 1e-4
2020-06-18 08:14:48 +01:00
)
2020-06-22 16:07:07 +01:00
type Net struct {
conv1 nn.Conv2D
conv2 nn.Conv2D
fc1 nn.Linear
fc2 nn.Linear
}
func newNet(vs *nn.Path) Net {
conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())
conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())
fc1 := nn.NewLinear(*vs, 1024, 1024, *nn.DefaultLinearConfig())
fc2 := nn.NewLinear(*vs, 1024, 10, *nn.DefaultLinearConfig())
return Net{
conv1,
conv2,
*fc1,
*fc2}
}
2020-06-23 10:14:08 +01:00
func (n Net) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false)
defer outView1.MustDrop()
outC1 := outView1.Apply(n.conv1)
// defer outC1.MustDrop()
outMP1 := outC1.MaxPool2DDefault(2, true)
defer outMP1.MustDrop()
outC2 := outMP1.Apply(n.conv2)
// defer outC2.MustDrop()
outMP2 := outC2.MaxPool2DDefault(2, true)
// defer outMP2.MustDrop()
outView2 := outMP2.MustView([]int64{-1, 1024}, true)
defer outView2.MustDrop()
outFC1 := outView2.Apply(&n.fc1)
// defer outFC1.MustDrop()
outRelu := outFC1.MustRelu(true)
defer outRelu.MustDrop()
// outRelu.Dropout_(0.5, train)
outDropout := ts.MustDropout(outRelu, 0.5, train)
defer outDropout.MustDrop()
return outDropout.Apply(&n.fc2)
2020-06-22 16:07:07 +01:00
}
2020-06-23 16:37:33 +01:00
func runCNN1() {
2020-06-22 16:07:07 +01:00
var ds vision.Dataset
ds = vision.LoadMNISTDir(MnistDirNN)
2020-06-23 16:37:33 +01:00
testImages := ds.TestImages
testLabels := ds.TestLabels
2020-06-22 16:07:07 +01:00
cuda := gotch.CudaBuilder(0)
vs := nn.NewVarStore(cuda.CudaIfAvailable())
2020-06-23 10:14:08 +01:00
// vs := nn.NewVarStore(gotch.CPU)
2020-06-22 16:07:07 +01:00
path := vs.Root()
net := newNet(&path)
opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)
2020-06-22 16:07:07 +01:00
if err != nil {
log.Fatal(err)
}
2020-06-23 10:14:08 +01:00
startTime := time.Now()
2020-06-22 16:07:07 +01:00
for epoch := 0; epoch < epochsCNN; epoch++ {
2020-06-23 10:14:08 +01:00
totalSize := ds.TrainImages.MustSize()[0]
samples := int(totalSize)
2020-06-23 16:37:33 +01:00
index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)
2020-06-23 10:14:08 +01:00
batches := samples / batchSize
batchIndex := 0
var epocLoss ts.Tensor
// var loss ts.Tensor
for i := 0; i < batches; i++ {
start := batchIndex * batchSize
size := batchSize
if samples-start < batchSize {
// size = samples - start
2020-06-22 16:07:07 +01:00
break
}
2020-06-23 10:14:08 +01:00
batchIndex += 1
// Indexing
narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
2020-06-23 16:37:33 +01:00
// bImages := ds.TrainImages.Idx(narrowIndex)
// bLabels := ds.TrainLabels.Idx(narrowIndex)
bImages := imagesTs.Idx(narrowIndex)
bLabels := labelsTs.Idx(narrowIndex)
2020-06-23 10:14:08 +01:00
bImages = bImages.MustTo(vs.Device(), true)
bLabels = bLabels.MustTo(vs.Device(), true)
logits := net.ForwardT(bImages, true)
loss := logits.CrossEntropyForLogits(bLabels)
2020-06-22 16:07:07 +01:00
opt.BackwardStep(loss)
2020-06-23 10:14:08 +01:00
epocLoss = loss.MustShallowClone()
epocLoss.Detach_()
// fmt.Printf("completed \t %v batches\t %.2f\n", i, loss.Values()[0])
bImages.MustDrop()
bLabels.MustDrop()
// logits.MustDrop()
// loss.MustDrop()
2020-06-22 16:07:07 +01:00
}
2020-06-23 16:37:33 +01:00
// testAccuracy := ts.BatchAccuracyForLogitsIdx(net, testImages, testLabels, vs.Device(), 1024)
// fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Values()[0], testAccuracy*100)
2020-06-23 10:14:08 +01:00
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, epocLoss.Values()[0])
epocLoss.MustDrop()
2020-06-23 16:37:33 +01:00
imagesTs.MustDrop()
labelsTs.MustDrop()
}
testAccuracy := ts.BatchAccuracyForLogitsIdx(net, testImages, testLabels, vs.Device(), 1024)
fmt.Printf("Test accuracy: %.2f%%\n", testAccuracy*100)
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
}
func runCNN2() {
var ds vision.Dataset
ds = vision.LoadMNISTDir(MnistDirNN)
cuda := gotch.CudaBuilder(0)
vs := nn.NewVarStore(cuda.CudaIfAvailable())
path := vs.Root()
net := newNet(&path)
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
if err != nil {
log.Fatal(err)
}
startTime := time.Now()
var lossVal float64
for epoch := 0; epoch < epochsCNN; epoch++ {
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchCNN)
// iter.Shuffle()
for {
item, ok := iter.Next()
if !ok {
break
}
bImages := item.Data.MustTo(vs.Device(), true)
bLabels := item.Label.MustTo(vs.Device(), true)
// _ = ts.MustGradSetEnabled(true)
2020-06-23 16:37:33 +01:00
logits := net.ForwardT(bImages, true)
loss := logits.CrossEntropyForLogits(bLabels)
opt.BackwardStep(loss)
lossVal = loss.Values()[0]
bImages.MustDrop()
bLabels.MustDrop()
loss.MustDrop()
}
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, lossVal)
2020-06-23 16:37:33 +01:00
// testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN)
// fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\t Accuracy: %.2f\n", epoch, lossVal, testAcc*100)
2020-06-22 16:07:07 +01:00
}
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN)
fmt.Printf("Loss: \t %.2f\t Accuracy: %.2f\n", lossVal, testAcc*100)
2020-06-23 10:14:08 +01:00
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
2020-06-18 08:14:48 +01:00
}