112 lines
3.8 KiB
Go
112 lines
3.8 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"github.com/sugarme/gotch"
|
|
ts "github.com/sugarme/gotch/tensor"
|
|
"github.com/sugarme/gotch/vision"
|
|
)
|
|
|
|
const (
|
|
ImageDim int64 = 784
|
|
Label int64 = 10
|
|
MnistDir string = "../../data/mnist"
|
|
|
|
// epochs = 500
|
|
// batchSize = 256
|
|
epochs = 200
|
|
batchSize = 60000
|
|
)
|
|
|
|
func runLinear() {
|
|
var ds vision.Dataset
|
|
ds = vision.LoadMNISTDir(MnistDir)
|
|
|
|
// fmt.Printf("Train image size: %v\n", ds.TrainImages.MustSize())
|
|
// fmt.Printf("Train label size: %v\n", ds.TrainLabels.MustSize())
|
|
// fmt.Printf("Test image size: %v\n", ds.TestImages.MustSize())
|
|
// fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
|
|
|
|
device := (gotch.CPU).CInt()
|
|
dtype := (gotch.Float).CInt()
|
|
|
|
ws := ts.MustZeros([]int64{ImageDim, Label}, dtype, device).MustSetRequiresGrad(true)
|
|
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
|
|
|
|
for epoch := 0; epoch < epochs; epoch++ {
|
|
/*
|
|
* totalSize := ds.TrainImages.MustSize()[0]
|
|
* samples := int(totalSize)
|
|
* index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
|
|
* imagesTs := ds.TrainImages.MustIndexSelect(0, index)
|
|
* labelsTs := ds.TrainLabels.MustIndexSelect(0, index)
|
|
*
|
|
* batches := samples / batchSize
|
|
* batchIndex := 0
|
|
* for i := 0; i < batches; i++ {
|
|
* start := batchIndex * batchSize
|
|
* size := batchSize
|
|
* if samples-start < batchSize {
|
|
* // size = samples - start
|
|
* break
|
|
* }
|
|
* batchIndex += 1
|
|
*
|
|
* // Indexing
|
|
* narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
|
|
* // bImages := ds.TrainImages.Idx(narrowIndex)
|
|
* // bLabels := ds.TrainLabels.Idx(narrowIndex)
|
|
* bImages := imagesTs.Idx(narrowIndex)
|
|
* bLabels := labelsTs.Idx(narrowIndex)
|
|
*
|
|
* logits := bImages.MustMm(ws).MustAdd(bs)
|
|
* // loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels)
|
|
* loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels)
|
|
*
|
|
* ws.ZeroGrad()
|
|
* bs.ZeroGrad()
|
|
* loss.Backward()
|
|
*
|
|
* bs.MustGrad().Print()
|
|
*
|
|
* ts.NoGrad(func() {
|
|
* ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
|
* bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
|
* })
|
|
* }
|
|
*
|
|
* imagesTs.MustDrop()
|
|
* labelsTs.MustDrop()
|
|
* */
|
|
|
|
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
|
|
// loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
|
|
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
|
|
// loss := ds.TrainImages.MustMm(ws).MustAdd(bs).MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
|
|
|
|
ws.ZeroGrad()
|
|
bs.ZeroGrad()
|
|
// loss.MustBackward()
|
|
loss.Backward()
|
|
|
|
// TODO: why `loss` need to print out to get updated?
|
|
fmt.Printf("loss (epoch %v): %v\n", epoch, loss.MustToString(0))
|
|
// fmt.Printf("bs grad (epoch %v): %v\n", epoch, bs.MustGrad().MustToString(1))
|
|
|
|
ts.NoGrad(func() {
|
|
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
|
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
|
})
|
|
|
|
// fmt.Printf("bs(epoch %v): \n%v\n", epoch, bs.MustToString(1))
|
|
// fmt.Printf("ws mean(epoch %v): \n%v\n", epoch, ws.MustMean(gotch.Float.CInt()).MustToString(1))
|
|
|
|
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
|
|
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
|
// testAccuracy := ds.TestImages.MustMm(ws).MustAdd(bs).MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
|
//
|
|
fmt.Printf("Epoch: %v - Test accuracy: %v\n", epoch, testAccuracy*100)
|
|
}
|
|
}
|