WIP(example/mnist): nn
This commit is contained in:
parent
b5f112e030
commit
9be19702d1
|
@ -7,22 +7,27 @@
|
|||
|
||||
- Load MNIST data using helper function at `vision` sub-package
|
||||
|
||||
|
||||
## Linear Regression
|
||||
|
||||
- Run with `go clean -cache -testcache && go run main.go linear.go`
|
||||
- Run with `go clean -cache -testcache && go run . -model="linear"`
|
||||
|
||||
|
||||
- Accuraccy should be about **91.68%**.
|
||||
|
||||
|
||||
## Neural Network (NN)
|
||||
|
||||
TODO: update
|
||||
- Run with `go clean -cache -testcache && go run . -model="nn"`
|
||||
|
||||
- Accuraccy should be about **TODO: update%**.
|
||||
|
||||
|
||||
## Convolutional Neural Network (CNN)
|
||||
|
||||
TODO: update
|
||||
|
||||
- Run with `go clean -cache -testcache && go run . -model="cnn"`
|
||||
|
||||
- Accuraccy should be about **TODO: update%**.
|
||||
|
||||
|
||||
|
||||
|
|
9
example/mnist/cnn.go
Normal file
9
example/mnist/cnn.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func runCNN() {
|
||||
fmt.Println("CNN will be implemented soon...!\n")
|
||||
}
|
|
@ -1,7 +1,29 @@
|
|||
package main
|
||||
|
||||
import ()
|
||||
import (
|
||||
"flag"
|
||||
)
|
||||
|
||||
var model string
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&model, "model", "linear", "specify a model to run")
|
||||
|
||||
}
|
||||
|
||||
func main() {
|
||||
runLinear()
|
||||
|
||||
flag.Parse()
|
||||
|
||||
switch model {
|
||||
case "linear":
|
||||
runLinear()
|
||||
case "nn":
|
||||
runNN()
|
||||
case "cnn":
|
||||
runCNN()
|
||||
default:
|
||||
panic("No specified model to run")
|
||||
}
|
||||
|
||||
}
|
||||
|
|
BIN
example/mnist/mnist
Executable file
BIN
example/mnist/mnist
Executable file
Binary file not shown.
58
example/mnist/nn.go
Normal file
58
example/mnist/nn.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
const (
|
||||
ImageDimNN int64 = 784
|
||||
HiddenNodesNN int64 = 128
|
||||
LabelNN int64 = 10
|
||||
MnistDirNN string = "../../data/mnist"
|
||||
|
||||
epochsNN = 200
|
||||
batchSizeNN = 256
|
||||
|
||||
LrNN = 1e-3
|
||||
)
|
||||
|
||||
func netInit(vs nn.Path) ts.Module {
|
||||
n := nn.Seq()
|
||||
|
||||
n.Add(nn.NewLinear(vs.Sub("layer1"), ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig()))
|
||||
n.AddFn(func(xs ts.Tensor) ts.Tensor {
|
||||
return xs.MustRelu()
|
||||
})
|
||||
|
||||
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig()))
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func runNN() {
|
||||
var ds vision.Dataset
|
||||
ds = vision.LoadMNISTDir(MnistDirNN)
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
net := netInit(vs.Root())
|
||||
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for epoch := 0; epoch < epochsNN; epoch++ {
|
||||
loss := net.Forward(ds.TrainImages).CrossEntropyForLogits(ds.TrainLabels)
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
testAccuracy := net.Forward(ds.TestImages).AccuracyForLogits(ds.TestLabels).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
|
||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss, testAccuracy*100)
|
||||
}
|
||||
|
||||
}
|
|
@ -247,3 +247,13 @@ func AtgRandperm(ptr *Ctensor, n int64, optionKind int32, optionDevice int32) {
|
|||
func AtgClamp_(ptr *Ctensor, self Ctensor, min Cscalar, max Cscalar) {
|
||||
C.atg_clamp_(ptr, self, min, max)
|
||||
}
|
||||
|
||||
// void atg_relu(tensor *, tensor self);
|
||||
func AtgRelu(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_relu(ptr, self)
|
||||
}
|
||||
|
||||
// void atg_relu_(tensor *, tensor self);
|
||||
func AtgRelu_(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_relu_(ptr, self)
|
||||
}
|
||||
|
|
|
@ -736,7 +736,41 @@ func (ts Tensor) Clamp_(min Scalar, max Scalar) {
|
|||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgClamp_(ptr, ts.ctensor, min.cscalar, max.cscalar)
|
||||
if err = TorchErr(); err != nil {
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) Relu_() {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgRelu_(ptr, ts.ctensor)
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) Relu() (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgRelu(ptr, ts.ctensor)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustRelu() (retVal Tensor) {
|
||||
retVal, err := ts.Relu()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package vision
|
|||
// A simple dataset structure shared by various computer vision datasets.
|
||||
|
||||
import (
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
|
@ -19,12 +18,12 @@ type Dataset struct {
|
|||
//=================
|
||||
|
||||
// TrainIter creates an iterator of Iter type for train images and labels
|
||||
func (ds Dataset) TrainIter(batchSize int64) (retVal nn.Iter2) {
|
||||
return nn.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchSize)
|
||||
func (ds Dataset) TrainIter(batchSize int64) (retVal ts.Iter2) {
|
||||
return ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchSize)
|
||||
|
||||
}
|
||||
|
||||
// TestIter creates an iterator of Iter type for test images and labels
|
||||
func (ds Dataset) TestIter(batchSize int64) (retVal nn.Iter2) {
|
||||
return nn.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize)
|
||||
func (ds Dataset) TestIter(batchSize int64) (retVal ts.Iter2) {
|
||||
return ts.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user