WIP(example/mnist): conv

This commit is contained in:
sugarme 2020-06-23 01:07:07 +10:00
parent d480f969bb
commit 9d31337b4f
7 changed files with 225 additions and 26 deletions

View File

@ -2,8 +2,86 @@ package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
"github.com/sugarme/gotch/vision"
)
func runCNN() {
fmt.Println("CNN will be implemented soon...!\n")
const (
MnistDirCNN string = "../../data/mnist"
epochsCNN = 10
batchCNN = 256
LrCNN = 1e-4
)
type Net struct {
conv1 nn.Conv2D
conv2 nn.Conv2D
fc1 nn.Linear
fc2 nn.Linear
}
func newNet(vs *nn.Path) Net {
conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())
conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())
fc1 := nn.NewLinear(*vs, 1024, 1024, *nn.DefaultLinearConfig())
fc2 := nn.NewLinear(*vs, 1024, 10, *nn.DefaultLinearConfig())
return Net{
conv1,
conv2,
*fc1,
*fc2}
}
func (n Net) ForwardT(xs ts.Tensor, train bool) ts.Tensor {
out := xs.MustView([]int64{-1, 1, 28, 28}).Apply(n.conv1).MaxPool2DDefault(2, true)
out = out.Apply(n.conv2).MaxPool2DDefault(2, true)
out = out.MustView([]int64{-1, 1024}).Apply(&n.fc1).MustRelu(true)
out.Dropout_(0.5, train)
return out.Apply(&n.fc2)
}
func runCNN() {
var ds vision.Dataset
ds = vision.LoadMNISTDir(MnistDirNN)
cuda := gotch.CudaBuilder(0)
vs := nn.NewVarStore(cuda.CudaIfAvailable())
path := vs.Root()
net := newNet(&path)
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
if err != nil {
log.Fatal(err)
}
for epoch := 0; epoch < epochsCNN; epoch++ {
var count = 0
for {
iter := ds.TrainIter(batchCNN).Shuffle()
item, ok := iter.Next()
if !ok {
break
}
loss := net.ForwardT(item.Data.MustTo(vs.Device(), true), true).CrossEntropyForLogits(item.Label.MustTo(vs.Device(), true))
opt.BackwardStep(loss)
loss.MustDrop()
count++
if count == 50 {
break
}
fmt.Printf("completed \t %v batches\n", count)
}
// testAccuracy := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 1024)
//
// fmt.Printf("Epoch: %v \t Test accuracy: %.2f%%\n", epoch, testAccuracy*100)
}
}

View File

@ -355,3 +355,35 @@ func AtgConv3d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, stride
C.atg_conv3d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cgroups)
}
// void atg_max_pool2d(tensor *, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int ceil_mode);
func AtgMaxPool2d(ptr *Ctensor, self Ctensor, kernelSizeData []int64, kernelSizeLen int, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, ceilMode int) {
ckernelSizeDataPtr := (*C.int64_t)(unsafe.Pointer(&kernelSizeData[0]))
ckernelSizeLen := *(*C.int)(unsafe.Pointer(&kernelSizeLen))
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
cceilMode := *(*C.int)(unsafe.Pointer(&ceilMode))
C.atg_max_pool2d(ptr, self, ckernelSizeDataPtr, ckernelSizeLen, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cceilMode)
}
// void atg_dropout(tensor *, tensor input, double p, int train);
func AtgDropout(ptr *Ctensor, input Ctensor, p float64, train int) {
cp := *(*C.double)(unsafe.Pointer(&p))
ctrain := *(*C.int)(unsafe.Pointer(&train))
C.atg_dropout(ptr, input, cp, ctrain)
}
// void atg_dropout_(tensor *, tensor self, double p, int train);
func AtgDropout_(ptr *Ctensor, self Ctensor, p float64, train int) {
cp := *(*C.double)(unsafe.Pointer(&p))
ctrain := *(*C.int)(unsafe.Pointer(&train))
C.atg_dropout_(ptr, self, cp, ctrain)
}

View File

@ -7,7 +7,6 @@ import (
)
type Conv1DConfig struct {
Kval int64
Stride []int64
Padding []int64
Dilation []int64
@ -18,7 +17,6 @@ type Conv1DConfig struct {
}
type Conv2DConfig struct {
Kval int64
Stride []int64
Padding []int64
Dilation []int64
@ -29,7 +27,6 @@ type Conv2DConfig struct {
}
type Conv3DConfig struct {
Kval int64
Stride []int64
Padding []int64
Dilation []int64
@ -71,14 +68,14 @@ type Conv1D struct {
Config Conv1DConfig
}
func NewConv1D(vs *Path, inDim, outDim int64, cfg Conv1DConfig) Conv1D {
func NewConv1D(vs *Path, inDim, outDim, k int64, cfg Conv1DConfig) Conv1D {
var conv Conv1D
conv.Config = cfg
if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
}
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, cfg.Kval)
weightSize = append(weightSize, k)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv
@ -90,14 +87,14 @@ type Conv2D struct {
Config Conv2DConfig
}
func NewConv2D(vs *Path, inDim, outDim int64, cfg Conv2DConfig) Conv2D {
func NewConv2D(vs *Path, inDim, outDim int64, k int64, cfg Conv2DConfig) Conv2D {
var conv Conv2D
conv.Config = cfg
if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
}
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, cfg.Kval, cfg.Kval)
weightSize = append(weightSize, k, k)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv
@ -109,14 +106,14 @@ type Conv3D struct {
Config Conv3DConfig
}
func NewConv3D(vs *Path, inDim, outDim int64, cfg Conv3DConfig) Conv3D {
func NewConv3D(vs *Path, inDim, outDim, k int64, cfg Conv3DConfig) Conv3D {
var conv Conv3D
conv.Config = cfg
if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
}
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, cfg.Kval, cfg.Kval, cfg.Kval)
weightSize = append(weightSize, k, k, k)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv

View File

@ -106,20 +106,21 @@ func (s *SequentialT) IsEmpty() (retVal bool) {
// Implement ModuleT interface for SequentialT:
// ==========================================
func (s SequentialT) Forward(xs ts.Tensor) (retVal ts.Tensor) {
if s.IsEmpty() {
return xs.MustShallowClone()
}
// forward sequentially
var currTs ts.Tensor = xs
for i := 0; i < len(s.layers); i++ {
currTs = s.layers[i].Forward(currTs)
}
return currTs
}
/*
* func (s SequentialT) Forward(xs ts.Tensor) (retVal ts.Tensor) {
* if s.IsEmpty() {
* return xs.MustShallowClone()
* }
*
* // forward sequentially
* var currTs ts.Tensor = xs
* for i := 0; i < len(s.layers); i++ {
* currTs = s.layers[i].Forward(currTs)
* }
*
* return currTs
* }
* */
func (s SequentialT) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
if s.IsEmpty() {
return xs.MustShallowClone()

View File

@ -18,7 +18,7 @@ type Module interface {
// The train parameter is commonly used to have different behavior
// between training and evaluation. E.g. When using dropout or batch-normalization.
type ModuleT interface {
Forward(xs Tensor) Tensor
// Forward(xs Tensor) Tensor
ForwardT(xs Tensor, train bool) Tensor
}

View File

@ -17,4 +17,8 @@ func (ts Tensor) AccuracyForLogits(targets Tensor) (retVal Tensor) {
return ts.MustArgmax(-1, false, true).MustEq1(targets).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true)
}
func (ts Tensor) MaxPool2DDefault(ksize int64, del bool) (retVal Tensor) {
return ts.MustMaxPool2D([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, []int64{1, 1}, false, del)
}
// TODO: continue

View File

@ -1087,3 +1087,90 @@ func MustConv3D(input, weight, bias Tensor, stride, padding, dilation []int64, g
return retVal
}
func (ts Tensor) MaxPool2D(kernel []int64, stride []int64, padding []int64, dilation []int64, ceil bool, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
if del {
defer ts.MustDrop()
}
var ceilMode int
switch ceil {
case true:
ceilMode = 1
case false:
ceilMode = 0
}
lib.AtgMaxPool2d(ptr, ts.ctensor, kernel, len(kernel), stride, len(stride), padding, len(padding), dilation, len(dilation), ceilMode)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustMaxPool2D(kernel []int64, stride []int64, padding []int64, dilation []int64, ceil bool, del bool) (retVal Tensor) {
retVal, err := ts.MaxPool2D(kernel, stride, padding, dilation, ceil, del)
if err != nil {
log.Fatal(err)
}
return retVal
}
func Dropout(input Tensor, p float64, train bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
var ctrain int
switch train {
case true:
ctrain = 1
case false:
ctrain = 0
}
lib.AtgDropout(ptr, input.ctensor, p, ctrain)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func MustDropout(input Tensor, p float64, train bool) (retVal Tensor) {
retVal, err := Dropout(input, p, train)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Dropout_(p float64, train bool) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
var ctrain int
switch train {
case true:
ctrain = 1
case false:
ctrain = 0
}
lib.AtgDropout_(ptr, ts.ctensor, p, ctrain)
err := TorchErr()
if err != nil {
log.Fatal(err)
}
}