fix(nn/module): ForwardT; feat(example/cifar)
This commit is contained in:
parent
f33ca7edf1
commit
8107d429b6
162
example/cifar/main.go
Normal file
162
example/cifar/main.go
Normal file
|
@ -0,0 +1,162 @@
|
|||
package main
|
||||
|
||||
// Training various models on the CIFAR-10 dataset.
|
||||
//
|
||||
// The dataset can be downloaded from https:www.cs.toronto.edu/~kriz/cifar.html, files
|
||||
// should be placed in the data/ directory.
|
||||
//
|
||||
// The resnet model reaches 95.4% accuracy.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
func convBn(p nn.Path, cIn, cOut int64) (retVal nn.SequentialT) {
|
||||
config := nn.DefaultConv2DConfig()
|
||||
config.Padding = []int64{1, 1}
|
||||
config.Bias = false
|
||||
|
||||
seq := nn.SeqT()
|
||||
|
||||
seq.Add(nn.NewConv2D(p, cIn, cOut, 3, config))
|
||||
seq.Add(nn.BatchNorm2D(p, cOut, nn.DefaultBatchNormConfig()))
|
||||
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||
return xs.MustRelu(false)
|
||||
}))
|
||||
|
||||
return seq
|
||||
}
|
||||
|
||||
func layer(p nn.Path, cIn, cOut int64) (retVal nn.FuncT) {
|
||||
pre := convBn(p.Sub("pre"), cIn, cOut)
|
||||
|
||||
block1 := convBn(p.Sub("b1"), cOut, cOut)
|
||||
block2 := convBn(p.Sub("b2"), cOut, cOut)
|
||||
|
||||
return nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
||||
tmp1 := xs.ApplyT(pre, train)
|
||||
preTs := tmp1.MaxPool2DDefault(2, true)
|
||||
tmp2 := preTs.ApplyT(block1, train)
|
||||
ys := tmp2.ApplyT(block2, train)
|
||||
tmp2.MustDrop()
|
||||
|
||||
res := preTs.MustAdd(ys, true)
|
||||
ys.MustDrop()
|
||||
|
||||
return res
|
||||
})
|
||||
}
|
||||
|
||||
func fastResnet(p nn.Path) (retVal nn.SequentialT) {
|
||||
seq := nn.SeqT()
|
||||
|
||||
seq.Add(convBn(p.Sub("pre"), 3, 64))
|
||||
seq.Add(layer(p.Sub("layer1"), 64, 128))
|
||||
seq.Add(convBn(p.Sub("inter"), 128, 256))
|
||||
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||
return xs.MaxPool2DDefault(2, false)
|
||||
}))
|
||||
seq.Add(layer(p.Sub("layer2"), 256, 512))
|
||||
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||
tmp := xs.MaxPool2DDefault(4, false)
|
||||
res := tmp.FlatView()
|
||||
tmp.MustDrop()
|
||||
|
||||
return res
|
||||
}))
|
||||
|
||||
seq.Add(nn.NewLinear(p.Sub("linear"), 512, 10, nn.DefaultLinearConfig()))
|
||||
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||
return xs.MustMul1(ts.FloatScalar(0.125), false)
|
||||
}))
|
||||
|
||||
return seq
|
||||
}
|
||||
|
||||
func learningRate(epoch int) (retVal float64) {
|
||||
switch {
|
||||
case epoch < 50:
|
||||
return 0.1
|
||||
case epoch < 100:
|
||||
return 0.01
|
||||
default:
|
||||
return 0.001
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
dir := "../../data/cifar10"
|
||||
ds := vision.CFLoadDir(dir)
|
||||
|
||||
fmt.Printf("TrainImages shape: %v\n", ds.TrainImages.MustSize())
|
||||
fmt.Printf("TrainLabel shape: %v\n", ds.TrainLabels.MustSize())
|
||||
fmt.Printf("TestImages shape: %v\n", ds.TestImages.MustSize())
|
||||
fmt.Printf("TestLabel shape: %v\n", ds.TestLabels.MustSize())
|
||||
fmt.Printf("Number of labels: %v\n", ds.Labels)
|
||||
|
||||
cuda := gotch.CudaBuilder(0)
|
||||
device := cuda.CudaIfAvailable()
|
||||
// device := gotch.CPU
|
||||
|
||||
vs := nn.NewVarStore(device)
|
||||
|
||||
net := fastResnet(vs.Root())
|
||||
|
||||
optConfig := nn.NewSGDConfig(0.9, 0.0, 5e-4, true)
|
||||
opt, err := optConfig.Build(vs, 0.0)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var lossVal float64
|
||||
startTime := time.Now()
|
||||
|
||||
for epoch := 0; epoch < 150; epoch++ {
|
||||
opt.SetLR(learningRate(epoch))
|
||||
|
||||
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
|
||||
iter.Shuffle()
|
||||
iter = iter.ToDevice(device)
|
||||
|
||||
for {
|
||||
item, ok := iter.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
// bimages := vision.Augmentation(item.Data, true, 4, 8)
|
||||
// logits := net.ForwardT(bimages, true)
|
||||
|
||||
bImages := item.Data.MustTo(vs.Device(), true)
|
||||
bLabels := item.Label.MustTo(vs.Device(), true)
|
||||
|
||||
// // logits := net.ForwardT(item.Data, true)
|
||||
logits := net.ForwardT(bImages, true)
|
||||
// // loss := logits.CrossEntropyForLogits(item.Label)
|
||||
loss := logits.CrossEntropyForLogits(bLabels)
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
lossVal = loss.Values()[0]
|
||||
|
||||
// logits.MustDrop()
|
||||
bImages.MustDrop()
|
||||
bLabels.MustDrop()
|
||||
loss.MustDrop()
|
||||
|
||||
}
|
||||
|
||||
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, lossVal)
|
||||
|
||||
}
|
||||
|
||||
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
fmt.Printf("Loss: \t %.2f\t Accuracy: %.2f\n", lossVal, testAcc*100)
|
||||
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
||||
}
|
|
@ -624,3 +624,21 @@ func AtgConstantPadNd(ptr *Ctensor, self Ctensor, padData []int64, padLen int) {
|
|||
func AtgSigmoid(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_sigmoid(ptr, self)
|
||||
}
|
||||
|
||||
// void atg_flip(tensor *, tensor self, int64_t *dims_data, int dims_len);
|
||||
func AtgFlip(ptr *Ctensor, self Ctensor, dimsData []int64, dimsLen int) {
|
||||
|
||||
cdimsDataPtr := (*C.int64_t)(unsafe.Pointer(&dimsData[0]))
|
||||
cdimsLen := *(*C.int)(unsafe.Pointer(&dimsLen))
|
||||
|
||||
C.atg_flip(ptr, self, cdimsDataPtr, cdimsLen)
|
||||
}
|
||||
|
||||
// void atg_reflection_pad2d(tensor *, tensor self, int64_t *padding_data, int padding_len);
|
||||
func AtgReflectionPad2d(ptr *Ctensor, self Ctensor, paddingData []int64, paddingLen int) {
|
||||
|
||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
||||
|
||||
C.atg_reflection_pad2d(ptr, self, cpaddingDataPtr, cpaddingLen)
|
||||
}
|
||||
|
|
|
@ -145,12 +145,21 @@ func (s SequentialT) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
|||
}
|
||||
|
||||
// forward sequentially
|
||||
var currTs ts.Tensor = xs
|
||||
outs := make([]ts.Tensor, len(s.layers))
|
||||
for i := 0; i < len(s.layers); i++ {
|
||||
currTs = s.layers[i].ForwardT(currTs, train)
|
||||
if i == 0 {
|
||||
outs[0] = s.layers[i].ForwardT(xs, train)
|
||||
defer outs[0].MustDrop()
|
||||
} else if i == len(s.layers)-1 {
|
||||
return s.layers[i].ForwardT(outs[i-1], train)
|
||||
} else {
|
||||
outs[i] = s.layers[i].ForwardT(outs[i-1], train)
|
||||
defer outs[i].MustDrop()
|
||||
}
|
||||
}
|
||||
|
||||
return currTs
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
// Add appends a layer after all the current layers.
|
||||
|
|
|
@ -123,9 +123,14 @@ func (it *Iter2) Next() (item Iter2Item, ok bool) {
|
|||
// Indexing
|
||||
narrowIndex := NewNarrow(start, start+size)
|
||||
|
||||
// data := it.xs.Idx(narrowIndex).MustTo(it.device, false)
|
||||
// label := it.ys.Idx(narrowIndex).MustTo(it.device, false)
|
||||
|
||||
return Iter2Item{
|
||||
Data: it.xs.Idx(narrowIndex),
|
||||
Label: it.ys.Idx(narrowIndex),
|
||||
// Data: data,
|
||||
// Label: label,
|
||||
}, true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1920,3 +1920,49 @@ func (ts Tensor) MustSigmoid(del bool) (retVal Tensor) {
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Flip(dims []int64) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtgFlip(ptr, ts.ctensor, dims, len(dims))
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
|
||||
}
|
||||
|
||||
func (ts Tensor) MustFlip(dims []int64) (retVal Tensor) {
|
||||
retVal, err := ts.Flip(dims)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) ReflectionPad2d(paddingData []int64) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtgReflectionPad2d(ptr, ts.ctensor, paddingData, len(paddingData))
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
|
||||
}
|
||||
|
||||
func (ts Tensor) MustReflectionPad2d(paddingData []int64) (retVal Tensor) {
|
||||
retVal, err := ts.ReflectionPad2d(paddingData)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ func readFile(filename string) (imagesTs ts.Tensor, labelsTs ts.Tensor) {
|
|||
}
|
||||
|
||||
images := ts.MustZeros([]int64{samplesPerFile, cfC, cfH, cfW}, gotch.Float.CInt(), gotch.CPU.CInt())
|
||||
labels := ts.MustZeros([]int64{samplesPerFile}, gotch.Float.CInt(), gotch.CPU.CInt())
|
||||
labels := ts.MustZeros([]int64{samplesPerFile}, gotch.Int64.CInt(), gotch.CPU.CInt())
|
||||
|
||||
for idx := 0; idx < int(samplesPerFile); idx++ {
|
||||
contentOffset := int(bytesPerImage) * idx
|
||||
|
|
|
@ -3,6 +3,10 @@ package vision
|
|||
// A simple dataset structure shared by various computer vision datasets.
|
||||
|
||||
import (
|
||||
"log"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
|
@ -27,3 +31,140 @@ func (ds Dataset) TrainIter(batchSize int64) (retVal ts.Iter2) {
|
|||
func (ds Dataset) TestIter(batchSize int64) (retVal ts.Iter2) {
|
||||
return ts.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize)
|
||||
}
|
||||
|
||||
// RandomFlip randomly applies horizontal flips
|
||||
// This expects a 4 dimension NCHW tensor and returns a tensor with
|
||||
// an identical shape.
|
||||
func RandomFlip(t ts.Tensor) (retVal ts.Tensor) {
|
||||
|
||||
size := t.MustSize()
|
||||
|
||||
if len(size) != 4 {
|
||||
log.Fatalf("Unexpected shape for tensor %v\n", size)
|
||||
}
|
||||
|
||||
output, err := t.ZerosLike(false)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for batchIdx := 0; batchIdx < int(size[0]); batchIdx++ {
|
||||
outputView := output.Idx(ts.NewSelect(int64(batchIdx)))
|
||||
tView := t.Idx(ts.NewSelect(int64(batchIdx)))
|
||||
|
||||
var src ts.Tensor
|
||||
if rand.Float64() == 1.0 {
|
||||
src = tView
|
||||
} else {
|
||||
src = tView.MustFlip([]int64{2})
|
||||
}
|
||||
|
||||
outputView.Copy_(src)
|
||||
}
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
// Pad the image using reflections and take some random crops.
|
||||
// This expects a 4 dimension NCHW tensor and returns a tensor with
|
||||
// an identical shape.
|
||||
func RandomCrop(t ts.Tensor, pad int64) (retVal ts.Tensor) {
|
||||
|
||||
size := t.MustSize()
|
||||
|
||||
if len(size) < 4 {
|
||||
log.Fatalf("Unexpected shape (%v) for tensor %v\n", size, t)
|
||||
}
|
||||
|
||||
szH := size[2]
|
||||
szW := size[3]
|
||||
padded := t.MustReflectionPad2d([]int64{pad, pad, pad, pad})
|
||||
output, err := t.ZerosLike(false)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for bidx := 0; bidx < int(size[0]); bidx++ {
|
||||
idx := ts.NewSelect(int64(bidx))
|
||||
outputView := output.Idx(idx)
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
startW := rand.Intn(int(2 * pad))
|
||||
startH := rand.Intn(int(2 * pad))
|
||||
|
||||
var srcIdx []ts.TensorIndexer
|
||||
nIdx := ts.NewSelect(int64(bidx))
|
||||
cIdx := ts.NewSelect(int64(-1))
|
||||
hIdx := ts.NewNarrow(int64(startH), int64(startH)+szH)
|
||||
wIdx := ts.NewNarrow(int64(startW), int64(startW)+szW)
|
||||
srcIdx = append(srcIdx, nIdx, cIdx, hIdx, wIdx)
|
||||
src := padded.Idx(srcIdx)
|
||||
outputView.Copy_(src)
|
||||
}
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
// Applies cutout: randomly remove some square areas in the original images.
|
||||
// https://arxiv.org/abs/1708.04552
|
||||
func RandomCutout(t ts.Tensor, sz int64) (retVal ts.Tensor) {
|
||||
|
||||
size := t.MustSize()
|
||||
|
||||
if len(size) != 4 || sz > size[2] || sz > size[3] {
|
||||
log.Fatalf("Unexpected shape (%v) for tensor %v\n", size, t)
|
||||
}
|
||||
|
||||
output, err := t.ZerosLike(false)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
output.Copy_(t)
|
||||
|
||||
for bidx := 0; bidx < int(size[0]); bidx++ {
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
startH := rand.Intn(int(size[2] - sz + 1))
|
||||
startW := rand.Intn(int(size[3] - sz + 1))
|
||||
|
||||
var srcIdx []ts.TensorIndexer
|
||||
nIdx := ts.NewSelect(int64(bidx))
|
||||
cIdx := ts.NewSelect(int64(-1))
|
||||
hIdx := ts.NewNarrow(int64(startH), int64(startH)+sz)
|
||||
wIdx := ts.NewNarrow(int64(startW), int64(startW)+sz)
|
||||
srcIdx = append(srcIdx, nIdx, cIdx, hIdx, wIdx)
|
||||
|
||||
output.Idx(srcIdx)
|
||||
output.Fill_(ts.FloatScalar(0.0))
|
||||
}
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
func Augmentation(t ts.Tensor, flip bool, crop int64, cutout int64) (retVal ts.Tensor) {
|
||||
|
||||
tclone := t.MustShallowClone()
|
||||
|
||||
var flipTs ts.Tensor
|
||||
if flip {
|
||||
flipTs = RandomFlip(tclone)
|
||||
} else {
|
||||
flipTs = tclone
|
||||
}
|
||||
|
||||
var cropTs ts.Tensor
|
||||
if crop > 0 {
|
||||
cropTs = RandomCrop(flipTs, crop)
|
||||
} else {
|
||||
cropTs = flipTs
|
||||
}
|
||||
|
||||
if cutout > 0 {
|
||||
retVal = RandomCutout(cropTs, cutout)
|
||||
} else {
|
||||
retVal = cropTs
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user