gotch/example/mnist/nn.go

71 lines
1.5 KiB
Go
Raw Normal View History

2020-06-18 08:14:48 +01:00
package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
"github.com/sugarme/gotch/vision"
)
const (
ImageDimNN int64 = 784
HiddenNodesNN int64 = 128
LabelNN int64 = 10
MnistDirNN string = "../../data/mnist"
epochsNN = 200
2020-06-18 08:14:48 +01:00
LrNN = 1e-3
2020-06-18 08:14:48 +01:00
)
2020-06-18 16:37:13 +01:00
var l nn.Linear
2020-06-18 08:14:48 +01:00
func netInit(vs nn.Path) ts.Module {
n := nn.Seq()
n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, *nn.DefaultLinearConfig()))
2020-06-18 16:37:13 +01:00
n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
return xs.MustRelu(false)
}))
2020-06-18 08:14:48 +01:00
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, *nn.DefaultLinearConfig()))
// n.Add(nn.NewLinear(vs, ImageDimNN, LabelNN, nn.DefaultLinearConfig()))
2020-06-18 08:14:48 +01:00
return &n
2020-06-18 08:14:48 +01:00
}
func train(trainX, trainY, testX, testY ts.Tensor, m ts.Module, opt nn.Optimizer, epoch int) {
logits := m.Forward(trainX)
loss := logits.CrossEntropyForLogits(trainY)
opt.BackwardStep(loss)
testAccuracy := m.Forward(testX).AccuracyForLogits(testY)
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy.Values()[0]*100)
loss.MustDrop()
testAccuracy.MustDrop()
}
2020-06-18 08:14:48 +01:00
func runNN() {
2020-06-18 08:14:48 +01:00
var ds vision.Dataset
ds = vision.LoadMNISTDir(MnistDirNN)
vs := nn.NewVarStore(gotch.CPU)
net := netInit(vs.Root())
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
if err != nil {
log.Fatal(err)
}
for epoch := 0; epoch < epochsNN; epoch++ {
train(ds.TrainImages, ds.TrainLabels, ds.TestImages, ds.TestLabels, net, opt, epoch)
2020-06-18 08:14:48 +01:00
}
}