59 lines
1.2 KiB
Go
59 lines
1.2 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"log"
|
||
|
|
||
|
"github.com/sugarme/gotch"
|
||
|
"github.com/sugarme/gotch/nn"
|
||
|
ts "github.com/sugarme/gotch/tensor"
|
||
|
"github.com/sugarme/gotch/vision"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
ImageDimNN int64 = 784
|
||
|
HiddenNodesNN int64 = 128
|
||
|
LabelNN int64 = 10
|
||
|
MnistDirNN string = "../../data/mnist"
|
||
|
|
||
|
epochsNN = 200
|
||
|
batchSizeNN = 256
|
||
|
|
||
|
LrNN = 1e-3
|
||
|
)
|
||
|
|
||
|
func netInit(vs nn.Path) ts.Module {
|
||
|
n := nn.Seq()
|
||
|
|
||
|
n.Add(nn.NewLinear(vs.Sub("layer1"), ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig()))
|
||
|
n.AddFn(func(xs ts.Tensor) ts.Tensor {
|
||
|
return xs.MustRelu()
|
||
|
})
|
||
|
|
||
|
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig()))
|
||
|
|
||
|
return n
|
||
|
}
|
||
|
|
||
|
func runNN() {
|
||
|
var ds vision.Dataset
|
||
|
ds = vision.LoadMNISTDir(MnistDirNN)
|
||
|
|
||
|
vs := nn.NewVarStore(gotch.CPU)
|
||
|
net := netInit(vs.Root())
|
||
|
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
||
|
if err != nil {
|
||
|
log.Fatal(err)
|
||
|
}
|
||
|
|
||
|
for epoch := 0; epoch < epochsNN; epoch++ {
|
||
|
loss := net.Forward(ds.TrainImages).CrossEntropyForLogits(ds.TrainLabels)
|
||
|
opt.BackwardStep(loss)
|
||
|
|
||
|
testAccuracy := net.Forward(ds.TestImages).AccuracyForLogits(ds.TestLabels).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||
|
|
||
|
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss, testAccuracy*100)
|
||
|
}
|
||
|
|
||
|
}
|