gotch/example/mnist/cnn.go

134 lines
3.4 KiB
Go
Raw Normal View History

2020-06-18 08:14:48 +01:00
package main
import (
"fmt"
2020-06-22 16:07:07 +01:00
"log"
2023-07-04 14:03:49 +01:00
"runtime"
"sync"
2020-06-23 10:14:08 +01:00
"time"
2020-06-22 16:07:07 +01:00
2024-04-21 15:15:00 +01:00
"git.andr3h3nriqu3s.com/andr3/gotch"
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
"git.andr3h3nriqu3s.com/andr3/gotch/vision"
2020-06-22 16:07:07 +01:00
)
const (
MnistDirCNN string = "../../data/mnist"
epochsCNN = 30
2020-06-22 16:07:07 +01:00
batchCNN = 256
2023-07-04 14:03:49 +01:00
// batchSize = 256
batchSize = 32
2020-06-22 16:07:07 +01:00
LrCNN = 3 * 1e-4
2020-06-18 08:14:48 +01:00
)
2023-07-04 14:03:49 +01:00
var mu sync.Mutex
2020-06-22 16:07:07 +01:00
type Net struct {
conv1 *nn.Conv2D
conv2 *nn.Conv2D
fc1 *nn.Linear
fc2 *nn.Linear
2020-06-22 16:07:07 +01:00
}
func newNet(vs *nn.Path) *Net {
2020-06-22 16:07:07 +01:00
conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())
conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())
fc1 := nn.NewLinear(vs, 1024, 1024, nn.DefaultLinearConfig())
fc2 := nn.NewLinear(vs, 1024, 10, nn.DefaultLinearConfig())
2020-06-22 16:07:07 +01:00
return &Net{
2020-06-22 16:07:07 +01:00
conv1,
conv2,
fc1,
fc2}
2020-06-22 16:07:07 +01:00
}
func (n *Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
2020-06-23 10:14:08 +01:00
outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false)
outC1 := outView1.Apply(n.conv1)
outMP1 := outC1.MaxPool2DDefault(2, true)
outC2 := outMP1.Apply(n.conv2)
outMP2 := outC2.MaxPool2DDefault(2, true)
outView2 := outMP2.MustView([]int64{-1, 1024}, true)
outFC1 := outView2.Apply(n.fc1)
2023-07-04 14:03:49 +01:00
outRelu := outFC1.MustRelu(false)
2020-06-23 10:14:08 +01:00
outDropout := ts.MustDropout(outRelu, 0.5, train)
return outDropout.Apply(n.fc2)
2020-06-22 16:07:07 +01:00
}
2020-06-23 16:37:33 +01:00
func runCNN1() {
var ds *vision.Dataset
2020-06-22 16:07:07 +01:00
ds = vision.LoadMNISTDir(MnistDirNN)
2023-07-04 14:03:49 +01:00
trainImages := ds.TrainImages.MustTo(device, false) //[60000, 784]
trainLabels := ds.TrainLabels.MustTo(device, false) // [60000, 784]
testImages := ds.TestImages.MustTo(device, false) // [10000, 784]
testLabels := ds.TestLabels.MustTo(device, false) // [10000, 784]
2022-03-12 04:47:15 +00:00
fmt.Printf("testImages: %v\n", testImages.MustSize())
fmt.Printf("testLabels: %v\n", testLabels.MustSize())
2020-06-23 16:37:33 +01:00
vs := nn.NewVarStore(device)
net := newNet(vs.Root())
opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)
// opt, err := nn.DefaultSGDConfig().Build(vs, LrCNN)
2020-06-22 16:07:07 +01:00
if err != nil {
log.Fatal(err)
}
var bestAccuracy float64 = 0.0
2020-06-23 10:14:08 +01:00
startTime := time.Now()
2020-06-22 16:07:07 +01:00
for epoch := 0; epoch < epochsCNN; epoch++ {
2020-06-23 10:14:08 +01:00
totalSize := ds.TrainImages.MustSize()[0]
samples := int(totalSize)
2022-03-12 04:47:15 +00:00
// Shuffling
2023-07-04 14:03:49 +01:00
index := ts.MustRandperm(int64(totalSize), gotch.Int64, device)
imagesTs := trainImages.MustIndexSelect(0, index, false)
labelsTs := trainLabels.MustIndexSelect(0, index, false)
2020-06-23 10:14:08 +01:00
batches := samples / batchSize
batchIndex := 0
2022-03-12 04:47:15 +00:00
var epocLoss float64
2020-06-23 10:14:08 +01:00
for i := 0; i < batches; i++ {
start := batchIndex * batchSize
size := batchSize
if samples-start < batchSize {
2020-06-22 16:07:07 +01:00
break
}
2020-06-23 10:14:08 +01:00
batchIndex += 1
// Indexing
2022-03-12 04:47:15 +00:00
bImages := imagesTs.MustNarrow(0, int64(start), int64(size), false)
2020-06-23 10:14:08 +01:00
logits := net.ForwardT(bImages, true)
2023-07-04 14:03:49 +01:00
bLabels := labelsTs.MustNarrow(0, int64(start), int64(size), false)
2020-06-23 10:14:08 +01:00
loss := logits.CrossEntropyForLogits(bLabels)
2020-06-22 16:07:07 +01:00
2022-03-12 04:47:15 +00:00
loss = loss.MustSetRequiresGrad(true, true)
2020-06-22 16:07:07 +01:00
opt.BackwardStep(loss)
2022-03-12 04:47:15 +00:00
epocLoss = loss.Float64Values()[0]
2023-07-04 14:03:49 +01:00
runtime.GC()
2020-06-22 16:07:07 +01:00
}
2022-03-12 04:47:15 +00:00
ts.NoGrad(func() {
2023-07-04 14:03:49 +01:00
fmt.Printf("Start eval...")
testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1000)
2022-03-12 04:47:15 +00:00
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss, testAccuracy*100.0)
if testAccuracy > bestAccuracy {
bestAccuracy = testAccuracy
}
2023-07-05 15:20:11 +01:00
})
2020-06-23 16:37:33 +01:00
}
fmt.Printf("Best test accuracy: %.2f%%\n", bestAccuracy*100.0)
2020-06-23 10:14:08 +01:00
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
2023-07-04 14:03:49 +01:00
ts.CleanUp()
2020-06-18 08:14:48 +01:00
}