gotch/example/mnist/nn.go

76 lines
1.7 KiB
Go

package main
import (
"fmt"
"log"
"runtime"
"git.andr3h3nriqu3s.com/andr3/gotch"
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
"git.andr3h3nriqu3s.com/andr3/gotch/vision"
)
const (
ImageDimNN int64 = 784
HiddenNodesNN int64 = 128
LabelNN int64 = 10
epochsNN = 200
LrNN = 1e-3
)
var MnistDirNN string = fmt.Sprintf("%s/%s", gotch.CachedDir, "mnist")
var l nn.Linear
func netInit(vs *nn.Path) ts.Module {
n := nn.Seq()
n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig()))
n.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
return xs.MustRelu(false)
}))
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig()))
return n
}
func train(trainX, trainY, testX, testY *ts.Tensor, m ts.Module, opt *nn.Optimizer, epoch int) {
logits := m.Forward(trainX)
loss := logits.CrossEntropyForLogits(trainY)
opt.BackwardStep(loss)
testLogits := m.Forward(testX)
testAccuracy := testLogits.AccuracyForLogits(testY)
accuracy := testAccuracy.Float64Values()[0] * 100
lossVal := loss.Float64Values()[0]
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, lossVal, accuracy)
runtime.GC()
}
func runNN() {
var ds *vision.Dataset
ds = vision.LoadMNISTDir(MnistDirNN)
vs := nn.NewVarStore(device)
net := netInit(vs.Root())
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
if err != nil {
log.Fatal(err)
}
trainImages := ds.TrainImages.MustTo(device, true)
trainLabels := ds.TrainLabels.MustTo(device, true)
testImages := ds.TestImages.MustTo(device, true)
testLabels := ds.TestLabels.MustTo(device, true)
for epoch := 0; epoch < epochsNN; epoch++ {
train(trainImages, trainLabels, testImages, testLabels, net, opt, epoch)
}
}