2020-06-18 08:14:48 +01:00
|
|
|
package main
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"log"
|
|
|
|
|
|
|
|
"github.com/sugarme/gotch"
|
|
|
|
"github.com/sugarme/gotch/nn"
|
2022-03-12 07:20:20 +00:00
|
|
|
"github.com/sugarme/gotch/ts"
|
2020-06-18 08:14:48 +01:00
|
|
|
"github.com/sugarme/gotch/vision"
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
ImageDimNN int64 = 784
|
|
|
|
HiddenNodesNN int64 = 128
|
|
|
|
LabelNN int64 = 10
|
|
|
|
MnistDirNN string = "../../data/mnist"
|
|
|
|
|
2020-06-24 09:14:34 +01:00
|
|
|
epochsNN = 200
|
2020-06-18 08:14:48 +01:00
|
|
|
|
2020-06-24 09:14:34 +01:00
|
|
|
LrNN = 1e-3
|
2020-06-18 08:14:48 +01:00
|
|
|
)
|
|
|
|
|
2020-06-18 16:37:13 +01:00
|
|
|
var l nn.Linear
|
|
|
|
|
2020-10-31 11:11:50 +00:00
|
|
|
func netInit(vs *nn.Path) ts.Module {
|
2020-06-18 08:14:48 +01:00
|
|
|
n := nn.Seq()
|
|
|
|
|
2020-07-03 02:20:52 +01:00
|
|
|
n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig()))
|
2020-06-18 16:37:13 +01:00
|
|
|
|
2020-10-31 11:11:50 +00:00
|
|
|
n.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
|
2020-06-23 06:21:16 +01:00
|
|
|
return xs.MustRelu(false)
|
|
|
|
}))
|
2020-06-18 08:14:48 +01:00
|
|
|
|
2020-07-03 02:20:52 +01:00
|
|
|
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig()))
|
2020-06-18 08:14:48 +01:00
|
|
|
|
2020-10-31 11:11:50 +00:00
|
|
|
return n
|
2020-06-18 08:14:48 +01:00
|
|
|
}
|
|
|
|
|
2020-10-31 11:11:50 +00:00
|
|
|
func train(trainX, trainY, testX, testY *ts.Tensor, m ts.Module, opt *nn.Optimizer, epoch int) {
|
2020-06-22 06:14:32 +01:00
|
|
|
|
2020-06-23 04:37:26 +01:00
|
|
|
logits := m.Forward(trainX)
|
|
|
|
loss := logits.CrossEntropyForLogits(trainY)
|
2020-06-21 01:57:29 +01:00
|
|
|
|
|
|
|
opt.BackwardStep(loss)
|
|
|
|
|
2020-07-02 12:30:45 +01:00
|
|
|
testLogits := m.Forward(testX)
|
|
|
|
testAccuracy := testLogits.AccuracyForLogits(testY)
|
2020-07-22 06:56:30 +01:00
|
|
|
accuracy := testAccuracy.Float64Values()[0] * 100
|
2020-06-21 14:37:42 +01:00
|
|
|
testAccuracy.MustDrop()
|
2020-07-22 06:56:30 +01:00
|
|
|
lossVal := loss.Float64Values()[0]
|
2020-07-02 12:30:45 +01:00
|
|
|
loss.MustDrop()
|
|
|
|
|
|
|
|
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, lossVal, accuracy)
|
2020-06-21 01:57:29 +01:00
|
|
|
}
|
|
|
|
|
2020-06-18 08:14:48 +01:00
|
|
|
func runNN() {
|
2020-06-21 01:57:29 +01:00
|
|
|
|
2020-10-31 11:11:50 +00:00
|
|
|
var ds *vision.Dataset
|
2020-06-18 08:14:48 +01:00
|
|
|
ds = vision.LoadMNISTDir(MnistDirNN)
|
|
|
|
vs := nn.NewVarStore(gotch.CPU)
|
|
|
|
net := netInit(vs.Root())
|
|
|
|
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
for epoch := 0; epoch < epochsNN; epoch++ {
|
2020-06-21 01:57:29 +01:00
|
|
|
train(ds.TrainImages, ds.TrainLabels, ds.TestImages, ds.TestLabels, net, opt, epoch)
|
2020-06-18 08:14:48 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
}
|