feat(example/mnist): get works when printing out tensor

This commit is contained in:
sugarme 2020-06-17 11:23:00 +10:00
parent 830f9ad9df
commit 613cd93443
4 changed files with 111 additions and 40 deletions

View File

@ -0,0 +1,48 @@
package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
func main() {
// mockup data
var (
n int = 20
xvals []float32
yvals []float32
epochs = 10
)
for i := 0; i < n; i++ {
xvals = append(xvals, float32(i))
yvals = append(yvals, float32(2*i+1))
}
xtrain, err := ts.NewTensorFromData(xvals, []int64{int64(n), 1})
if err != nil {
log.Fatal(err)
}
ytrain, err := ts.NewTensorFromData(yvals, []int64{int64(n), 1})
if err != nil {
log.Fatal(err)
}
ws := ts.MustZeros([]int64{1, int64(n)}, gotch.Float.CInt(), gotch.CPU.CInt())
bs := ts.MustZeros([]int64{1, int64(n)}, gotch.Float.CInt(), gotch.CPU.CInt())
for epoch := 0; epoch < epochs; epoch++ {
logit := ws.MustMatMul(xtrain).MustAdd(bs)
loss := ts.NewTensor().MustLogSoftmax(-1, gotch.Float.CInt())
ws.MustGrad()
bs.MustGrad()
loss.MustBackward()
}
}

View File

@ -13,8 +13,10 @@ const (
Label int64 = 10
MnistDir string = "../../data/mnist"
epochs = 100
batchSize = 256
// epochs = 500
// batchSize = 256
epochs = 200
batchSize = 60000
)
func runLinear() {
@ -33,51 +35,77 @@ func runLinear() {
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
for epoch := 0; epoch < epochs; epoch++ {
var loss ts.Tensor
trainIter := ds.TrainIter(batchSize)
trainIter.Shuffle().ToDevice(gotch.CPU)
// item a pair of images and labels as 2 tensors
for {
batch, ok := trainIter.Next()
if !ok {
break
}
logits := batch.Images.MustMm(ws).MustAdd(bs)
loss = logits.MustLogSoftmax(-1, dtype).MustNllLoss(batch.Labels)
ws.ZeroGrad()
bs.ZeroGrad()
loss.Backward()
ts.NoGrad(func() {
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
})
}
/*
* logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
* loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
* totalSize := ds.TrainImages.MustSize()[0]
* samples := int(totalSize)
* index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
* imagesTs := ds.TrainImages.MustIndexSelect(0, index)
* labelsTs := ds.TrainLabels.MustIndexSelect(0, index)
*
* batches := samples / batchSize
* batchIndex := 0
* for i := 0; i < batches; i++ {
* start := batchIndex * batchSize
* size := batchSize
* if samples-start < batchSize {
* // size = samples - start
* break
* }
* batchIndex += 1
*
* // Indexing
* narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
* // bImages := ds.TrainImages.Idx(narrowIndex)
* // bLabels := ds.TrainLabels.Idx(narrowIndex)
* bImages := imagesTs.Idx(narrowIndex)
* bLabels := labelsTs.Idx(narrowIndex)
*
* logits := bImages.MustMm(ws).MustAdd(bs)
* // loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels)
* loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels)
*
* ws.ZeroGrad()
* bs.ZeroGrad()
* loss.Backward()
*
* bs.MustGrad().Print()
*
* ts.NoGrad(func() {
* ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
* bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
* })
* loss.Print()
* }
*
* imagesTs.MustDrop()
* labelsTs.MustDrop()
* */
// bs.MustGrad().Print()
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
// loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
// loss := ds.TrainImages.MustMm(ws).MustAdd(bs).MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
ws.ZeroGrad()
bs.ZeroGrad()
// loss.MustBackward()
loss.Backward()
// TODO: why `loss` need to print out to get updated?
fmt.Printf("loss (epoch %v): %v\n", epoch, loss.MustToString(0))
// fmt.Printf("bs grad (epoch %v): %v\n", epoch, bs.MustGrad().MustToString(1))
ts.NoGrad(func() {
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
})
// fmt.Printf("bs(epoch %v): \n%v\n", epoch, bs.MustToString(1))
// fmt.Printf("ws mean(epoch %v): \n%v\n", epoch, ws.MustMean(gotch.Float.CInt()).MustToString(1))
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
// testAccuracy := ds.TestImages.MustMm(ws).MustAdd(bs).MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
//
fmt.Printf("Epoch: %v - Test accuracy: %v\n", epoch, testAccuracy*100)
// fmt.Printf("Epoch: %v - Train loss: %v - Test accuracy: %v\n", epoch, loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}), testAccuracy*100)
}
}

View File

@ -124,15 +124,9 @@ func (it *Iter2) Next() (item Iter2Item, ok bool) {
// Indexing
narrowIndex := ts.NewNarrow(start, start+size)
// ts1 := it.xs.Idx(narrowIndex).MustTo(it.device)
// ts2 := it.ys.Idx(narrowIndex).MustTo(it.device)
ts1 := it.xs.Idx(narrowIndex)
ts2 := it.ys.Idx(narrowIndex)
return Iter2Item{
Images: ts1,
Labels: ts2,
Images: it.xs.Idx(narrowIndex),
Labels: it.ys.Idx(narrowIndex),
}, true
}
}

View File

@ -825,6 +825,7 @@ func (ts Tensor) ToString(lw int64) (retVal string, err error) {
// MustToString returns a string representation for the tensor. It will be panic
// if error.
// lw : line width (size)
func (ts Tensor) MustToString(lw int64) (retVal string) {
retVal, err := ts.ToString(lw)
if err != nil {