WIP(example/mnist): conv
This commit is contained in:
parent
d480f969bb
commit
9d31337b4f
|
@ -2,8 +2,86 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
func runCNN() {
|
||||
fmt.Println("CNN will be implemented soon...!\n")
|
||||
const (
|
||||
MnistDirCNN string = "../../data/mnist"
|
||||
|
||||
epochsCNN = 10
|
||||
batchCNN = 256
|
||||
|
||||
LrCNN = 1e-4
|
||||
)
|
||||
|
||||
type Net struct {
|
||||
conv1 nn.Conv2D
|
||||
conv2 nn.Conv2D
|
||||
fc1 nn.Linear
|
||||
fc2 nn.Linear
|
||||
}
|
||||
|
||||
func newNet(vs *nn.Path) Net {
|
||||
conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())
|
||||
conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())
|
||||
fc1 := nn.NewLinear(*vs, 1024, 1024, *nn.DefaultLinearConfig())
|
||||
fc2 := nn.NewLinear(*vs, 1024, 10, *nn.DefaultLinearConfig())
|
||||
|
||||
return Net{
|
||||
conv1,
|
||||
conv2,
|
||||
*fc1,
|
||||
*fc2}
|
||||
}
|
||||
|
||||
func (n Net) ForwardT(xs ts.Tensor, train bool) ts.Tensor {
|
||||
out := xs.MustView([]int64{-1, 1, 28, 28}).Apply(n.conv1).MaxPool2DDefault(2, true)
|
||||
out = out.Apply(n.conv2).MaxPool2DDefault(2, true)
|
||||
out = out.MustView([]int64{-1, 1024}).Apply(&n.fc1).MustRelu(true)
|
||||
out.Dropout_(0.5, train)
|
||||
return out.Apply(&n.fc2)
|
||||
}
|
||||
|
||||
func runCNN() {
|
||||
|
||||
var ds vision.Dataset
|
||||
ds = vision.LoadMNISTDir(MnistDirNN)
|
||||
cuda := gotch.CudaBuilder(0)
|
||||
vs := nn.NewVarStore(cuda.CudaIfAvailable())
|
||||
path := vs.Root()
|
||||
net := newNet(&path)
|
||||
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for epoch := 0; epoch < epochsCNN; epoch++ {
|
||||
var count = 0
|
||||
for {
|
||||
iter := ds.TrainIter(batchCNN).Shuffle()
|
||||
item, ok := iter.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
loss := net.ForwardT(item.Data.MustTo(vs.Device(), true), true).CrossEntropyForLogits(item.Label.MustTo(vs.Device(), true))
|
||||
opt.BackwardStep(loss)
|
||||
loss.MustDrop()
|
||||
count++
|
||||
if count == 50 {
|
||||
break
|
||||
}
|
||||
fmt.Printf("completed \t %v batches\n", count)
|
||||
}
|
||||
|
||||
// testAccuracy := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 1024)
|
||||
//
|
||||
// fmt.Printf("Epoch: %v \t Test accuracy: %.2f%%\n", epoch, testAccuracy*100)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -355,3 +355,35 @@ func AtgConv3d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, stride
|
|||
|
||||
C.atg_conv3d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cgroups)
|
||||
}
|
||||
|
||||
// void atg_max_pool2d(tensor *, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int ceil_mode);
|
||||
func AtgMaxPool2d(ptr *Ctensor, self Ctensor, kernelSizeData []int64, kernelSizeLen int, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, ceilMode int) {
|
||||
|
||||
ckernelSizeDataPtr := (*C.int64_t)(unsafe.Pointer(&kernelSizeData[0]))
|
||||
ckernelSizeLen := *(*C.int)(unsafe.Pointer(&kernelSizeLen))
|
||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
||||
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
|
||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
||||
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
|
||||
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
|
||||
cceilMode := *(*C.int)(unsafe.Pointer(&ceilMode))
|
||||
|
||||
C.atg_max_pool2d(ptr, self, ckernelSizeDataPtr, ckernelSizeLen, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cceilMode)
|
||||
}
|
||||
|
||||
// void atg_dropout(tensor *, tensor input, double p, int train);
|
||||
func AtgDropout(ptr *Ctensor, input Ctensor, p float64, train int) {
|
||||
cp := *(*C.double)(unsafe.Pointer(&p))
|
||||
ctrain := *(*C.int)(unsafe.Pointer(&train))
|
||||
|
||||
C.atg_dropout(ptr, input, cp, ctrain)
|
||||
}
|
||||
|
||||
// void atg_dropout_(tensor *, tensor self, double p, int train);
|
||||
func AtgDropout_(ptr *Ctensor, self Ctensor, p float64, train int) {
|
||||
cp := *(*C.double)(unsafe.Pointer(&p))
|
||||
ctrain := *(*C.int)(unsafe.Pointer(&train))
|
||||
|
||||
C.atg_dropout_(ptr, self, cp, ctrain)
|
||||
}
|
||||
|
|
15
nn/conv.go
15
nn/conv.go
|
@ -7,7 +7,6 @@ import (
|
|||
)
|
||||
|
||||
type Conv1DConfig struct {
|
||||
Kval int64
|
||||
Stride []int64
|
||||
Padding []int64
|
||||
Dilation []int64
|
||||
|
@ -18,7 +17,6 @@ type Conv1DConfig struct {
|
|||
}
|
||||
|
||||
type Conv2DConfig struct {
|
||||
Kval int64
|
||||
Stride []int64
|
||||
Padding []int64
|
||||
Dilation []int64
|
||||
|
@ -29,7 +27,6 @@ type Conv2DConfig struct {
|
|||
}
|
||||
|
||||
type Conv3DConfig struct {
|
||||
Kval int64
|
||||
Stride []int64
|
||||
Padding []int64
|
||||
Dilation []int64
|
||||
|
@ -71,14 +68,14 @@ type Conv1D struct {
|
|||
Config Conv1DConfig
|
||||
}
|
||||
|
||||
func NewConv1D(vs *Path, inDim, outDim int64, cfg Conv1DConfig) Conv1D {
|
||||
func NewConv1D(vs *Path, inDim, outDim, k int64, cfg Conv1DConfig) Conv1D {
|
||||
var conv Conv1D
|
||||
conv.Config = cfg
|
||||
if cfg.Bias {
|
||||
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, cfg.Kval)
|
||||
weightSize = append(weightSize, k)
|
||||
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
|
||||
return conv
|
||||
|
@ -90,14 +87,14 @@ type Conv2D struct {
|
|||
Config Conv2DConfig
|
||||
}
|
||||
|
||||
func NewConv2D(vs *Path, inDim, outDim int64, cfg Conv2DConfig) Conv2D {
|
||||
func NewConv2D(vs *Path, inDim, outDim int64, k int64, cfg Conv2DConfig) Conv2D {
|
||||
var conv Conv2D
|
||||
conv.Config = cfg
|
||||
if cfg.Bias {
|
||||
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, cfg.Kval, cfg.Kval)
|
||||
weightSize = append(weightSize, k, k)
|
||||
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
|
||||
return conv
|
||||
|
@ -109,14 +106,14 @@ type Conv3D struct {
|
|||
Config Conv3DConfig
|
||||
}
|
||||
|
||||
func NewConv3D(vs *Path, inDim, outDim int64, cfg Conv3DConfig) Conv3D {
|
||||
func NewConv3D(vs *Path, inDim, outDim, k int64, cfg Conv3DConfig) Conv3D {
|
||||
var conv Conv3D
|
||||
conv.Config = cfg
|
||||
if cfg.Bias {
|
||||
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, cfg.Kval, cfg.Kval, cfg.Kval)
|
||||
weightSize = append(weightSize, k, k, k)
|
||||
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
|
||||
return conv
|
||||
|
|
|
@ -106,20 +106,21 @@ func (s *SequentialT) IsEmpty() (retVal bool) {
|
|||
|
||||
// Implement ModuleT interface for SequentialT:
|
||||
// ==========================================
|
||||
func (s SequentialT) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||
if s.IsEmpty() {
|
||||
return xs.MustShallowClone()
|
||||
}
|
||||
|
||||
// forward sequentially
|
||||
var currTs ts.Tensor = xs
|
||||
for i := 0; i < len(s.layers); i++ {
|
||||
currTs = s.layers[i].Forward(currTs)
|
||||
}
|
||||
|
||||
return currTs
|
||||
}
|
||||
|
||||
/*
|
||||
* func (s SequentialT) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||
* if s.IsEmpty() {
|
||||
* return xs.MustShallowClone()
|
||||
* }
|
||||
*
|
||||
* // forward sequentially
|
||||
* var currTs ts.Tensor = xs
|
||||
* for i := 0; i < len(s.layers); i++ {
|
||||
* currTs = s.layers[i].Forward(currTs)
|
||||
* }
|
||||
*
|
||||
* return currTs
|
||||
* }
|
||||
* */
|
||||
func (s SequentialT) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||
if s.IsEmpty() {
|
||||
return xs.MustShallowClone()
|
||||
|
|
|
@ -18,7 +18,7 @@ type Module interface {
|
|||
// The train parameter is commonly used to have different behavior
|
||||
// between training and evaluation. E.g. When using dropout or batch-normalization.
|
||||
type ModuleT interface {
|
||||
Forward(xs Tensor) Tensor
|
||||
// Forward(xs Tensor) Tensor
|
||||
ForwardT(xs Tensor, train bool) Tensor
|
||||
}
|
||||
|
||||
|
|
|
@ -17,4 +17,8 @@ func (ts Tensor) AccuracyForLogits(targets Tensor) (retVal Tensor) {
|
|||
return ts.MustArgmax(-1, false, true).MustEq1(targets).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true)
|
||||
}
|
||||
|
||||
func (ts Tensor) MaxPool2DDefault(ksize int64, del bool) (retVal Tensor) {
|
||||
return ts.MustMaxPool2D([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, []int64{1, 1}, false, del)
|
||||
}
|
||||
|
||||
// TODO: continue
|
||||
|
|
|
@ -1087,3 +1087,90 @@ func MustConv3D(input, weight, bias Tensor, stride, padding, dilation []int64, g
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) MaxPool2D(kernel []int64, stride []int64, padding []int64, dilation []int64, ceil bool, del bool) (retVal Tensor, err error) {
|
||||
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
var ceilMode int
|
||||
switch ceil {
|
||||
case true:
|
||||
ceilMode = 1
|
||||
case false:
|
||||
ceilMode = 0
|
||||
}
|
||||
|
||||
lib.AtgMaxPool2d(ptr, ts.ctensor, kernel, len(kernel), stride, len(stride), padding, len(padding), dilation, len(dilation), ceilMode)
|
||||
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustMaxPool2D(kernel []int64, stride []int64, padding []int64, dilation []int64, ceil bool, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.MaxPool2D(kernel, stride, padding, dilation, ceil, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func Dropout(input Tensor, p float64, train bool) (retVal Tensor, err error) {
|
||||
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
var ctrain int
|
||||
switch train {
|
||||
case true:
|
||||
ctrain = 1
|
||||
case false:
|
||||
ctrain = 0
|
||||
}
|
||||
|
||||
lib.AtgDropout(ptr, input.ctensor, p, ctrain)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
|
||||
}
|
||||
|
||||
func MustDropout(input Tensor, p float64, train bool) (retVal Tensor) {
|
||||
retVal, err := Dropout(input, p, train)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Dropout_(p float64, train bool) {
|
||||
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
var ctrain int
|
||||
switch train {
|
||||
case true:
|
||||
ctrain = 1
|
||||
case false:
|
||||
ctrain = 0
|
||||
}
|
||||
lib.AtgDropout_(ptr, ts.ctensor, p, ctrain)
|
||||
err := TorchErr()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user