WIP(example/mnist): conv
This commit is contained in:
parent
d480f969bb
commit
9d31337b4f
|
@ -2,8 +2,86 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
"github.com/sugarme/gotch/nn"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
"github.com/sugarme/gotch/vision"
|
||||||
)
|
)
|
||||||
|
|
||||||
func runCNN() {
|
const (
|
||||||
fmt.Println("CNN will be implemented soon...!\n")
|
MnistDirCNN string = "../../data/mnist"
|
||||||
|
|
||||||
|
epochsCNN = 10
|
||||||
|
batchCNN = 256
|
||||||
|
|
||||||
|
LrCNN = 1e-4
|
||||||
|
)
|
||||||
|
|
||||||
|
type Net struct {
|
||||||
|
conv1 nn.Conv2D
|
||||||
|
conv2 nn.Conv2D
|
||||||
|
fc1 nn.Linear
|
||||||
|
fc2 nn.Linear
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNet(vs *nn.Path) Net {
|
||||||
|
conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())
|
||||||
|
conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())
|
||||||
|
fc1 := nn.NewLinear(*vs, 1024, 1024, *nn.DefaultLinearConfig())
|
||||||
|
fc2 := nn.NewLinear(*vs, 1024, 10, *nn.DefaultLinearConfig())
|
||||||
|
|
||||||
|
return Net{
|
||||||
|
conv1,
|
||||||
|
conv2,
|
||||||
|
*fc1,
|
||||||
|
*fc2}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n Net) ForwardT(xs ts.Tensor, train bool) ts.Tensor {
|
||||||
|
out := xs.MustView([]int64{-1, 1, 28, 28}).Apply(n.conv1).MaxPool2DDefault(2, true)
|
||||||
|
out = out.Apply(n.conv2).MaxPool2DDefault(2, true)
|
||||||
|
out = out.MustView([]int64{-1, 1024}).Apply(&n.fc1).MustRelu(true)
|
||||||
|
out.Dropout_(0.5, train)
|
||||||
|
return out.Apply(&n.fc2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runCNN() {
|
||||||
|
|
||||||
|
var ds vision.Dataset
|
||||||
|
ds = vision.LoadMNISTDir(MnistDirNN)
|
||||||
|
cuda := gotch.CudaBuilder(0)
|
||||||
|
vs := nn.NewVarStore(cuda.CudaIfAvailable())
|
||||||
|
path := vs.Root()
|
||||||
|
net := newNet(&path)
|
||||||
|
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for epoch := 0; epoch < epochsCNN; epoch++ {
|
||||||
|
var count = 0
|
||||||
|
for {
|
||||||
|
iter := ds.TrainIter(batchCNN).Shuffle()
|
||||||
|
item, ok := iter.Next()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
loss := net.ForwardT(item.Data.MustTo(vs.Device(), true), true).CrossEntropyForLogits(item.Label.MustTo(vs.Device(), true))
|
||||||
|
opt.BackwardStep(loss)
|
||||||
|
loss.MustDrop()
|
||||||
|
count++
|
||||||
|
if count == 50 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
fmt.Printf("completed \t %v batches\n", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// testAccuracy := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 1024)
|
||||||
|
//
|
||||||
|
// fmt.Printf("Epoch: %v \t Test accuracy: %.2f%%\n", epoch, testAccuracy*100)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -355,3 +355,35 @@ func AtgConv3d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, stride
|
||||||
|
|
||||||
C.atg_conv3d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cgroups)
|
C.atg_conv3d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cgroups)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// void atg_max_pool2d(tensor *, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int ceil_mode);
|
||||||
|
func AtgMaxPool2d(ptr *Ctensor, self Ctensor, kernelSizeData []int64, kernelSizeLen int, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, ceilMode int) {
|
||||||
|
|
||||||
|
ckernelSizeDataPtr := (*C.int64_t)(unsafe.Pointer(&kernelSizeData[0]))
|
||||||
|
ckernelSizeLen := *(*C.int)(unsafe.Pointer(&kernelSizeLen))
|
||||||
|
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
||||||
|
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
|
||||||
|
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
||||||
|
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
||||||
|
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
|
||||||
|
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
|
||||||
|
cceilMode := *(*C.int)(unsafe.Pointer(&ceilMode))
|
||||||
|
|
||||||
|
C.atg_max_pool2d(ptr, self, ckernelSizeDataPtr, ckernelSizeLen, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cceilMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// void atg_dropout(tensor *, tensor input, double p, int train);
|
||||||
|
func AtgDropout(ptr *Ctensor, input Ctensor, p float64, train int) {
|
||||||
|
cp := *(*C.double)(unsafe.Pointer(&p))
|
||||||
|
ctrain := *(*C.int)(unsafe.Pointer(&train))
|
||||||
|
|
||||||
|
C.atg_dropout(ptr, input, cp, ctrain)
|
||||||
|
}
|
||||||
|
|
||||||
|
// void atg_dropout_(tensor *, tensor self, double p, int train);
|
||||||
|
func AtgDropout_(ptr *Ctensor, self Ctensor, p float64, train int) {
|
||||||
|
cp := *(*C.double)(unsafe.Pointer(&p))
|
||||||
|
ctrain := *(*C.int)(unsafe.Pointer(&train))
|
||||||
|
|
||||||
|
C.atg_dropout_(ptr, self, cp, ctrain)
|
||||||
|
}
|
||||||
|
|
15
nn/conv.go
15
nn/conv.go
|
@ -7,7 +7,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conv1DConfig struct {
|
type Conv1DConfig struct {
|
||||||
Kval int64
|
|
||||||
Stride []int64
|
Stride []int64
|
||||||
Padding []int64
|
Padding []int64
|
||||||
Dilation []int64
|
Dilation []int64
|
||||||
|
@ -18,7 +17,6 @@ type Conv1DConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Conv2DConfig struct {
|
type Conv2DConfig struct {
|
||||||
Kval int64
|
|
||||||
Stride []int64
|
Stride []int64
|
||||||
Padding []int64
|
Padding []int64
|
||||||
Dilation []int64
|
Dilation []int64
|
||||||
|
@ -29,7 +27,6 @@ type Conv2DConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Conv3DConfig struct {
|
type Conv3DConfig struct {
|
||||||
Kval int64
|
|
||||||
Stride []int64
|
Stride []int64
|
||||||
Padding []int64
|
Padding []int64
|
||||||
Dilation []int64
|
Dilation []int64
|
||||||
|
@ -71,14 +68,14 @@ type Conv1D struct {
|
||||||
Config Conv1DConfig
|
Config Conv1DConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConv1D(vs *Path, inDim, outDim int64, cfg Conv1DConfig) Conv1D {
|
func NewConv1D(vs *Path, inDim, outDim, k int64, cfg Conv1DConfig) Conv1D {
|
||||||
var conv Conv1D
|
var conv Conv1D
|
||||||
conv.Config = cfg
|
conv.Config = cfg
|
||||||
if cfg.Bias {
|
if cfg.Bias {
|
||||||
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||||
}
|
}
|
||||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||||
weightSize = append(weightSize, cfg.Kval)
|
weightSize = append(weightSize, k)
|
||||||
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||||
|
|
||||||
return conv
|
return conv
|
||||||
|
@ -90,14 +87,14 @@ type Conv2D struct {
|
||||||
Config Conv2DConfig
|
Config Conv2DConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConv2D(vs *Path, inDim, outDim int64, cfg Conv2DConfig) Conv2D {
|
func NewConv2D(vs *Path, inDim, outDim int64, k int64, cfg Conv2DConfig) Conv2D {
|
||||||
var conv Conv2D
|
var conv Conv2D
|
||||||
conv.Config = cfg
|
conv.Config = cfg
|
||||||
if cfg.Bias {
|
if cfg.Bias {
|
||||||
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||||
}
|
}
|
||||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||||
weightSize = append(weightSize, cfg.Kval, cfg.Kval)
|
weightSize = append(weightSize, k, k)
|
||||||
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||||
|
|
||||||
return conv
|
return conv
|
||||||
|
@ -109,14 +106,14 @@ type Conv3D struct {
|
||||||
Config Conv3DConfig
|
Config Conv3DConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConv3D(vs *Path, inDim, outDim int64, cfg Conv3DConfig) Conv3D {
|
func NewConv3D(vs *Path, inDim, outDim, k int64, cfg Conv3DConfig) Conv3D {
|
||||||
var conv Conv3D
|
var conv Conv3D
|
||||||
conv.Config = cfg
|
conv.Config = cfg
|
||||||
if cfg.Bias {
|
if cfg.Bias {
|
||||||
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||||
}
|
}
|
||||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||||
weightSize = append(weightSize, cfg.Kval, cfg.Kval, cfg.Kval)
|
weightSize = append(weightSize, k, k, k)
|
||||||
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||||
|
|
||||||
return conv
|
return conv
|
||||||
|
|
|
@ -106,20 +106,21 @@ func (s *SequentialT) IsEmpty() (retVal bool) {
|
||||||
|
|
||||||
// Implement ModuleT interface for SequentialT:
|
// Implement ModuleT interface for SequentialT:
|
||||||
// ==========================================
|
// ==========================================
|
||||||
func (s SequentialT) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
/*
|
||||||
if s.IsEmpty() {
|
* func (s SequentialT) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||||
return xs.MustShallowClone()
|
* if s.IsEmpty() {
|
||||||
}
|
* return xs.MustShallowClone()
|
||||||
|
* }
|
||||||
// forward sequentially
|
*
|
||||||
var currTs ts.Tensor = xs
|
* // forward sequentially
|
||||||
for i := 0; i < len(s.layers); i++ {
|
* var currTs ts.Tensor = xs
|
||||||
currTs = s.layers[i].Forward(currTs)
|
* for i := 0; i < len(s.layers); i++ {
|
||||||
}
|
* currTs = s.layers[i].Forward(currTs)
|
||||||
|
* }
|
||||||
return currTs
|
*
|
||||||
}
|
* return currTs
|
||||||
|
* }
|
||||||
|
* */
|
||||||
func (s SequentialT) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
func (s SequentialT) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||||
if s.IsEmpty() {
|
if s.IsEmpty() {
|
||||||
return xs.MustShallowClone()
|
return xs.MustShallowClone()
|
||||||
|
|
|
@ -18,7 +18,7 @@ type Module interface {
|
||||||
// The train parameter is commonly used to have different behavior
|
// The train parameter is commonly used to have different behavior
|
||||||
// between training and evaluation. E.g. When using dropout or batch-normalization.
|
// between training and evaluation. E.g. When using dropout or batch-normalization.
|
||||||
type ModuleT interface {
|
type ModuleT interface {
|
||||||
Forward(xs Tensor) Tensor
|
// Forward(xs Tensor) Tensor
|
||||||
ForwardT(xs Tensor, train bool) Tensor
|
ForwardT(xs Tensor, train bool) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,4 +17,8 @@ func (ts Tensor) AccuracyForLogits(targets Tensor) (retVal Tensor) {
|
||||||
return ts.MustArgmax(-1, false, true).MustEq1(targets).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true)
|
return ts.MustArgmax(-1, false, true).MustEq1(targets).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MaxPool2DDefault(ksize int64, del bool) (retVal Tensor) {
|
||||||
|
return ts.MustMaxPool2D([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, []int64{1, 1}, false, del)
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: continue
|
// TODO: continue
|
||||||
|
|
|
@ -1087,3 +1087,90 @@ func MustConv3D(input, weight, bias Tensor, stride, padding, dilation []int64, g
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MaxPool2D(kernel []int64, stride []int64, padding []int64, dilation []int64, ceil bool, del bool) (retVal Tensor, err error) {
|
||||||
|
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
if del {
|
||||||
|
defer ts.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
var ceilMode int
|
||||||
|
switch ceil {
|
||||||
|
case true:
|
||||||
|
ceilMode = 1
|
||||||
|
case false:
|
||||||
|
ceilMode = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
lib.AtgMaxPool2d(ptr, ts.ctensor, kernel, len(kernel), stride, len(stride), padding, len(padding), dilation, len(dilation), ceilMode)
|
||||||
|
|
||||||
|
err = TorchErr()
|
||||||
|
if err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustMaxPool2D(kernel []int64, stride []int64, padding []int64, dilation []int64, ceil bool, del bool) (retVal Tensor) {
|
||||||
|
retVal, err := ts.MaxPool2D(kernel, stride, padding, dilation, ceil, del)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func Dropout(input Tensor, p float64, train bool) (retVal Tensor, err error) {
|
||||||
|
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
|
||||||
|
var ctrain int
|
||||||
|
switch train {
|
||||||
|
case true:
|
||||||
|
ctrain = 1
|
||||||
|
case false:
|
||||||
|
ctrain = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
lib.AtgDropout(ptr, input.ctensor, p, ctrain)
|
||||||
|
err = TorchErr()
|
||||||
|
if err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func MustDropout(input Tensor, p float64, train bool) (retVal Tensor) {
|
||||||
|
retVal, err := Dropout(input, p, train)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Dropout_(p float64, train bool) {
|
||||||
|
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
|
||||||
|
var ctrain int
|
||||||
|
switch train {
|
||||||
|
case true:
|
||||||
|
ctrain = 1
|
||||||
|
case false:
|
||||||
|
ctrain = 0
|
||||||
|
}
|
||||||
|
lib.AtgDropout_(ptr, ts.ctensor, p, ctrain)
|
||||||
|
err := TorchErr()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user