updated
This commit is contained in:
parent
49cdc52bd8
commit
f4f65963ce
89
example/augmentation/main.go
Normal file
89
example/augmentation/main.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package main
|
||||
|
||||
// Training various models on the CIFAR-10 dataset.
|
||||
//
|
||||
// The dataset can be downloaded from https:www.cs.toronto.edu/~kriz/cifar.html, files
|
||||
// should be placed in the data/ directory.
|
||||
//
|
||||
// The resnet model reaches 95.4% accuracy.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
// "log"
|
||||
// "os/exec"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
func main() {
|
||||
dir := "../../data/cifar10"
|
||||
ds := vision.CFLoadDir(dir)
|
||||
|
||||
fmt.Printf("TrainImages shape: %v\n", ds.TrainImages.MustSize())
|
||||
fmt.Printf("TrainLabel shape: %v\n", ds.TrainLabels.MustSize())
|
||||
fmt.Printf("TestImages shape: %v\n", ds.TestImages.MustSize())
|
||||
fmt.Printf("TestLabel shape: %v\n", ds.TestLabels.MustSize())
|
||||
fmt.Printf("Number of labels: %v\n", ds.Labels)
|
||||
|
||||
// cuda := gotch.CudaBuilder(0)
|
||||
// device := cuda.CudaIfAvailable()
|
||||
device := gotch.CPU
|
||||
|
||||
var si *gotch.SI
|
||||
si = gotch.GetSysInfo()
|
||||
fmt.Printf("Total RAM (MB):\t %8.2f\n", float64(si.TotalRam)/1024)
|
||||
fmt.Printf("Used RAM (MB):\t %8.2f\n", float64(si.TotalRam-si.FreeRam)/1024)
|
||||
|
||||
startRAM := si.TotalRam - si.FreeRam
|
||||
|
||||
vs := nn.NewVarStore(device)
|
||||
|
||||
for epoch := 0; epoch < 150; epoch++ {
|
||||
|
||||
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
|
||||
iter.Shuffle()
|
||||
|
||||
for {
|
||||
item, ok := iter.Next()
|
||||
if !ok {
|
||||
item.Data.MustDrop()
|
||||
item.Label.MustDrop()
|
||||
break
|
||||
}
|
||||
|
||||
devicedData := item.Data.MustTo(vs.Device(), true)
|
||||
devicedLabel := item.Label.MustTo(vs.Device(), true)
|
||||
// bimages := vision.Augmentation(devicedData, true, 4, 8)
|
||||
// NOTE: memory blow-up at augmentation/RandomCutout
|
||||
bimages := vision.Augmentation(devicedData, true, 4, 0)
|
||||
|
||||
devicedData.MustDrop()
|
||||
devicedLabel.MustDrop()
|
||||
bimages.MustDrop()
|
||||
|
||||
}
|
||||
|
||||
iter.Drop()
|
||||
|
||||
si = gotch.GetSysInfo()
|
||||
memUsed := (float64(si.TotalRam-si.FreeRam) - float64(startRAM)) / 1024
|
||||
fmt.Printf("Epoch:\t %v\t Memory Used:\t [%8.2f MiB]\n", epoch, memUsed)
|
||||
|
||||
/*
|
||||
* // Print out GPU used
|
||||
* nvidia := "nvidia-smi"
|
||||
* cmd := exec.Command(nvidia)
|
||||
* stdout, err := cmd.Output()
|
||||
*
|
||||
* if err != nil {
|
||||
* log.Fatal(err.Error())
|
||||
* }
|
||||
*
|
||||
* fmt.Println(string(stdout))
|
||||
* */
|
||||
}
|
||||
|
||||
}
|
|
@ -83,11 +83,11 @@ func fastResnet(p nn.Path) (retVal nn.SequentialT) {
|
|||
func learningRate(epoch int) (retVal float64) {
|
||||
switch {
|
||||
case epoch < 50:
|
||||
return float64(0.1)
|
||||
return 0.1
|
||||
case epoch < 100:
|
||||
return float64(0.01)
|
||||
return 0.01
|
||||
default:
|
||||
return float64(0.001)
|
||||
return 0.001
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,7 +110,7 @@ func main() {
|
|||
net := fastResnet(vs.Root())
|
||||
|
||||
optConfig := nn.NewSGDConfig(0.9, 0.0, 5e-4, true)
|
||||
opt, err := optConfig.Build(vs, 0.1)
|
||||
opt, err := optConfig.Build(vs, 0.01)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -154,9 +154,9 @@ func main() {
|
|||
loss.MustDrop()
|
||||
}
|
||||
|
||||
// testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
// fmt.Printf("Epoch:\t %v\t Memory Used:\t [%8.2f MiB]\tLoss: \t %.3f \tAcc: %10.2f%%\n", epoch, memUsed, lossVal, testAcc*100.0)
|
||||
fmt.Printf("Epoch: %10.0d\tLoss:%10.3f\n", epoch, lossVal)
|
||||
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
fmt.Printf("Epoch:\t %v\t Loss: \t %.3f \tAcc: %10.2f%%\n", epoch, lossVal, testAcc*100.0)
|
||||
// fmt.Printf("Epoch: %10.0d\tLoss:%10.3f\n", epoch, lossVal)
|
||||
iter.Drop()
|
||||
|
||||
/*
|
||||
|
|
27
example/tensor-index1/main.go
Normal file
27
example/tensor-index1/main.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
// "github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func main() {
|
||||
data := [][]int64{
|
||||
{1, 1, 1, 2, 2, 2, 3, 3},
|
||||
{1, 1, 1, 2, 2, 2, 4, 4},
|
||||
}
|
||||
// shape := []int64{2, 8}
|
||||
shape := []int64{2, 2, 4}
|
||||
|
||||
t, err := ts.NewTensorFromData(data, shape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
t.Print()
|
||||
|
||||
idx := ts.NewNarrow(0, 3)
|
||||
|
||||
selTs := t.Idx(idx)
|
||||
selTs.Print()
|
||||
}
|
|
@ -251,8 +251,6 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
|||
// `spec` is a function type implements `TensorIndexer`
|
||||
for _, spec := range indexSpec {
|
||||
|
||||
// fmt.Printf("spec type: %v\n", reflect.TypeOf(spec).Name())
|
||||
|
||||
switch reflect.TypeOf(spec).Name() {
|
||||
case "InsertNewAxis":
|
||||
nextTensor, err = currTensor.Unsqueeze(currIdx, true)
|
||||
|
@ -299,7 +297,6 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
|||
}
|
||||
|
||||
retVal = currTensor
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package tensor
|
|||
|
||||
import (
|
||||
// "unsafe"
|
||||
"log"
|
||||
|
||||
lib "github.com/sugarme/gotch/libtch"
|
||||
)
|
||||
|
@ -63,3 +64,11 @@ func (sc Scalar) Drop() (err error) {
|
|||
lib.AtsFree(sc.cscalar)
|
||||
return TorchErr()
|
||||
}
|
||||
|
||||
func (sc Scalar) MustDrop() {
|
||||
lib.AtsFree(sc.cscalar)
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package vision
|
|||
// A simple dataset structure shared by various computer vision datasets.
|
||||
|
||||
import (
|
||||
// "fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
@ -122,6 +123,11 @@ func RandomCutout(t ts.Tensor, sz int64) (retVal ts.Tensor) {
|
|||
log.Fatalf("Unexpected shape (%v) for tensor %v\n", size, t)
|
||||
}
|
||||
|
||||
// output, err := t.ShallowClone()
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
|
||||
output, err := t.ZerosLike(false)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -137,14 +143,17 @@ func RandomCutout(t ts.Tensor, sz int64) (retVal ts.Tensor) {
|
|||
|
||||
var srcIdx []ts.TensorIndexer
|
||||
nIdx := ts.NewSelect(int64(bidx))
|
||||
cIdx := ts.NewSelect(int64(-1))
|
||||
cIdx := ts.NewNarrow(0, size[1])
|
||||
hIdx := ts.NewNarrow(int64(startH), int64(startH)+sz)
|
||||
wIdx := ts.NewNarrow(int64(startW), int64(startW)+sz)
|
||||
srcIdx = append(srcIdx, nIdx, cIdx, hIdx, wIdx)
|
||||
|
||||
outputView := output.Idx(srcIdx)
|
||||
outputView.Fill_(ts.FloatScalar(0.0))
|
||||
outputView.MustDrop()
|
||||
// TODO: there's memory blow-up here. Need to fix.
|
||||
// view := output.Idx(srcIdx)
|
||||
// zeroSc := ts.FloatScalar(0.0)
|
||||
// view.Fill_(zeroSc)
|
||||
// zeroSc.MustDrop()
|
||||
// view.MustDrop()
|
||||
}
|
||||
|
||||
return output
|
||||
|
@ -157,28 +166,25 @@ func Augmentation(t ts.Tensor, flip bool, crop int64, cutout int64) (retVal ts.T
|
|||
var flipTs ts.Tensor
|
||||
if flip {
|
||||
flipTs = RandomFlip(tclone)
|
||||
tclone.MustDrop()
|
||||
} else {
|
||||
flipTs = tclone
|
||||
}
|
||||
|
||||
tclone.MustDrop()
|
||||
|
||||
var cropTs ts.Tensor
|
||||
if crop > 0 {
|
||||
cropTs = RandomCrop(flipTs, crop)
|
||||
flipTs.MustDrop()
|
||||
} else {
|
||||
cropTs = flipTs
|
||||
}
|
||||
|
||||
flipTs.MustDrop()
|
||||
return cropTs
|
||||
if cutout > 0 {
|
||||
retVal = RandomCutout(cropTs, cutout)
|
||||
cropTs.MustDrop()
|
||||
} else {
|
||||
retVal = cropTs
|
||||
}
|
||||
|
||||
// if cutout > 0 {
|
||||
// retVal = RandomCutout(cropTs, cutout)
|
||||
// } else {
|
||||
// retVal = cropTs
|
||||
// }
|
||||
//
|
||||
// cropTs.MustDrop()
|
||||
// return retVal
|
||||
return retVal
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user