feat(example/mnis): conv
This commit is contained in:
parent
3e08ff3a41
commit
31a3f0e587
|
@ -73,12 +73,13 @@ func (n Net) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
|||
|
||||
}
|
||||
|
||||
func runCNN() {
|
||||
func runCNN1() {
|
||||
|
||||
var ds vision.Dataset
|
||||
ds = vision.LoadMNISTDir(MnistDirNN)
|
||||
// testImages := ds.TestImages
|
||||
// testLabels := ds.TestLabels
|
||||
testImages := ds.TestImages
|
||||
testLabels := ds.TestLabels
|
||||
|
||||
cuda := gotch.CudaBuilder(0)
|
||||
vs := nn.NewVarStore(cuda.CudaIfAvailable())
|
||||
// vs := nn.NewVarStore(gotch.CPU)
|
||||
|
@ -95,9 +96,9 @@ func runCNN() {
|
|||
|
||||
totalSize := ds.TrainImages.MustSize()[0]
|
||||
samples := int(totalSize)
|
||||
// index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
|
||||
// imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
|
||||
// labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)
|
||||
index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
|
||||
imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
|
||||
labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)
|
||||
|
||||
batches := samples / batchSize
|
||||
batchIndex := 0
|
||||
|
@ -114,10 +115,10 @@ func runCNN() {
|
|||
|
||||
// Indexing
|
||||
narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
|
||||
bImages := ds.TrainImages.Idx(narrowIndex)
|
||||
bLabels := ds.TrainLabels.Idx(narrowIndex)
|
||||
// bImages := imagesTs.Idx(narrowIndex)
|
||||
// bLabels := labelsTs.Idx(narrowIndex)
|
||||
// bImages := ds.TrainImages.Idx(narrowIndex)
|
||||
// bLabels := ds.TrainLabels.Idx(narrowIndex)
|
||||
bImages := imagesTs.Idx(narrowIndex)
|
||||
bLabels := labelsTs.Idx(narrowIndex)
|
||||
|
||||
bImages = bImages.MustTo(vs.Device(), true)
|
||||
bLabels = bLabels.MustTo(vs.Device(), true)
|
||||
|
@ -138,11 +139,69 @@ func runCNN() {
|
|||
loss.MustDrop()
|
||||
}
|
||||
|
||||
// testAccuracy := ts.BatchAccuracyForLogits(net, testImages, testLabels, vs.Device(), 1024)
|
||||
// fmt.Printf("Epoch: %v \t Test accuracy: %.2f%%\n", epoch, testAccuracy*100)
|
||||
// testAccuracy := ts.BatchAccuracyForLogitsIdx(net, testImages, testLabels, vs.Device(), 1024)
|
||||
// fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Values()[0], testAccuracy*100)
|
||||
|
||||
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, epocLoss.Values()[0])
|
||||
epocLoss.MustDrop()
|
||||
imagesTs.MustDrop()
|
||||
labelsTs.MustDrop()
|
||||
}
|
||||
|
||||
testAccuracy := ts.BatchAccuracyForLogitsIdx(net, testImages, testLabels, vs.Device(), 1024)
|
||||
fmt.Printf("Test accuracy: %.2f%%\n", testAccuracy*100)
|
||||
|
||||
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
||||
}
|
||||
|
||||
func runCNN2() {
|
||||
|
||||
var ds vision.Dataset
|
||||
ds = vision.LoadMNISTDir(MnistDirNN)
|
||||
|
||||
cuda := gotch.CudaBuilder(0)
|
||||
vs := nn.NewVarStore(cuda.CudaIfAvailable())
|
||||
path := vs.Root()
|
||||
net := newNet(&path)
|
||||
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
var lossVal float64
|
||||
for epoch := 0; epoch < epochsCNN; epoch++ {
|
||||
|
||||
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchCNN)
|
||||
// iter.Shuffle()
|
||||
|
||||
for {
|
||||
item, ok := iter.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
bImages := item.Data.MustTo(vs.Device(), true)
|
||||
bLabels := item.Label.MustTo(vs.Device(), true)
|
||||
|
||||
_ = ts.MustGradSetEnabled(true)
|
||||
|
||||
logits := net.ForwardT(bImages, true)
|
||||
loss := logits.CrossEntropyForLogits(bLabels)
|
||||
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
lossVal = loss.Values()[0]
|
||||
|
||||
bImages.MustDrop()
|
||||
bLabels.MustDrop()
|
||||
loss.MustDrop()
|
||||
}
|
||||
|
||||
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN)
|
||||
|
||||
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\t Accuracy: %.2f\n", epoch, lossVal, testAcc*100)
|
||||
}
|
||||
|
||||
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
||||
|
|
|
@ -21,7 +21,7 @@ func main() {
|
|||
case "nn":
|
||||
runNN()
|
||||
case "cnn":
|
||||
runCNN()
|
||||
runCNN2()
|
||||
default:
|
||||
panic("No specified model to run")
|
||||
}
|
||||
|
|
|
@ -81,12 +81,12 @@ func MustNewIter2(xs, ys Tensor, batchSize int64) (retVal Iter2) {
|
|||
//
|
||||
// The iterator would still run over the whole dataset but the order in
|
||||
// which elements are grouped in mini-batches is randomized.
|
||||
func (it Iter2) Shuffle() (retVal Iter2) {
|
||||
func (it *Iter2) Shuffle() {
|
||||
index := MustRandperm(it.totalSize, gotch.Int64, gotch.CPU)
|
||||
|
||||
it.xs = it.xs.MustIndexSelect(0, index, true)
|
||||
it.ys = it.ys.MustIndexSelect(0, index, true)
|
||||
return it
|
||||
|
||||
}
|
||||
|
||||
// ToDevice transfers the mini-batches to a specified device.
|
||||
|
|
|
@ -72,15 +72,86 @@ func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize
|
|||
break
|
||||
}
|
||||
|
||||
acc := m.ForwardT(item.Data.MustTo(d, true), false).AccuracyForLogits(item.Label.MustTo(d, true)).MustView([]int64{-1}, false).MustFloat64Value([]int64{0})
|
||||
size := float64(item.Data.MustSize()[0])
|
||||
sumAccuracy += acc * size
|
||||
bImages := item.Data.MustTo(d, true)
|
||||
bLabels := item.Label.MustTo(d, true)
|
||||
|
||||
logits := m.ForwardT(bImages, false)
|
||||
acc := logits.AccuracyForLogits(bLabels)
|
||||
sumAccuracy += acc.Values()[0] * size
|
||||
sampleCount += size
|
||||
|
||||
bImages.MustDrop()
|
||||
bLabels.MustDrop()
|
||||
acc.MustDrop()
|
||||
}
|
||||
|
||||
return sumAccuracy / sampleCount
|
||||
}
|
||||
|
||||
// BatchAccuracyForLogitIdx is an alternative of BatchAccuracyForLogits to
|
||||
// calculate accuracy for specified batch on module weight. It uses tensor
|
||||
// indexing instead of Iter2
|
||||
func BatchAccuracyForLogitsIdx(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize int) (retVal float64) {
|
||||
var (
|
||||
sumAccuracy float64 = 0.0
|
||||
sampleCount float64 = 0.0
|
||||
)
|
||||
|
||||
// Switch Grad off
|
||||
_ = NewNoGradGuard()
|
||||
|
||||
totalSize := xs.MustSize()[0]
|
||||
samples := int(totalSize)
|
||||
|
||||
index := MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
|
||||
imagesTs := xs.MustIndexSelect(0, index, false)
|
||||
labelsTs := ys.MustIndexSelect(0, index, false)
|
||||
|
||||
batches := samples / batchSize
|
||||
batchIndex := 0
|
||||
|
||||
for i := 0; i < batches; i++ {
|
||||
start := batchIndex * batchSize
|
||||
size := batchSize
|
||||
if samples-start < batchSize {
|
||||
// size = samples - start
|
||||
break
|
||||
}
|
||||
batchIndex += 1
|
||||
|
||||
// Indexing
|
||||
narrowIndex := NewNarrow(int64(start), int64(start+size))
|
||||
bImages := imagesTs.Idx(narrowIndex)
|
||||
bLabels := labelsTs.Idx(narrowIndex)
|
||||
|
||||
bImages = bImages.MustTo(d, true)
|
||||
bLabels = bLabels.MustTo(d, true)
|
||||
|
||||
logits := m.ForwardT(bImages, true)
|
||||
bAccuracy := logits.AccuracyForLogits(bLabels)
|
||||
|
||||
accuVal := bAccuracy.Values()[0]
|
||||
bSamples := float64(xs.MustSize()[0])
|
||||
sumAccuracy += accuVal * bSamples
|
||||
sampleCount += bSamples
|
||||
|
||||
// Free up tensors on C memory
|
||||
bImages.MustDrop()
|
||||
bLabels.MustDrop()
|
||||
// logits.MustDrop()
|
||||
bAccuracy.MustDrop()
|
||||
}
|
||||
|
||||
imagesTs.MustDrop()
|
||||
labelsTs.MustDrop()
|
||||
|
||||
// Switch Grad on
|
||||
// _ = MustGradSetEnabled(true)
|
||||
|
||||
return sumAccuracy / sampleCount
|
||||
}
|
||||
|
||||
// Tensor methods for Module and ModuleT:
|
||||
// ======================================
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user