diff --git a/example/mnist/README.md b/example/mnist/README.md index 90cce5c..d02603d 100644 --- a/example/mnist/README.md +++ b/example/mnist/README.md @@ -7,22 +7,27 @@ - Load MNIST data using helper function at `vision` sub-package + ## Linear Regression -- Run with `go clean -cache -testcache && go run main.go linear.go` +- Run with `go clean -cache -testcache && go run . -model="linear"` + - Accuraccy should be about **91.68%**. ## Neural Network (NN) -TODO: update +- Run with `go clean -cache -testcache && go run . -model="nn"` + +- Accuraccy should be about **TODO: update%**. ## Convolutional Neural Network (CNN) -TODO: update - +- Run with `go clean -cache -testcache && go run . -model="cnn"` + +- Accuraccy should be about **TODO: update%**. diff --git a/example/mnist/cnn.go b/example/mnist/cnn.go new file mode 100644 index 0000000..54915f3 --- /dev/null +++ b/example/mnist/cnn.go @@ -0,0 +1,9 @@ +package main + +import ( + "fmt" +) + +func runCNN() { + fmt.Println("CNN will be implemented soon...!\n") +} diff --git a/example/mnist/main.go b/example/mnist/main.go index 15a945a..dd27f15 100644 --- a/example/mnist/main.go +++ b/example/mnist/main.go @@ -1,7 +1,29 @@ package main -import () +import ( + "flag" +) + +var model string + +func init() { + flag.StringVar(&model, "model", "linear", "specify a model to run") + +} func main() { - runLinear() + + flag.Parse() + + switch model { + case "linear": + runLinear() + case "nn": + runNN() + case "cnn": + runCNN() + default: + panic("No specified model to run") + } + } diff --git a/example/mnist/mnist b/example/mnist/mnist new file mode 100755 index 0000000..cd4b4ee Binary files /dev/null and b/example/mnist/mnist differ diff --git a/example/mnist/nn.go b/example/mnist/nn.go new file mode 100644 index 0000000..bf50e65 --- /dev/null +++ b/example/mnist/nn.go @@ -0,0 +1,58 @@ +package main + +import ( + "fmt" + "log" + + "github.com/sugarme/gotch" + "github.com/sugarme/gotch/nn" + ts "github.com/sugarme/gotch/tensor" + "github.com/sugarme/gotch/vision" +) + +const ( + ImageDimNN int64 = 784 + HiddenNodesNN int64 = 128 + LabelNN int64 = 10 + MnistDirNN string = "../../data/mnist" + + epochsNN = 200 + batchSizeNN = 256 + + LrNN = 1e-3 +) + +func netInit(vs nn.Path) ts.Module { + n := nn.Seq() + + n.Add(nn.NewLinear(vs.Sub("layer1"), ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig())) + n.AddFn(func(xs ts.Tensor) ts.Tensor { + return xs.MustRelu() + }) + + n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig())) + + return n +} + +func runNN() { + var ds vision.Dataset + ds = vision.LoadMNISTDir(MnistDirNN) + + vs := nn.NewVarStore(gotch.CPU) + net := netInit(vs.Root()) + opt, err := nn.DefaultAdamConfig().Build(vs, LrNN) + if err != nil { + log.Fatal(err) + } + + for epoch := 0; epoch < epochsNN; epoch++ { + loss := net.Forward(ds.TrainImages).CrossEntropyForLogits(ds.TrainLabels) + opt.BackwardStep(loss) + + testAccuracy := net.Forward(ds.TestImages).AccuracyForLogits(ds.TestLabels).MustView([]int64{-1}).MustFloat64Value([]int64{0}) + + fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss, testAccuracy*100) + } + +} diff --git a/libtch/c-generated-sample.go b/libtch/c-generated-sample.go index dd1f039..578f7cd 100644 --- a/libtch/c-generated-sample.go +++ b/libtch/c-generated-sample.go @@ -247,3 +247,13 @@ func AtgRandperm(ptr *Ctensor, n int64, optionKind int32, optionDevice int32) { func AtgClamp_(ptr *Ctensor, self Ctensor, min Cscalar, max Cscalar) { C.atg_clamp_(ptr, self, min, max) } + +// void atg_relu(tensor *, tensor self); +func AtgRelu(ptr *Ctensor, self Ctensor) { + C.atg_relu(ptr, self) +} + +// void atg_relu_(tensor *, tensor self); +func AtgRelu_(ptr *Ctensor, self Ctensor) { + C.atg_relu_(ptr, self) +} diff --git a/tensor/tensor-generated-sample.go b/tensor/tensor-generated-sample.go index c587e5f..64849ee 100644 --- a/tensor/tensor-generated-sample.go +++ b/tensor/tensor-generated-sample.go @@ -736,7 +736,41 @@ func (ts Tensor) Clamp_(min Scalar, max Scalar) { defer C.free(unsafe.Pointer(ptr)) lib.AtgClamp_(ptr, ts.ctensor, min.cscalar, max.cscalar) - if err = TorchErr(); err != nil { + if err := TorchErr(); err != nil { log.Fatal(err) } } + +func (ts Tensor) Relu_() { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgRelu_(ptr, ts.ctensor) + if err := TorchErr(); err != nil { + log.Fatal(err) + } +} + +func (ts Tensor) Relu() (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgRelu(ptr, ts.ctensor) + err = TorchErr() + if err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustRelu() (retVal Tensor) { + retVal, err := ts.Relu() + if err != nil { + log.Fatal(err) + } + + return retVal +} diff --git a/vision/dataset.go b/vision/dataset.go index 7ccd8b6..14d50f0 100644 --- a/vision/dataset.go +++ b/vision/dataset.go @@ -3,7 +3,6 @@ package vision // A simple dataset structure shared by various computer vision datasets. import ( - "github.com/sugarme/gotch/nn" ts "github.com/sugarme/gotch/tensor" ) @@ -19,12 +18,12 @@ type Dataset struct { //================= // TrainIter creates an iterator of Iter type for train images and labels -func (ds Dataset) TrainIter(batchSize int64) (retVal nn.Iter2) { - return nn.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchSize) +func (ds Dataset) TrainIter(batchSize int64) (retVal ts.Iter2) { + return ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchSize) } // TestIter creates an iterator of Iter type for test images and labels -func (ds Dataset) TestIter(batchSize int64) (retVal nn.Iter2) { - return nn.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize) +func (ds Dataset) TestIter(batchSize int64) (retVal ts.Iter2) { + return ts.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize) }