feat(example/mnist): conv
This commit is contained in:
parent
b792c6af3c
commit
3e08ff3a41
|
@ -3,6 +3,7 @@ package main
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
"github.com/sugarme/gotch"
|
||||||
"github.com/sugarme/gotch/nn"
|
"github.com/sugarme/gotch/nn"
|
||||||
|
@ -13,8 +14,9 @@ import (
|
||||||
const (
|
const (
|
||||||
MnistDirCNN string = "../../data/mnist"
|
MnistDirCNN string = "../../data/mnist"
|
||||||
|
|
||||||
epochsCNN = 10
|
epochsCNN = 100
|
||||||
batchCNN = 256
|
batchCNN = 256
|
||||||
|
batchSize = 256
|
||||||
|
|
||||||
LrCNN = 1e-4
|
LrCNN = 1e-4
|
||||||
)
|
)
|
||||||
|
@ -39,20 +41,47 @@ func newNet(vs *nn.Path) Net {
|
||||||
*fc2}
|
*fc2}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n Net) ForwardT(xs ts.Tensor, train bool) ts.Tensor {
|
func (n Net) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||||
out := xs.MustView([]int64{-1, 1, 28, 28}).Apply(n.conv1).MaxPool2DDefault(2, true)
|
outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false)
|
||||||
out = out.Apply(n.conv2).MaxPool2DDefault(2, true)
|
defer outView1.MustDrop()
|
||||||
out = out.MustView([]int64{-1, 1024}).Apply(&n.fc1).MustRelu(true)
|
|
||||||
out.Dropout_(0.5, train)
|
outC1 := outView1.Apply(n.conv1)
|
||||||
return out.Apply(&n.fc2)
|
// defer outC1.MustDrop()
|
||||||
|
|
||||||
|
outMP1 := outC1.MaxPool2DDefault(2, true)
|
||||||
|
defer outMP1.MustDrop()
|
||||||
|
|
||||||
|
outC2 := outMP1.Apply(n.conv2)
|
||||||
|
// defer outC2.MustDrop()
|
||||||
|
|
||||||
|
outMP2 := outC2.MaxPool2DDefault(2, true)
|
||||||
|
// defer outMP2.MustDrop()
|
||||||
|
|
||||||
|
outView2 := outMP2.MustView([]int64{-1, 1024}, true)
|
||||||
|
defer outView2.MustDrop()
|
||||||
|
|
||||||
|
outFC1 := outView2.Apply(&n.fc1)
|
||||||
|
// defer outFC1.MustDrop()
|
||||||
|
|
||||||
|
outRelu := outFC1.MustRelu(true)
|
||||||
|
defer outRelu.MustDrop()
|
||||||
|
// outRelu.Dropout_(0.5, train)
|
||||||
|
outDropout := ts.MustDropout(outRelu, 0.5, train)
|
||||||
|
defer outDropout.MustDrop()
|
||||||
|
|
||||||
|
return outDropout.Apply(&n.fc2)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func runCNN() {
|
func runCNN() {
|
||||||
|
|
||||||
var ds vision.Dataset
|
var ds vision.Dataset
|
||||||
ds = vision.LoadMNISTDir(MnistDirNN)
|
ds = vision.LoadMNISTDir(MnistDirNN)
|
||||||
|
// testImages := ds.TestImages
|
||||||
|
// testLabels := ds.TestLabels
|
||||||
cuda := gotch.CudaBuilder(0)
|
cuda := gotch.CudaBuilder(0)
|
||||||
vs := nn.NewVarStore(cuda.CudaIfAvailable())
|
vs := nn.NewVarStore(cuda.CudaIfAvailable())
|
||||||
|
// vs := nn.NewVarStore(gotch.CPU)
|
||||||
path := vs.Root()
|
path := vs.Root()
|
||||||
net := newNet(&path)
|
net := newNet(&path)
|
||||||
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
||||||
|
@ -60,28 +89,61 @@ func runCNN() {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for epoch := 0; epoch < epochsCNN; epoch++ {
|
startTime := time.Now()
|
||||||
var count = 0
|
|
||||||
for {
|
|
||||||
iter := ds.TrainIter(batchCNN).Shuffle()
|
|
||||||
item, ok := iter.Next()
|
|
||||||
if !ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
loss := net.ForwardT(item.Data.MustTo(vs.Device(), true), true).CrossEntropyForLogits(item.Label.MustTo(vs.Device(), true))
|
for epoch := 0; epoch < epochsCNN; epoch++ {
|
||||||
opt.BackwardStep(loss)
|
|
||||||
loss.MustDrop()
|
totalSize := ds.TrainImages.MustSize()[0]
|
||||||
count++
|
samples := int(totalSize)
|
||||||
if count == 50 {
|
// index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
|
||||||
|
// imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
|
||||||
|
// labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)
|
||||||
|
|
||||||
|
batches := samples / batchSize
|
||||||
|
batchIndex := 0
|
||||||
|
var epocLoss ts.Tensor
|
||||||
|
// var loss ts.Tensor
|
||||||
|
for i := 0; i < batches; i++ {
|
||||||
|
start := batchIndex * batchSize
|
||||||
|
size := batchSize
|
||||||
|
if samples-start < batchSize {
|
||||||
|
// size = samples - start
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
fmt.Printf("completed \t %v batches\n", count)
|
batchIndex += 1
|
||||||
|
|
||||||
|
// Indexing
|
||||||
|
narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
|
||||||
|
bImages := ds.TrainImages.Idx(narrowIndex)
|
||||||
|
bLabels := ds.TrainLabels.Idx(narrowIndex)
|
||||||
|
// bImages := imagesTs.Idx(narrowIndex)
|
||||||
|
// bLabels := labelsTs.Idx(narrowIndex)
|
||||||
|
|
||||||
|
bImages = bImages.MustTo(vs.Device(), true)
|
||||||
|
bLabels = bLabels.MustTo(vs.Device(), true)
|
||||||
|
|
||||||
|
logits := net.ForwardT(bImages, true)
|
||||||
|
loss := logits.CrossEntropyForLogits(bLabels)
|
||||||
|
|
||||||
|
opt.BackwardStep(loss)
|
||||||
|
|
||||||
|
epocLoss = loss.MustShallowClone()
|
||||||
|
epocLoss.Detach_()
|
||||||
|
|
||||||
|
// fmt.Printf("completed \t %v batches\t %.2f\n", i, loss.Values()[0])
|
||||||
|
|
||||||
|
bImages.MustDrop()
|
||||||
|
bLabels.MustDrop()
|
||||||
|
// logits.MustDrop()
|
||||||
|
loss.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
||||||
// testAccuracy := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 1024)
|
// testAccuracy := ts.BatchAccuracyForLogits(net, testImages, testLabels, vs.Device(), 1024)
|
||||||
//
|
|
||||||
// fmt.Printf("Epoch: %v \t Test accuracy: %.2f%%\n", epoch, testAccuracy*100)
|
// fmt.Printf("Epoch: %v \t Test accuracy: %.2f%%\n", epoch, testAccuracy*100)
|
||||||
|
|
||||||
|
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, epocLoss.Values()[0])
|
||||||
|
epocLoss.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,7 @@ func runLinear() {
|
||||||
})
|
})
|
||||||
|
|
||||||
testLogits := ds.TestImages.MustMm(ws, false).MustAdd(bs, true)
|
testLogits := ds.TestImages.MustMm(ws, false).MustAdd(bs, true)
|
||||||
testAccuracy := testLogits.MustArgmax(-1, false, true).MustEq1(ds.TestLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
testAccuracy := testLogits.MustArgmax(-1, false, true).MustEq1(ds.TestLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0})
|
||||||
|
|
||||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100)
|
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100)
|
||||||
|
|
||||||
|
|
|
@ -72,7 +72,7 @@ func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
acc := m.ForwardT(item.Data.MustTo(d, true), false).AccuracyForLogits(item.Label.MustTo(d, true)).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
acc := m.ForwardT(item.Data.MustTo(d, true), false).AccuracyForLogits(item.Label.MustTo(d, true)).MustView([]int64{-1}, false).MustFloat64Value([]int64{0})
|
||||||
size := float64(item.Data.MustSize()[0])
|
size := float64(item.Data.MustSize()[0])
|
||||||
sumAccuracy += acc * size
|
sumAccuracy += acc * size
|
||||||
sampleCount += size
|
sampleCount += size
|
||||||
|
|
|
@ -678,9 +678,11 @@ func (ts Tensor) MustMean(dtype int32, del bool) (retVal Tensor) {
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts Tensor) View(sizeData []int64) (retVal Tensor, err error) {
|
func (ts Tensor) View(sizeData []int64, del bool) (retVal Tensor, err error) {
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
defer C.free(unsafe.Pointer(ptr))
|
if del {
|
||||||
|
defer ts.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
lib.AtgView(ptr, ts.ctensor, sizeData, len(sizeData))
|
lib.AtgView(ptr, ts.ctensor, sizeData, len(sizeData))
|
||||||
if err = TorchErr(); err != nil {
|
if err = TorchErr(); err != nil {
|
||||||
|
@ -692,8 +694,8 @@ func (ts Tensor) View(sizeData []int64) (retVal Tensor, err error) {
|
||||||
return retVal, nil
|
return retVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts Tensor) MustView(sizeData []int64) (retVal Tensor) {
|
func (ts Tensor) MustView(sizeData []int64, del bool) (retVal Tensor) {
|
||||||
retVal, err := ts.View(sizeData)
|
retVal, err := ts.View(sizeData, del)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -993,5 +993,5 @@ func (r Reduction) ToInt() (retVal int) {
|
||||||
func (ts Tensor) Values() []float64 {
|
func (ts Tensor) Values() []float64 {
|
||||||
clone := ts.MustShallowClone()
|
clone := ts.MustShallowClone()
|
||||||
clone.Detach_()
|
clone.Detach_()
|
||||||
return []float64{clone.MustView([]int64{-1}).MustFloat64Value([]int64{-1})}
|
return []float64{clone.MustView([]int64{-1}, true).MustFloat64Value([]int64{-1})}
|
||||||
}
|
}
|
||||||
|
|
|
@ -125,7 +125,7 @@ func readImages(filename string) (retVal ts.Tensor) {
|
||||||
err = fmt.Errorf("create images tensor err.")
|
err = fmt.Errorf("create images tensor err.")
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
retVal = imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}).MustTotype(gotch.Float, true).MustDiv1(ts.FloatScalar(255.0), true)
|
retVal = imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}, true).MustTotype(gotch.Float, true).MustDiv1(ts.FloatScalar(255.0), true)
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user