feat(example/mnis): conv

This commit is contained in:
sugarme 2020-06-24 01:37:33 +10:00
parent 3e08ff3a41
commit 31a3f0e587
4 changed files with 147 additions and 17 deletions

View File

@ -73,12 +73,13 @@ func (n Net) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
} }
func runCNN() { func runCNN1() {
var ds vision.Dataset var ds vision.Dataset
ds = vision.LoadMNISTDir(MnistDirNN) ds = vision.LoadMNISTDir(MnistDirNN)
// testImages := ds.TestImages testImages := ds.TestImages
// testLabels := ds.TestLabels testLabels := ds.TestLabels
cuda := gotch.CudaBuilder(0) cuda := gotch.CudaBuilder(0)
vs := nn.NewVarStore(cuda.CudaIfAvailable()) vs := nn.NewVarStore(cuda.CudaIfAvailable())
// vs := nn.NewVarStore(gotch.CPU) // vs := nn.NewVarStore(gotch.CPU)
@ -95,9 +96,9 @@ func runCNN() {
totalSize := ds.TrainImages.MustSize()[0] totalSize := ds.TrainImages.MustSize()[0]
samples := int(totalSize) samples := int(totalSize)
// index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU) index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
// imagesTs := ds.TrainImages.MustIndexSelect(0, index, false) imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
// labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false) labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)
batches := samples / batchSize batches := samples / batchSize
batchIndex := 0 batchIndex := 0
@ -114,10 +115,10 @@ func runCNN() {
// Indexing // Indexing
narrowIndex := ts.NewNarrow(int64(start), int64(start+size)) narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
bImages := ds.TrainImages.Idx(narrowIndex) // bImages := ds.TrainImages.Idx(narrowIndex)
bLabels := ds.TrainLabels.Idx(narrowIndex) // bLabels := ds.TrainLabels.Idx(narrowIndex)
// bImages := imagesTs.Idx(narrowIndex) bImages := imagesTs.Idx(narrowIndex)
// bLabels := labelsTs.Idx(narrowIndex) bLabels := labelsTs.Idx(narrowIndex)
bImages = bImages.MustTo(vs.Device(), true) bImages = bImages.MustTo(vs.Device(), true)
bLabels = bLabels.MustTo(vs.Device(), true) bLabels = bLabels.MustTo(vs.Device(), true)
@ -138,11 +139,69 @@ func runCNN() {
loss.MustDrop() loss.MustDrop()
} }
// testAccuracy := ts.BatchAccuracyForLogits(net, testImages, testLabels, vs.Device(), 1024) // testAccuracy := ts.BatchAccuracyForLogitsIdx(net, testImages, testLabels, vs.Device(), 1024)
// fmt.Printf("Epoch: %v \t Test accuracy: %.2f%%\n", epoch, testAccuracy*100) // fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Values()[0], testAccuracy*100)
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, epocLoss.Values()[0]) fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, epocLoss.Values()[0])
epocLoss.MustDrop() epocLoss.MustDrop()
imagesTs.MustDrop()
labelsTs.MustDrop()
}
testAccuracy := ts.BatchAccuracyForLogitsIdx(net, testImages, testLabels, vs.Device(), 1024)
fmt.Printf("Test accuracy: %.2f%%\n", testAccuracy*100)
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
}
func runCNN2() {
var ds vision.Dataset
ds = vision.LoadMNISTDir(MnistDirNN)
cuda := gotch.CudaBuilder(0)
vs := nn.NewVarStore(cuda.CudaIfAvailable())
path := vs.Root()
net := newNet(&path)
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
if err != nil {
log.Fatal(err)
}
startTime := time.Now()
var lossVal float64
for epoch := 0; epoch < epochsCNN; epoch++ {
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchCNN)
// iter.Shuffle()
for {
item, ok := iter.Next()
if !ok {
break
}
bImages := item.Data.MustTo(vs.Device(), true)
bLabels := item.Label.MustTo(vs.Device(), true)
_ = ts.MustGradSetEnabled(true)
logits := net.ForwardT(bImages, true)
loss := logits.CrossEntropyForLogits(bLabels)
opt.BackwardStep(loss)
lossVal = loss.Values()[0]
bImages.MustDrop()
bLabels.MustDrop()
loss.MustDrop()
}
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN)
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\t Accuracy: %.2f\n", epoch, lossVal, testAcc*100)
} }
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes()) fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())

View File

@ -21,7 +21,7 @@ func main() {
case "nn": case "nn":
runNN() runNN()
case "cnn": case "cnn":
runCNN() runCNN2()
default: default:
panic("No specified model to run") panic("No specified model to run")
} }

View File

@ -81,12 +81,12 @@ func MustNewIter2(xs, ys Tensor, batchSize int64) (retVal Iter2) {
// //
// The iterator would still run over the whole dataset but the order in // The iterator would still run over the whole dataset but the order in
// which elements are grouped in mini-batches is randomized. // which elements are grouped in mini-batches is randomized.
func (it Iter2) Shuffle() (retVal Iter2) { func (it *Iter2) Shuffle() {
index := MustRandperm(it.totalSize, gotch.Int64, gotch.CPU) index := MustRandperm(it.totalSize, gotch.Int64, gotch.CPU)
it.xs = it.xs.MustIndexSelect(0, index, true) it.xs = it.xs.MustIndexSelect(0, index, true)
it.ys = it.ys.MustIndexSelect(0, index, true) it.ys = it.ys.MustIndexSelect(0, index, true)
return it
} }
// ToDevice transfers the mini-batches to a specified device. // ToDevice transfers the mini-batches to a specified device.

View File

@ -72,15 +72,86 @@ func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize
break break
} }
acc := m.ForwardT(item.Data.MustTo(d, true), false).AccuracyForLogits(item.Label.MustTo(d, true)).MustView([]int64{-1}, false).MustFloat64Value([]int64{0})
size := float64(item.Data.MustSize()[0]) size := float64(item.Data.MustSize()[0])
sumAccuracy += acc * size bImages := item.Data.MustTo(d, true)
bLabels := item.Label.MustTo(d, true)
logits := m.ForwardT(bImages, false)
acc := logits.AccuracyForLogits(bLabels)
sumAccuracy += acc.Values()[0] * size
sampleCount += size sampleCount += size
bImages.MustDrop()
bLabels.MustDrop()
acc.MustDrop()
} }
return sumAccuracy / sampleCount return sumAccuracy / sampleCount
} }
// BatchAccuracyForLogitIdx is an alternative of BatchAccuracyForLogits to
// calculate accuracy for specified batch on module weight. It uses tensor
// indexing instead of Iter2
func BatchAccuracyForLogitsIdx(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize int) (retVal float64) {
var (
sumAccuracy float64 = 0.0
sampleCount float64 = 0.0
)
// Switch Grad off
_ = NewNoGradGuard()
totalSize := xs.MustSize()[0]
samples := int(totalSize)
index := MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
imagesTs := xs.MustIndexSelect(0, index, false)
labelsTs := ys.MustIndexSelect(0, index, false)
batches := samples / batchSize
batchIndex := 0
for i := 0; i < batches; i++ {
start := batchIndex * batchSize
size := batchSize
if samples-start < batchSize {
// size = samples - start
break
}
batchIndex += 1
// Indexing
narrowIndex := NewNarrow(int64(start), int64(start+size))
bImages := imagesTs.Idx(narrowIndex)
bLabels := labelsTs.Idx(narrowIndex)
bImages = bImages.MustTo(d, true)
bLabels = bLabels.MustTo(d, true)
logits := m.ForwardT(bImages, true)
bAccuracy := logits.AccuracyForLogits(bLabels)
accuVal := bAccuracy.Values()[0]
bSamples := float64(xs.MustSize()[0])
sumAccuracy += accuVal * bSamples
sampleCount += bSamples
// Free up tensors on C memory
bImages.MustDrop()
bLabels.MustDrop()
// logits.MustDrop()
bAccuracy.MustDrop()
}
imagesTs.MustDrop()
labelsTs.MustDrop()
// Switch Grad on
// _ = MustGradSetEnabled(true)
return sumAccuracy / sampleCount
}
// Tensor methods for Module and ModuleT: // Tensor methods for Module and ModuleT:
// ====================================== // ======================================