2020-06-18 08:14:48 +01:00
|
|
|
package main
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
2020-06-22 16:07:07 +01:00
|
|
|
"log"
|
2020-06-23 10:14:08 +01:00
|
|
|
"time"
|
2020-06-22 16:07:07 +01:00
|
|
|
|
|
|
|
"github.com/sugarme/gotch"
|
|
|
|
"github.com/sugarme/gotch/nn"
|
|
|
|
ts "github.com/sugarme/gotch/tensor"
|
|
|
|
"github.com/sugarme/gotch/vision"
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
MnistDirCNN string = "../../data/mnist"
|
|
|
|
|
2020-06-23 10:14:08 +01:00
|
|
|
epochsCNN = 100
|
2020-06-22 16:07:07 +01:00
|
|
|
batchCNN = 256
|
2020-06-23 10:14:08 +01:00
|
|
|
batchSize = 256
|
2020-06-22 16:07:07 +01:00
|
|
|
|
|
|
|
LrCNN = 1e-4
|
2020-06-18 08:14:48 +01:00
|
|
|
)
|
|
|
|
|
2020-06-22 16:07:07 +01:00
|
|
|
type Net struct {
|
|
|
|
conv1 nn.Conv2D
|
|
|
|
conv2 nn.Conv2D
|
|
|
|
fc1 nn.Linear
|
|
|
|
fc2 nn.Linear
|
|
|
|
}
|
|
|
|
|
2020-07-07 01:40:05 +01:00
|
|
|
func newNet(vs nn.Path) Net {
|
2020-06-22 16:07:07 +01:00
|
|
|
conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())
|
|
|
|
conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())
|
2020-07-07 01:40:05 +01:00
|
|
|
fc1 := nn.NewLinear(vs, 1024, 1024, nn.DefaultLinearConfig())
|
|
|
|
fc2 := nn.NewLinear(vs, 1024, 10, nn.DefaultLinearConfig())
|
2020-06-22 16:07:07 +01:00
|
|
|
|
|
|
|
return Net{
|
|
|
|
conv1,
|
|
|
|
conv2,
|
2020-07-03 02:20:52 +01:00
|
|
|
fc1,
|
|
|
|
fc2}
|
2020-06-22 16:07:07 +01:00
|
|
|
}
|
|
|
|
|
2020-06-23 10:14:08 +01:00
|
|
|
func (n Net) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
|
|
|
outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false)
|
|
|
|
defer outView1.MustDrop()
|
|
|
|
|
|
|
|
outC1 := outView1.Apply(n.conv1)
|
|
|
|
// defer outC1.MustDrop()
|
|
|
|
|
|
|
|
outMP1 := outC1.MaxPool2DDefault(2, true)
|
|
|
|
defer outMP1.MustDrop()
|
|
|
|
|
|
|
|
outC2 := outMP1.Apply(n.conv2)
|
|
|
|
// defer outC2.MustDrop()
|
|
|
|
|
|
|
|
outMP2 := outC2.MaxPool2DDefault(2, true)
|
|
|
|
// defer outMP2.MustDrop()
|
|
|
|
|
|
|
|
outView2 := outMP2.MustView([]int64{-1, 1024}, true)
|
|
|
|
defer outView2.MustDrop()
|
|
|
|
|
|
|
|
outFC1 := outView2.Apply(&n.fc1)
|
|
|
|
// defer outFC1.MustDrop()
|
|
|
|
|
|
|
|
outRelu := outFC1.MustRelu(true)
|
|
|
|
defer outRelu.MustDrop()
|
|
|
|
// outRelu.Dropout_(0.5, train)
|
|
|
|
outDropout := ts.MustDropout(outRelu, 0.5, train)
|
|
|
|
defer outDropout.MustDrop()
|
|
|
|
|
|
|
|
return outDropout.Apply(&n.fc2)
|
|
|
|
|
2020-06-22 16:07:07 +01:00
|
|
|
}
|
|
|
|
|
2020-06-23 16:37:33 +01:00
|
|
|
func runCNN1() {
|
2020-06-22 16:07:07 +01:00
|
|
|
|
|
|
|
var ds vision.Dataset
|
|
|
|
ds = vision.LoadMNISTDir(MnistDirNN)
|
2020-06-23 16:37:33 +01:00
|
|
|
testImages := ds.TestImages
|
|
|
|
testLabels := ds.TestLabels
|
|
|
|
|
2020-06-22 16:07:07 +01:00
|
|
|
cuda := gotch.CudaBuilder(0)
|
|
|
|
vs := nn.NewVarStore(cuda.CudaIfAvailable())
|
2020-06-23 10:14:08 +01:00
|
|
|
// vs := nn.NewVarStore(gotch.CPU)
|
2020-07-07 01:40:05 +01:00
|
|
|
net := newNet(vs.Root())
|
2020-06-24 03:47:10 +01:00
|
|
|
opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)
|
2020-06-22 16:07:07 +01:00
|
|
|
if err != nil {
|
|
|
|
log.Fatal(err)
|
|
|
|
}
|
|
|
|
|
2020-06-23 10:14:08 +01:00
|
|
|
startTime := time.Now()
|
|
|
|
|
2020-06-22 16:07:07 +01:00
|
|
|
for epoch := 0; epoch < epochsCNN; epoch++ {
|
2020-06-23 10:14:08 +01:00
|
|
|
|
|
|
|
totalSize := ds.TrainImages.MustSize()[0]
|
|
|
|
samples := int(totalSize)
|
2020-06-23 16:37:33 +01:00
|
|
|
index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
|
|
|
|
imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
|
|
|
|
labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)
|
2020-06-23 10:14:08 +01:00
|
|
|
|
|
|
|
batches := samples / batchSize
|
|
|
|
batchIndex := 0
|
|
|
|
var epocLoss ts.Tensor
|
|
|
|
// var loss ts.Tensor
|
|
|
|
for i := 0; i < batches; i++ {
|
|
|
|
start := batchIndex * batchSize
|
|
|
|
size := batchSize
|
|
|
|
if samples-start < batchSize {
|
|
|
|
// size = samples - start
|
2020-06-22 16:07:07 +01:00
|
|
|
break
|
|
|
|
}
|
2020-06-23 10:14:08 +01:00
|
|
|
batchIndex += 1
|
|
|
|
|
|
|
|
// Indexing
|
|
|
|
narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
|
2020-06-23 16:37:33 +01:00
|
|
|
// bImages := ds.TrainImages.Idx(narrowIndex)
|
|
|
|
// bLabels := ds.TrainLabels.Idx(narrowIndex)
|
|
|
|
bImages := imagesTs.Idx(narrowIndex)
|
|
|
|
bLabels := labelsTs.Idx(narrowIndex)
|
2020-06-23 10:14:08 +01:00
|
|
|
|
|
|
|
bImages = bImages.MustTo(vs.Device(), true)
|
|
|
|
bLabels = bLabels.MustTo(vs.Device(), true)
|
|
|
|
|
|
|
|
logits := net.ForwardT(bImages, true)
|
|
|
|
loss := logits.CrossEntropyForLogits(bLabels)
|
2020-06-22 16:07:07 +01:00
|
|
|
|
2020-07-10 22:10:50 +01:00
|
|
|
// loss = loss.MustSetRequiresGrad(true)
|
|
|
|
|
2020-06-22 16:07:07 +01:00
|
|
|
opt.BackwardStep(loss)
|
2020-06-23 10:14:08 +01:00
|
|
|
|
|
|
|
epocLoss = loss.MustShallowClone()
|
|
|
|
epocLoss.Detach_()
|
|
|
|
|
|
|
|
// fmt.Printf("completed \t %v batches\t %.2f\n", i, loss.Values()[0])
|
|
|
|
|
|
|
|
bImages.MustDrop()
|
|
|
|
bLabels.MustDrop()
|
|
|
|
// logits.MustDrop()
|
2020-06-24 03:47:10 +01:00
|
|
|
// loss.MustDrop()
|
2020-06-22 16:07:07 +01:00
|
|
|
}
|
|
|
|
|
2020-07-10 22:10:50 +01:00
|
|
|
vs.Freeze()
|
|
|
|
testAccuracy := batchAccuracyForLogits(net, testImages, testLabels, vs.Device(), 1024)
|
|
|
|
vs.Unfreeze()
|
|
|
|
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Values()[0], testAccuracy*100.0)
|
2020-06-23 10:14:08 +01:00
|
|
|
|
2020-07-10 22:10:50 +01:00
|
|
|
// fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, epocLoss.Values()[0])
|
2020-06-23 10:14:08 +01:00
|
|
|
epocLoss.MustDrop()
|
2020-06-23 16:37:33 +01:00
|
|
|
imagesTs.MustDrop()
|
|
|
|
labelsTs.MustDrop()
|
|
|
|
}
|
|
|
|
|
|
|
|
testAccuracy := ts.BatchAccuracyForLogitsIdx(net, testImages, testLabels, vs.Device(), 1024)
|
|
|
|
fmt.Printf("Test accuracy: %.2f%%\n", testAccuracy*100)
|
|
|
|
|
|
|
|
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
|
|
|
}
|
|
|
|
|
|
|
|
func runCNN2() {
|
|
|
|
|
|
|
|
var ds vision.Dataset
|
|
|
|
ds = vision.LoadMNISTDir(MnistDirNN)
|
|
|
|
|
|
|
|
cuda := gotch.CudaBuilder(0)
|
|
|
|
vs := nn.NewVarStore(cuda.CudaIfAvailable())
|
2020-07-07 01:40:05 +01:00
|
|
|
net := newNet(vs.Root())
|
2020-06-23 16:37:33 +01:00
|
|
|
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
startTime := time.Now()
|
|
|
|
|
|
|
|
var lossVal float64
|
|
|
|
for epoch := 0; epoch < epochsCNN; epoch++ {
|
|
|
|
|
|
|
|
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchCNN)
|
|
|
|
// iter.Shuffle()
|
|
|
|
|
|
|
|
for {
|
|
|
|
item, ok := iter.Next()
|
|
|
|
if !ok {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
|
|
|
|
bImages := item.Data.MustTo(vs.Device(), true)
|
|
|
|
bLabels := item.Label.MustTo(vs.Device(), true)
|
|
|
|
|
2020-06-24 03:47:10 +01:00
|
|
|
// _ = ts.MustGradSetEnabled(true)
|
2020-06-23 16:37:33 +01:00
|
|
|
|
|
|
|
logits := net.ForwardT(bImages, true)
|
|
|
|
loss := logits.CrossEntropyForLogits(bLabels)
|
|
|
|
|
|
|
|
opt.BackwardStep(loss)
|
|
|
|
|
|
|
|
lossVal = loss.Values()[0]
|
|
|
|
|
|
|
|
bImages.MustDrop()
|
|
|
|
bLabels.MustDrop()
|
|
|
|
loss.MustDrop()
|
|
|
|
}
|
|
|
|
|
2020-07-10 22:10:50 +01:00
|
|
|
// fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, lossVal)
|
2020-06-23 16:37:33 +01:00
|
|
|
|
2020-07-10 22:10:50 +01:00
|
|
|
vs.Freeze()
|
|
|
|
testAcc := batchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN)
|
|
|
|
vs.Unfreeze()
|
|
|
|
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\t Accuracy: %.2f\n", epoch, lossVal, testAcc*100.0)
|
2020-06-22 16:07:07 +01:00
|
|
|
}
|
|
|
|
|
2020-06-24 03:47:10 +01:00
|
|
|
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN)
|
|
|
|
fmt.Printf("Loss: \t %.2f\t Accuracy: %.2f\n", lossVal, testAcc*100)
|
2020-06-23 10:14:08 +01:00
|
|
|
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
2020-06-18 08:14:48 +01:00
|
|
|
}
|
2020-07-10 22:10:50 +01:00
|
|
|
|
|
|
|
func batchAccuracyForLogits(m ts.ModuleT, xs, ys ts.Tensor, d gotch.Device, batchSize int) (retVal float64) {
|
|
|
|
|
|
|
|
var (
|
|
|
|
sumAccuracy float64 = 0.0
|
|
|
|
sampleCount float64 = 0.0
|
|
|
|
)
|
|
|
|
|
|
|
|
iter2 := ts.MustNewIter2(xs, ys, int64(batchSize))
|
|
|
|
for {
|
|
|
|
item, ok := iter2.Next()
|
|
|
|
if !ok {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
|
|
|
|
size := float64(item.Data.MustSize()[0])
|
|
|
|
bImages := item.Data.MustTo(d, true)
|
|
|
|
bLabels := item.Label.MustTo(d, true)
|
|
|
|
|
|
|
|
logits := m.ForwardT(bImages, false)
|
|
|
|
acc := logits.AccuracyForLogits(bLabels)
|
|
|
|
sumAccuracy += acc.Values()[0] * size
|
|
|
|
sampleCount += size
|
|
|
|
|
|
|
|
bImages.MustDrop()
|
|
|
|
bLabels.MustDrop()
|
|
|
|
acc.MustDrop()
|
|
|
|
}
|
|
|
|
|
|
|
|
return sumAccuracy / sampleCount
|
|
|
|
|
|
|
|
}
|