gotch/example/mnist/nn.go

76 lines
1.7 KiB
Go
Raw Normal View History

2020-06-18 08:14:48 +01:00
package main
import (
"fmt"
"log"
2023-07-04 14:03:49 +01:00
"runtime"
2020-06-18 08:14:48 +01:00
2024-04-21 15:15:00 +01:00
"git.andr3h3nriqu3s.com/andr3/gotch"
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
"git.andr3h3nriqu3s.com/andr3/gotch/vision"
2020-06-18 08:14:48 +01:00
)
const (
2023-07-04 14:03:49 +01:00
ImageDimNN int64 = 784
HiddenNodesNN int64 = 128
LabelNN int64 = 10
2020-06-18 08:14:48 +01:00
epochsNN = 200
2020-06-18 08:14:48 +01:00
LrNN = 1e-3
2020-06-18 08:14:48 +01:00
)
var MnistDirNN string = fmt.Sprintf("%s/%s", gotch.CachedDir, "mnist")
2020-06-18 16:37:13 +01:00
var l nn.Linear
func netInit(vs *nn.Path) ts.Module {
2020-06-18 08:14:48 +01:00
n := nn.Seq()
n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig()))
2020-06-18 16:37:13 +01:00
n.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
return xs.MustRelu(false)
}))
2020-06-18 08:14:48 +01:00
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig()))
2020-06-18 08:14:48 +01:00
return n
2020-06-18 08:14:48 +01:00
}
func train(trainX, trainY, testX, testY *ts.Tensor, m ts.Module, opt *nn.Optimizer, epoch int) {
logits := m.Forward(trainX)
loss := logits.CrossEntropyForLogits(trainY)
opt.BackwardStep(loss)
testLogits := m.Forward(testX)
testAccuracy := testLogits.AccuracyForLogits(testY)
accuracy := testAccuracy.Float64Values()[0] * 100
lossVal := loss.Float64Values()[0]
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, lossVal, accuracy)
2023-07-04 14:03:49 +01:00
runtime.GC()
}
2020-06-18 08:14:48 +01:00
func runNN() {
var ds *vision.Dataset
2020-06-18 08:14:48 +01:00
ds = vision.LoadMNISTDir(MnistDirNN)
2023-07-04 14:03:49 +01:00
vs := nn.NewVarStore(device)
2020-06-18 08:14:48 +01:00
net := netInit(vs.Root())
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
if err != nil {
log.Fatal(err)
}
2023-07-04 14:03:49 +01:00
trainImages := ds.TrainImages.MustTo(device, true)
trainLabels := ds.TrainLabels.MustTo(device, true)
testImages := ds.TestImages.MustTo(device, true)
testLabels := ds.TestLabels.MustTo(device, true)
2020-06-18 08:14:48 +01:00
for epoch := 0; epoch < epochsNN; epoch++ {
2023-07-04 14:03:49 +01:00
train(trainImages, trainLabels, testImages, testLabels, net, opt, epoch)
2020-06-18 08:14:48 +01:00
}
}