WIP(example/mnist): nn
This commit is contained in:
parent
b5f112e030
commit
9be19702d1
|
@ -7,22 +7,27 @@
|
||||||
|
|
||||||
- Load MNIST data using helper function at `vision` sub-package
|
- Load MNIST data using helper function at `vision` sub-package
|
||||||
|
|
||||||
|
|
||||||
## Linear Regression
|
## Linear Regression
|
||||||
|
|
||||||
- Run with `go clean -cache -testcache && go run main.go linear.go`
|
- Run with `go clean -cache -testcache && go run . -model="linear"`
|
||||||
|
|
||||||
|
|
||||||
- Accuraccy should be about **91.68%**.
|
- Accuraccy should be about **91.68%**.
|
||||||
|
|
||||||
|
|
||||||
## Neural Network (NN)
|
## Neural Network (NN)
|
||||||
|
|
||||||
TODO: update
|
- Run with `go clean -cache -testcache && go run . -model="nn"`
|
||||||
|
|
||||||
|
- Accuraccy should be about **TODO: update%**.
|
||||||
|
|
||||||
|
|
||||||
## Convolutional Neural Network (CNN)
|
## Convolutional Neural Network (CNN)
|
||||||
|
|
||||||
TODO: update
|
- Run with `go clean -cache -testcache && go run . -model="cnn"`
|
||||||
|
|
||||||
|
- Accuraccy should be about **TODO: update%**.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
9
example/mnist/cnn.go
Normal file
9
example/mnist/cnn.go
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func runCNN() {
|
||||||
|
fmt.Println("CNN will be implemented soon...!\n")
|
||||||
|
}
|
|
@ -1,7 +1,29 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import ()
|
import (
|
||||||
|
"flag"
|
||||||
|
)
|
||||||
|
|
||||||
|
var model string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
flag.StringVar(&model, "model", "linear", "specify a model to run")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
runLinear()
|
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
switch model {
|
||||||
|
case "linear":
|
||||||
|
runLinear()
|
||||||
|
case "nn":
|
||||||
|
runNN()
|
||||||
|
case "cnn":
|
||||||
|
runCNN()
|
||||||
|
default:
|
||||||
|
panic("No specified model to run")
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
BIN
example/mnist/mnist
Executable file
BIN
example/mnist/mnist
Executable file
Binary file not shown.
58
example/mnist/nn.go
Normal file
58
example/mnist/nn.go
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
"github.com/sugarme/gotch/nn"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
"github.com/sugarme/gotch/vision"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ImageDimNN int64 = 784
|
||||||
|
HiddenNodesNN int64 = 128
|
||||||
|
LabelNN int64 = 10
|
||||||
|
MnistDirNN string = "../../data/mnist"
|
||||||
|
|
||||||
|
epochsNN = 200
|
||||||
|
batchSizeNN = 256
|
||||||
|
|
||||||
|
LrNN = 1e-3
|
||||||
|
)
|
||||||
|
|
||||||
|
func netInit(vs nn.Path) ts.Module {
|
||||||
|
n := nn.Seq()
|
||||||
|
|
||||||
|
n.Add(nn.NewLinear(vs.Sub("layer1"), ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig()))
|
||||||
|
n.AddFn(func(xs ts.Tensor) ts.Tensor {
|
||||||
|
return xs.MustRelu()
|
||||||
|
})
|
||||||
|
|
||||||
|
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig()))
|
||||||
|
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func runNN() {
|
||||||
|
var ds vision.Dataset
|
||||||
|
ds = vision.LoadMNISTDir(MnistDirNN)
|
||||||
|
|
||||||
|
vs := nn.NewVarStore(gotch.CPU)
|
||||||
|
net := netInit(vs.Root())
|
||||||
|
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for epoch := 0; epoch < epochsNN; epoch++ {
|
||||||
|
loss := net.Forward(ds.TrainImages).CrossEntropyForLogits(ds.TrainLabels)
|
||||||
|
opt.BackwardStep(loss)
|
||||||
|
|
||||||
|
testAccuracy := net.Forward(ds.TestImages).AccuracyForLogits(ds.TestLabels).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||||
|
|
||||||
|
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss, testAccuracy*100)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -247,3 +247,13 @@ func AtgRandperm(ptr *Ctensor, n int64, optionKind int32, optionDevice int32) {
|
||||||
func AtgClamp_(ptr *Ctensor, self Ctensor, min Cscalar, max Cscalar) {
|
func AtgClamp_(ptr *Ctensor, self Ctensor, min Cscalar, max Cscalar) {
|
||||||
C.atg_clamp_(ptr, self, min, max)
|
C.atg_clamp_(ptr, self, min, max)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// void atg_relu(tensor *, tensor self);
|
||||||
|
func AtgRelu(ptr *Ctensor, self Ctensor) {
|
||||||
|
C.atg_relu(ptr, self)
|
||||||
|
}
|
||||||
|
|
||||||
|
// void atg_relu_(tensor *, tensor self);
|
||||||
|
func AtgRelu_(ptr *Ctensor, self Ctensor) {
|
||||||
|
C.atg_relu_(ptr, self)
|
||||||
|
}
|
||||||
|
|
|
@ -736,7 +736,41 @@ func (ts Tensor) Clamp_(min Scalar, max Scalar) {
|
||||||
defer C.free(unsafe.Pointer(ptr))
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
lib.AtgClamp_(ptr, ts.ctensor, min.cscalar, max.cscalar)
|
lib.AtgClamp_(ptr, ts.ctensor, min.cscalar, max.cscalar)
|
||||||
if err = TorchErr(); err != nil {
|
if err := TorchErr(); err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Relu_() {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
lib.AtgRelu_(ptr, ts.ctensor)
|
||||||
|
if err := TorchErr(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Relu() (retVal Tensor, err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
lib.AtgRelu(ptr, ts.ctensor)
|
||||||
|
err = TorchErr()
|
||||||
|
if err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustRelu() (retVal Tensor) {
|
||||||
|
retVal, err := ts.Relu()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package vision
|
||||||
// A simple dataset structure shared by various computer vision datasets.
|
// A simple dataset structure shared by various computer vision datasets.
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/sugarme/gotch/nn"
|
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,12 +18,12 @@ type Dataset struct {
|
||||||
//=================
|
//=================
|
||||||
|
|
||||||
// TrainIter creates an iterator of Iter type for train images and labels
|
// TrainIter creates an iterator of Iter type for train images and labels
|
||||||
func (ds Dataset) TrainIter(batchSize int64) (retVal nn.Iter2) {
|
func (ds Dataset) TrainIter(batchSize int64) (retVal ts.Iter2) {
|
||||||
return nn.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchSize)
|
return ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchSize)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestIter creates an iterator of Iter type for test images and labels
|
// TestIter creates an iterator of Iter type for test images and labels
|
||||||
func (ds Dataset) TestIter(batchSize int64) (retVal nn.Iter2) {
|
func (ds Dataset) TestIter(batchSize int64) (retVal ts.Iter2) {
|
||||||
return nn.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize)
|
return ts.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user