fix(nn/module): ForwardT; feat(example/cifar)

This commit is contained in:
sugarme 2020-07-08 15:19:40 +10:00
parent f33ca7edf1
commit 8107d429b6
7 changed files with 385 additions and 4 deletions

162
example/cifar/main.go Normal file
View File

@ -0,0 +1,162 @@
package main
// Training various models on the CIFAR-10 dataset.
//
// The dataset can be downloaded from https:www.cs.toronto.edu/~kriz/cifar.html, files
// should be placed in the data/ directory.
//
// The resnet model reaches 95.4% accuracy.
import (
"fmt"
"log"
"time"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
"github.com/sugarme/gotch/vision"
)
func convBn(p nn.Path, cIn, cOut int64) (retVal nn.SequentialT) {
config := nn.DefaultConv2DConfig()
config.Padding = []int64{1, 1}
config.Bias = false
seq := nn.SeqT()
seq.Add(nn.NewConv2D(p, cIn, cOut, 3, config))
seq.Add(nn.BatchNorm2D(p, cOut, nn.DefaultBatchNormConfig()))
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
return xs.MustRelu(false)
}))
return seq
}
func layer(p nn.Path, cIn, cOut int64) (retVal nn.FuncT) {
pre := convBn(p.Sub("pre"), cIn, cOut)
block1 := convBn(p.Sub("b1"), cOut, cOut)
block2 := convBn(p.Sub("b2"), cOut, cOut)
return nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
tmp1 := xs.ApplyT(pre, train)
preTs := tmp1.MaxPool2DDefault(2, true)
tmp2 := preTs.ApplyT(block1, train)
ys := tmp2.ApplyT(block2, train)
tmp2.MustDrop()
res := preTs.MustAdd(ys, true)
ys.MustDrop()
return res
})
}
func fastResnet(p nn.Path) (retVal nn.SequentialT) {
seq := nn.SeqT()
seq.Add(convBn(p.Sub("pre"), 3, 64))
seq.Add(layer(p.Sub("layer1"), 64, 128))
seq.Add(convBn(p.Sub("inter"), 128, 256))
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
return xs.MaxPool2DDefault(2, false)
}))
seq.Add(layer(p.Sub("layer2"), 256, 512))
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
tmp := xs.MaxPool2DDefault(4, false)
res := tmp.FlatView()
tmp.MustDrop()
return res
}))
seq.Add(nn.NewLinear(p.Sub("linear"), 512, 10, nn.DefaultLinearConfig()))
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
return xs.MustMul1(ts.FloatScalar(0.125), false)
}))
return seq
}
func learningRate(epoch int) (retVal float64) {
switch {
case epoch < 50:
return 0.1
case epoch < 100:
return 0.01
default:
return 0.001
}
}
func main() {
dir := "../../data/cifar10"
ds := vision.CFLoadDir(dir)
fmt.Printf("TrainImages shape: %v\n", ds.TrainImages.MustSize())
fmt.Printf("TrainLabel shape: %v\n", ds.TrainLabels.MustSize())
fmt.Printf("TestImages shape: %v\n", ds.TestImages.MustSize())
fmt.Printf("TestLabel shape: %v\n", ds.TestLabels.MustSize())
fmt.Printf("Number of labels: %v\n", ds.Labels)
cuda := gotch.CudaBuilder(0)
device := cuda.CudaIfAvailable()
// device := gotch.CPU
vs := nn.NewVarStore(device)
net := fastResnet(vs.Root())
optConfig := nn.NewSGDConfig(0.9, 0.0, 5e-4, true)
opt, err := optConfig.Build(vs, 0.0)
if err != nil {
log.Fatal(err)
}
var lossVal float64
startTime := time.Now()
for epoch := 0; epoch < 150; epoch++ {
opt.SetLR(learningRate(epoch))
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
iter.Shuffle()
iter = iter.ToDevice(device)
for {
item, ok := iter.Next()
if !ok {
break
}
// bimages := vision.Augmentation(item.Data, true, 4, 8)
// logits := net.ForwardT(bimages, true)
bImages := item.Data.MustTo(vs.Device(), true)
bLabels := item.Label.MustTo(vs.Device(), true)
// // logits := net.ForwardT(item.Data, true)
logits := net.ForwardT(bImages, true)
// // loss := logits.CrossEntropyForLogits(item.Label)
loss := logits.CrossEntropyForLogits(bLabels)
opt.BackwardStep(loss)
lossVal = loss.Values()[0]
// logits.MustDrop()
bImages.MustDrop()
bLabels.MustDrop()
loss.MustDrop()
}
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, lossVal)
}
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
fmt.Printf("Loss: \t %.2f\t Accuracy: %.2f\n", lossVal, testAcc*100)
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
}

View File

@ -624,3 +624,21 @@ func AtgConstantPadNd(ptr *Ctensor, self Ctensor, padData []int64, padLen int) {
func AtgSigmoid(ptr *Ctensor, self Ctensor) {
C.atg_sigmoid(ptr, self)
}
// void atg_flip(tensor *, tensor self, int64_t *dims_data, int dims_len);
func AtgFlip(ptr *Ctensor, self Ctensor, dimsData []int64, dimsLen int) {
cdimsDataPtr := (*C.int64_t)(unsafe.Pointer(&dimsData[0]))
cdimsLen := *(*C.int)(unsafe.Pointer(&dimsLen))
C.atg_flip(ptr, self, cdimsDataPtr, cdimsLen)
}
// void atg_reflection_pad2d(tensor *, tensor self, int64_t *padding_data, int padding_len);
func AtgReflectionPad2d(ptr *Ctensor, self Ctensor, paddingData []int64, paddingLen int) {
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
C.atg_reflection_pad2d(ptr, self, cpaddingDataPtr, cpaddingLen)
}

View File

@ -145,12 +145,21 @@ func (s SequentialT) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
}
// forward sequentially
var currTs ts.Tensor = xs
outs := make([]ts.Tensor, len(s.layers))
for i := 0; i < len(s.layers); i++ {
currTs = s.layers[i].ForwardT(currTs, train)
if i == 0 {
outs[0] = s.layers[i].ForwardT(xs, train)
defer outs[0].MustDrop()
} else if i == len(s.layers)-1 {
return s.layers[i].ForwardT(outs[i-1], train)
} else {
outs[i] = s.layers[i].ForwardT(outs[i-1], train)
defer outs[i].MustDrop()
}
}
return currTs
return
}
// Add appends a layer after all the current layers.

View File

@ -123,9 +123,14 @@ func (it *Iter2) Next() (item Iter2Item, ok bool) {
// Indexing
narrowIndex := NewNarrow(start, start+size)
// data := it.xs.Idx(narrowIndex).MustTo(it.device, false)
// label := it.ys.Idx(narrowIndex).MustTo(it.device, false)
return Iter2Item{
Data: it.xs.Idx(narrowIndex),
Label: it.ys.Idx(narrowIndex),
// Data: data,
// Label: label,
}, true
}
}

View File

@ -1920,3 +1920,49 @@ func (ts Tensor) MustSigmoid(del bool) (retVal Tensor) {
return retVal
}
func (ts Tensor) Flip(dims []int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgFlip(ptr, ts.ctensor, dims, len(dims))
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustFlip(dims []int64) (retVal Tensor) {
retVal, err := ts.Flip(dims)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) ReflectionPad2d(paddingData []int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgReflectionPad2d(ptr, ts.ctensor, paddingData, len(paddingData))
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustReflectionPad2d(paddingData []int64) (retVal Tensor) {
retVal, err := ts.ReflectionPad2d(paddingData)
if err != nil {
log.Fatal(err)
}
return retVal
}

View File

@ -46,7 +46,7 @@ func readFile(filename string) (imagesTs ts.Tensor, labelsTs ts.Tensor) {
}
images := ts.MustZeros([]int64{samplesPerFile, cfC, cfH, cfW}, gotch.Float.CInt(), gotch.CPU.CInt())
labels := ts.MustZeros([]int64{samplesPerFile}, gotch.Float.CInt(), gotch.CPU.CInt())
labels := ts.MustZeros([]int64{samplesPerFile}, gotch.Int64.CInt(), gotch.CPU.CInt())
for idx := 0; idx < int(samplesPerFile); idx++ {
contentOffset := int(bytesPerImage) * idx

View File

@ -3,6 +3,10 @@ package vision
// A simple dataset structure shared by various computer vision datasets.
import (
"log"
"math/rand"
"time"
ts "github.com/sugarme/gotch/tensor"
)
@ -27,3 +31,140 @@ func (ds Dataset) TrainIter(batchSize int64) (retVal ts.Iter2) {
func (ds Dataset) TestIter(batchSize int64) (retVal ts.Iter2) {
return ts.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize)
}
// RandomFlip randomly applies horizontal flips
// This expects a 4 dimension NCHW tensor and returns a tensor with
// an identical shape.
func RandomFlip(t ts.Tensor) (retVal ts.Tensor) {
size := t.MustSize()
if len(size) != 4 {
log.Fatalf("Unexpected shape for tensor %v\n", size)
}
output, err := t.ZerosLike(false)
if err != nil {
panic(err)
}
for batchIdx := 0; batchIdx < int(size[0]); batchIdx++ {
outputView := output.Idx(ts.NewSelect(int64(batchIdx)))
tView := t.Idx(ts.NewSelect(int64(batchIdx)))
var src ts.Tensor
if rand.Float64() == 1.0 {
src = tView
} else {
src = tView.MustFlip([]int64{2})
}
outputView.Copy_(src)
}
return output
}
// Pad the image using reflections and take some random crops.
// This expects a 4 dimension NCHW tensor and returns a tensor with
// an identical shape.
func RandomCrop(t ts.Tensor, pad int64) (retVal ts.Tensor) {
size := t.MustSize()
if len(size) < 4 {
log.Fatalf("Unexpected shape (%v) for tensor %v\n", size, t)
}
szH := size[2]
szW := size[3]
padded := t.MustReflectionPad2d([]int64{pad, pad, pad, pad})
output, err := t.ZerosLike(false)
if err != nil {
log.Fatal(err)
}
for bidx := 0; bidx < int(size[0]); bidx++ {
idx := ts.NewSelect(int64(bidx))
outputView := output.Idx(idx)
rand.Seed(time.Now().UnixNano())
startW := rand.Intn(int(2 * pad))
startH := rand.Intn(int(2 * pad))
var srcIdx []ts.TensorIndexer
nIdx := ts.NewSelect(int64(bidx))
cIdx := ts.NewSelect(int64(-1))
hIdx := ts.NewNarrow(int64(startH), int64(startH)+szH)
wIdx := ts.NewNarrow(int64(startW), int64(startW)+szW)
srcIdx = append(srcIdx, nIdx, cIdx, hIdx, wIdx)
src := padded.Idx(srcIdx)
outputView.Copy_(src)
}
return output
}
// Applies cutout: randomly remove some square areas in the original images.
// https://arxiv.org/abs/1708.04552
func RandomCutout(t ts.Tensor, sz int64) (retVal ts.Tensor) {
size := t.MustSize()
if len(size) != 4 || sz > size[2] || sz > size[3] {
log.Fatalf("Unexpected shape (%v) for tensor %v\n", size, t)
}
output, err := t.ZerosLike(false)
if err != nil {
log.Fatal(err)
}
output.Copy_(t)
for bidx := 0; bidx < int(size[0]); bidx++ {
rand.Seed(time.Now().UnixNano())
startH := rand.Intn(int(size[2] - sz + 1))
startW := rand.Intn(int(size[3] - sz + 1))
var srcIdx []ts.TensorIndexer
nIdx := ts.NewSelect(int64(bidx))
cIdx := ts.NewSelect(int64(-1))
hIdx := ts.NewNarrow(int64(startH), int64(startH)+sz)
wIdx := ts.NewNarrow(int64(startW), int64(startW)+sz)
srcIdx = append(srcIdx, nIdx, cIdx, hIdx, wIdx)
output.Idx(srcIdx)
output.Fill_(ts.FloatScalar(0.0))
}
return output
}
func Augmentation(t ts.Tensor, flip bool, crop int64, cutout int64) (retVal ts.Tensor) {
tclone := t.MustShallowClone()
var flipTs ts.Tensor
if flip {
flipTs = RandomFlip(tclone)
} else {
flipTs = tclone
}
var cropTs ts.Tensor
if crop > 0 {
cropTs = RandomCrop(flipTs, crop)
} else {
cropTs = flipTs
}
if cutout > 0 {
retVal = RandomCutout(cropTs, cutout)
} else {
retVal = cropTs
}
return retVal
}