This commit is contained in:
sugarme 2020-07-09 18:24:37 +10:00
parent 49cdc52bd8
commit f4f65963ce
6 changed files with 154 additions and 26 deletions

View File

@ -0,0 +1,89 @@
package main
// Training various models on the CIFAR-10 dataset.
//
// The dataset can be downloaded from https:www.cs.toronto.edu/~kriz/cifar.html, files
// should be placed in the data/ directory.
//
// The resnet model reaches 95.4% accuracy.
import (
"fmt"
// "log"
// "os/exec"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
"github.com/sugarme/gotch/vision"
)
func main() {
dir := "../../data/cifar10"
ds := vision.CFLoadDir(dir)
fmt.Printf("TrainImages shape: %v\n", ds.TrainImages.MustSize())
fmt.Printf("TrainLabel shape: %v\n", ds.TrainLabels.MustSize())
fmt.Printf("TestImages shape: %v\n", ds.TestImages.MustSize())
fmt.Printf("TestLabel shape: %v\n", ds.TestLabels.MustSize())
fmt.Printf("Number of labels: %v\n", ds.Labels)
// cuda := gotch.CudaBuilder(0)
// device := cuda.CudaIfAvailable()
device := gotch.CPU
var si *gotch.SI
si = gotch.GetSysInfo()
fmt.Printf("Total RAM (MB):\t %8.2f\n", float64(si.TotalRam)/1024)
fmt.Printf("Used RAM (MB):\t %8.2f\n", float64(si.TotalRam-si.FreeRam)/1024)
startRAM := si.TotalRam - si.FreeRam
vs := nn.NewVarStore(device)
for epoch := 0; epoch < 150; epoch++ {
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
iter.Shuffle()
for {
item, ok := iter.Next()
if !ok {
item.Data.MustDrop()
item.Label.MustDrop()
break
}
devicedData := item.Data.MustTo(vs.Device(), true)
devicedLabel := item.Label.MustTo(vs.Device(), true)
// bimages := vision.Augmentation(devicedData, true, 4, 8)
// NOTE: memory blow-up at augmentation/RandomCutout
bimages := vision.Augmentation(devicedData, true, 4, 0)
devicedData.MustDrop()
devicedLabel.MustDrop()
bimages.MustDrop()
}
iter.Drop()
si = gotch.GetSysInfo()
memUsed := (float64(si.TotalRam-si.FreeRam) - float64(startRAM)) / 1024
fmt.Printf("Epoch:\t %v\t Memory Used:\t [%8.2f MiB]\n", epoch, memUsed)
/*
* // Print out GPU used
* nvidia := "nvidia-smi"
* cmd := exec.Command(nvidia)
* stdout, err := cmd.Output()
*
* if err != nil {
* log.Fatal(err.Error())
* }
*
* fmt.Println(string(stdout))
* */
}
}

View File

@ -83,11 +83,11 @@ func fastResnet(p nn.Path) (retVal nn.SequentialT) {
func learningRate(epoch int) (retVal float64) {
switch {
case epoch < 50:
return float64(0.1)
return 0.1
case epoch < 100:
return float64(0.01)
return 0.01
default:
return float64(0.001)
return 0.001
}
}
@ -110,7 +110,7 @@ func main() {
net := fastResnet(vs.Root())
optConfig := nn.NewSGDConfig(0.9, 0.0, 5e-4, true)
opt, err := optConfig.Build(vs, 0.1)
opt, err := optConfig.Build(vs, 0.01)
if err != nil {
log.Fatal(err)
}
@ -154,9 +154,9 @@ func main() {
loss.MustDrop()
}
// testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
// fmt.Printf("Epoch:\t %v\t Memory Used:\t [%8.2f MiB]\tLoss: \t %.3f \tAcc: %10.2f%%\n", epoch, memUsed, lossVal, testAcc*100.0)
fmt.Printf("Epoch: %10.0d\tLoss:%10.3f\n", epoch, lossVal)
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
fmt.Printf("Epoch:\t %v\t Loss: \t %.3f \tAcc: %10.2f%%\n", epoch, lossVal, testAcc*100.0)
// fmt.Printf("Epoch: %10.0d\tLoss:%10.3f\n", epoch, lossVal)
iter.Drop()
/*

View File

@ -0,0 +1,27 @@
package main
import (
// "github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
func main() {
data := [][]int64{
{1, 1, 1, 2, 2, 2, 3, 3},
{1, 1, 1, 2, 2, 2, 4, 4},
}
// shape := []int64{2, 8}
shape := []int64{2, 2, 4}
t, err := ts.NewTensorFromData(data, shape)
if err != nil {
panic(err)
}
t.Print()
idx := ts.NewNarrow(0, 3)
selTs := t.Idx(idx)
selTs.Print()
}

View File

@ -251,8 +251,6 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
// `spec` is a function type implements `TensorIndexer`
for _, spec := range indexSpec {
// fmt.Printf("spec type: %v\n", reflect.TypeOf(spec).Name())
switch reflect.TypeOf(spec).Name() {
case "InsertNewAxis":
nextTensor, err = currTensor.Unsqueeze(currIdx, true)
@ -299,7 +297,6 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
}
retVal = currTensor
return retVal, nil
}

View File

@ -2,6 +2,7 @@ package tensor
import (
// "unsafe"
"log"
lib "github.com/sugarme/gotch/libtch"
)
@ -63,3 +64,11 @@ func (sc Scalar) Drop() (err error) {
lib.AtsFree(sc.cscalar)
return TorchErr()
}
func (sc Scalar) MustDrop() {
lib.AtsFree(sc.cscalar)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
}

View File

@ -3,6 +3,7 @@ package vision
// A simple dataset structure shared by various computer vision datasets.
import (
// "fmt"
"log"
"math/rand"
"time"
@ -122,6 +123,11 @@ func RandomCutout(t ts.Tensor, sz int64) (retVal ts.Tensor) {
log.Fatalf("Unexpected shape (%v) for tensor %v\n", size, t)
}
// output, err := t.ShallowClone()
// if err != nil {
// log.Fatal(err)
// }
output, err := t.ZerosLike(false)
if err != nil {
log.Fatal(err)
@ -137,14 +143,17 @@ func RandomCutout(t ts.Tensor, sz int64) (retVal ts.Tensor) {
var srcIdx []ts.TensorIndexer
nIdx := ts.NewSelect(int64(bidx))
cIdx := ts.NewSelect(int64(-1))
cIdx := ts.NewNarrow(0, size[1])
hIdx := ts.NewNarrow(int64(startH), int64(startH)+sz)
wIdx := ts.NewNarrow(int64(startW), int64(startW)+sz)
srcIdx = append(srcIdx, nIdx, cIdx, hIdx, wIdx)
outputView := output.Idx(srcIdx)
outputView.Fill_(ts.FloatScalar(0.0))
outputView.MustDrop()
// TODO: there's memory blow-up here. Need to fix.
// view := output.Idx(srcIdx)
// zeroSc := ts.FloatScalar(0.0)
// view.Fill_(zeroSc)
// zeroSc.MustDrop()
// view.MustDrop()
}
return output
@ -157,28 +166,25 @@ func Augmentation(t ts.Tensor, flip bool, crop int64, cutout int64) (retVal ts.T
var flipTs ts.Tensor
if flip {
flipTs = RandomFlip(tclone)
tclone.MustDrop()
} else {
flipTs = tclone
}
tclone.MustDrop()
var cropTs ts.Tensor
if crop > 0 {
cropTs = RandomCrop(flipTs, crop)
flipTs.MustDrop()
} else {
cropTs = flipTs
}
flipTs.MustDrop()
return cropTs
if cutout > 0 {
retVal = RandomCutout(cropTs, cutout)
cropTs.MustDrop()
} else {
retVal = cropTs
}
// if cutout > 0 {
// retVal = RandomCutout(cropTs, cutout)
// } else {
// retVal = cropTs
// }
//
// cropTs.MustDrop()
// return retVal
return retVal
}