WIP(example/mnist): nn

This commit is contained in:
sugarme 2020-06-18 17:14:48 +10:00
parent b5f112e030
commit 9be19702d1
8 changed files with 149 additions and 12 deletions

View File

@ -7,22 +7,27 @@
- Load MNIST data using helper function at `vision` sub-package
## Linear Regression
- Run with `go clean -cache -testcache && go run main.go linear.go`
- Run with `go clean -cache -testcache && go run . -model="linear"`
- Accuraccy should be about **91.68%**.
## Neural Network (NN)
TODO: update
- Run with `go clean -cache -testcache && go run . -model="nn"`
- Accuraccy should be about **TODO: update%**.
## Convolutional Neural Network (CNN)
TODO: update
- Run with `go clean -cache -testcache && go run . -model="cnn"`
- Accuraccy should be about **TODO: update%**.

9
example/mnist/cnn.go Normal file
View File

@ -0,0 +1,9 @@
package main
import (
"fmt"
)
func runCNN() {
fmt.Println("CNN will be implemented soon...!\n")
}

View File

@ -1,7 +1,29 @@
package main
import ()
import (
"flag"
)
var model string
func init() {
flag.StringVar(&model, "model", "linear", "specify a model to run")
}
func main() {
runLinear()
flag.Parse()
switch model {
case "linear":
runLinear()
case "nn":
runNN()
case "cnn":
runCNN()
default:
panic("No specified model to run")
}
}

BIN
example/mnist/mnist Executable file

Binary file not shown.

58
example/mnist/nn.go Normal file
View File

@ -0,0 +1,58 @@
package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
"github.com/sugarme/gotch/vision"
)
const (
ImageDimNN int64 = 784
HiddenNodesNN int64 = 128
LabelNN int64 = 10
MnistDirNN string = "../../data/mnist"
epochsNN = 200
batchSizeNN = 256
LrNN = 1e-3
)
func netInit(vs nn.Path) ts.Module {
n := nn.Seq()
n.Add(nn.NewLinear(vs.Sub("layer1"), ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig()))
n.AddFn(func(xs ts.Tensor) ts.Tensor {
return xs.MustRelu()
})
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig()))
return n
}
func runNN() {
var ds vision.Dataset
ds = vision.LoadMNISTDir(MnistDirNN)
vs := nn.NewVarStore(gotch.CPU)
net := netInit(vs.Root())
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
if err != nil {
log.Fatal(err)
}
for epoch := 0; epoch < epochsNN; epoch++ {
loss := net.Forward(ds.TrainImages).CrossEntropyForLogits(ds.TrainLabels)
opt.BackwardStep(loss)
testAccuracy := net.Forward(ds.TestImages).AccuracyForLogits(ds.TestLabels).MustView([]int64{-1}).MustFloat64Value([]int64{0})
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss, testAccuracy*100)
}
}

View File

@ -247,3 +247,13 @@ func AtgRandperm(ptr *Ctensor, n int64, optionKind int32, optionDevice int32) {
func AtgClamp_(ptr *Ctensor, self Ctensor, min Cscalar, max Cscalar) {
C.atg_clamp_(ptr, self, min, max)
}
// void atg_relu(tensor *, tensor self);
func AtgRelu(ptr *Ctensor, self Ctensor) {
C.atg_relu(ptr, self)
}
// void atg_relu_(tensor *, tensor self);
func AtgRelu_(ptr *Ctensor, self Ctensor) {
C.atg_relu_(ptr, self)
}

View File

@ -736,7 +736,41 @@ func (ts Tensor) Clamp_(min Scalar, max Scalar) {
defer C.free(unsafe.Pointer(ptr))
lib.AtgClamp_(ptr, ts.ctensor, min.cscalar, max.cscalar)
if err = TorchErr(); err != nil {
if err := TorchErr(); err != nil {
log.Fatal(err)
}
}
func (ts Tensor) Relu_() {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgRelu_(ptr, ts.ctensor)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
}
func (ts Tensor) Relu() (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgRelu(ptr, ts.ctensor)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustRelu() (retVal Tensor) {
retVal, err := ts.Relu()
if err != nil {
log.Fatal(err)
}
return retVal
}

View File

@ -3,7 +3,6 @@ package vision
// A simple dataset structure shared by various computer vision datasets.
import (
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
)
@ -19,12 +18,12 @@ type Dataset struct {
//=================
// TrainIter creates an iterator of Iter type for train images and labels
func (ds Dataset) TrainIter(batchSize int64) (retVal nn.Iter2) {
return nn.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchSize)
func (ds Dataset) TrainIter(batchSize int64) (retVal ts.Iter2) {
return ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchSize)
}
// TestIter creates an iterator of Iter type for test images and labels
func (ds Dataset) TestIter(batchSize int64) (retVal nn.Iter2) {
return nn.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize)
func (ds Dataset) TestIter(batchSize int64) (retVal ts.Iter2) {
return ts.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize)
}