feat(example/mnis): conv

This commit is contained in:
sugarme 2020-06-24 01:37:33 +10:00
parent 3e08ff3a41
commit 31a3f0e587
4 changed files with 147 additions and 17 deletions

View File

@ -73,12 +73,13 @@ func (n Net) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
}
func runCNN() {
func runCNN1() {
var ds vision.Dataset
ds = vision.LoadMNISTDir(MnistDirNN)
// testImages := ds.TestImages
// testLabels := ds.TestLabels
testImages := ds.TestImages
testLabels := ds.TestLabels
cuda := gotch.CudaBuilder(0)
vs := nn.NewVarStore(cuda.CudaIfAvailable())
// vs := nn.NewVarStore(gotch.CPU)
@ -95,9 +96,9 @@ func runCNN() {
totalSize := ds.TrainImages.MustSize()[0]
samples := int(totalSize)
// index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
// imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
// labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)
index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)
batches := samples / batchSize
batchIndex := 0
@ -114,10 +115,10 @@ func runCNN() {
// Indexing
narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
bImages := ds.TrainImages.Idx(narrowIndex)
bLabels := ds.TrainLabels.Idx(narrowIndex)
// bImages := imagesTs.Idx(narrowIndex)
// bLabels := labelsTs.Idx(narrowIndex)
// bImages := ds.TrainImages.Idx(narrowIndex)
// bLabels := ds.TrainLabels.Idx(narrowIndex)
bImages := imagesTs.Idx(narrowIndex)
bLabels := labelsTs.Idx(narrowIndex)
bImages = bImages.MustTo(vs.Device(), true)
bLabels = bLabels.MustTo(vs.Device(), true)
@ -138,11 +139,69 @@ func runCNN() {
loss.MustDrop()
}
// testAccuracy := ts.BatchAccuracyForLogits(net, testImages, testLabels, vs.Device(), 1024)
// fmt.Printf("Epoch: %v \t Test accuracy: %.2f%%\n", epoch, testAccuracy*100)
// testAccuracy := ts.BatchAccuracyForLogitsIdx(net, testImages, testLabels, vs.Device(), 1024)
// fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Values()[0], testAccuracy*100)
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, epocLoss.Values()[0])
epocLoss.MustDrop()
imagesTs.MustDrop()
labelsTs.MustDrop()
}
testAccuracy := ts.BatchAccuracyForLogitsIdx(net, testImages, testLabels, vs.Device(), 1024)
fmt.Printf("Test accuracy: %.2f%%\n", testAccuracy*100)
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
}
func runCNN2() {
var ds vision.Dataset
ds = vision.LoadMNISTDir(MnistDirNN)
cuda := gotch.CudaBuilder(0)
vs := nn.NewVarStore(cuda.CudaIfAvailable())
path := vs.Root()
net := newNet(&path)
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
if err != nil {
log.Fatal(err)
}
startTime := time.Now()
var lossVal float64
for epoch := 0; epoch < epochsCNN; epoch++ {
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchCNN)
// iter.Shuffle()
for {
item, ok := iter.Next()
if !ok {
break
}
bImages := item.Data.MustTo(vs.Device(), true)
bLabels := item.Label.MustTo(vs.Device(), true)
_ = ts.MustGradSetEnabled(true)
logits := net.ForwardT(bImages, true)
loss := logits.CrossEntropyForLogits(bLabels)
opt.BackwardStep(loss)
lossVal = loss.Values()[0]
bImages.MustDrop()
bLabels.MustDrop()
loss.MustDrop()
}
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN)
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\t Accuracy: %.2f\n", epoch, lossVal, testAcc*100)
}
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())

View File

@ -21,7 +21,7 @@ func main() {
case "nn":
runNN()
case "cnn":
runCNN()
runCNN2()
default:
panic("No specified model to run")
}

View File

@ -81,12 +81,12 @@ func MustNewIter2(xs, ys Tensor, batchSize int64) (retVal Iter2) {
//
// The iterator would still run over the whole dataset but the order in
// which elements are grouped in mini-batches is randomized.
func (it Iter2) Shuffle() (retVal Iter2) {
func (it *Iter2) Shuffle() {
index := MustRandperm(it.totalSize, gotch.Int64, gotch.CPU)
it.xs = it.xs.MustIndexSelect(0, index, true)
it.ys = it.ys.MustIndexSelect(0, index, true)
return it
}
// ToDevice transfers the mini-batches to a specified device.

View File

@ -72,15 +72,86 @@ func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize
break
}
acc := m.ForwardT(item.Data.MustTo(d, true), false).AccuracyForLogits(item.Label.MustTo(d, true)).MustView([]int64{-1}, false).MustFloat64Value([]int64{0})
size := float64(item.Data.MustSize()[0])
sumAccuracy += acc * size
bImages := item.Data.MustTo(d, true)
bLabels := item.Label.MustTo(d, true)
logits := m.ForwardT(bImages, false)
acc := logits.AccuracyForLogits(bLabels)
sumAccuracy += acc.Values()[0] * size
sampleCount += size
bImages.MustDrop()
bLabels.MustDrop()
acc.MustDrop()
}
return sumAccuracy / sampleCount
}
// BatchAccuracyForLogitIdx is an alternative of BatchAccuracyForLogits to
// calculate accuracy for specified batch on module weight. It uses tensor
// indexing instead of Iter2
func BatchAccuracyForLogitsIdx(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize int) (retVal float64) {
var (
sumAccuracy float64 = 0.0
sampleCount float64 = 0.0
)
// Switch Grad off
_ = NewNoGradGuard()
totalSize := xs.MustSize()[0]
samples := int(totalSize)
index := MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
imagesTs := xs.MustIndexSelect(0, index, false)
labelsTs := ys.MustIndexSelect(0, index, false)
batches := samples / batchSize
batchIndex := 0
for i := 0; i < batches; i++ {
start := batchIndex * batchSize
size := batchSize
if samples-start < batchSize {
// size = samples - start
break
}
batchIndex += 1
// Indexing
narrowIndex := NewNarrow(int64(start), int64(start+size))
bImages := imagesTs.Idx(narrowIndex)
bLabels := labelsTs.Idx(narrowIndex)
bImages = bImages.MustTo(d, true)
bLabels = bLabels.MustTo(d, true)
logits := m.ForwardT(bImages, true)
bAccuracy := logits.AccuracyForLogits(bLabels)
accuVal := bAccuracy.Values()[0]
bSamples := float64(xs.MustSize()[0])
sumAccuracy += accuVal * bSamples
sampleCount += bSamples
// Free up tensors on C memory
bImages.MustDrop()
bLabels.MustDrop()
// logits.MustDrop()
bAccuracy.MustDrop()
}
imagesTs.MustDrop()
labelsTs.MustDrop()
// Switch Grad on
// _ = MustGradSetEnabled(true)
return sumAccuracy / sampleCount
}
// Tensor methods for Module and ModuleT:
// ======================================