feat(example/mnist): get works when printing out tensor
This commit is contained in:
parent
830f9ad9df
commit
613cd93443
48
example/linear-regression/main.go
Normal file
48
example/linear-regression/main.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
// mockup data
|
||||
var (
|
||||
n int = 20
|
||||
xvals []float32
|
||||
yvals []float32
|
||||
epochs = 10
|
||||
)
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
xvals = append(xvals, float32(i))
|
||||
yvals = append(yvals, float32(2*i+1))
|
||||
}
|
||||
|
||||
xtrain, err := ts.NewTensorFromData(xvals, []int64{int64(n), 1})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
ytrain, err := ts.NewTensorFromData(yvals, []int64{int64(n), 1})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ws := ts.MustZeros([]int64{1, int64(n)}, gotch.Float.CInt(), gotch.CPU.CInt())
|
||||
bs := ts.MustZeros([]int64{1, int64(n)}, gotch.Float.CInt(), gotch.CPU.CInt())
|
||||
|
||||
for epoch := 0; epoch < epochs; epoch++ {
|
||||
|
||||
logit := ws.MustMatMul(xtrain).MustAdd(bs)
|
||||
loss := ts.NewTensor().MustLogSoftmax(-1, gotch.Float.CInt())
|
||||
|
||||
ws.MustGrad()
|
||||
bs.MustGrad()
|
||||
loss.MustBackward()
|
||||
|
||||
}
|
||||
}
|
|
@ -13,8 +13,10 @@ const (
|
|||
Label int64 = 10
|
||||
MnistDir string = "../../data/mnist"
|
||||
|
||||
epochs = 100
|
||||
batchSize = 256
|
||||
// epochs = 500
|
||||
// batchSize = 256
|
||||
epochs = 200
|
||||
batchSize = 60000
|
||||
)
|
||||
|
||||
func runLinear() {
|
||||
|
@ -33,51 +35,77 @@ func runLinear() {
|
|||
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
|
||||
|
||||
for epoch := 0; epoch < epochs; epoch++ {
|
||||
|
||||
var loss ts.Tensor
|
||||
trainIter := ds.TrainIter(batchSize)
|
||||
trainIter.Shuffle().ToDevice(gotch.CPU)
|
||||
// item a pair of images and labels as 2 tensors
|
||||
for {
|
||||
batch, ok := trainIter.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
logits := batch.Images.MustMm(ws).MustAdd(bs)
|
||||
loss = logits.MustLogSoftmax(-1, dtype).MustNllLoss(batch.Labels)
|
||||
ws.ZeroGrad()
|
||||
bs.ZeroGrad()
|
||||
loss.Backward()
|
||||
|
||||
ts.NoGrad(func() {
|
||||
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
* logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
|
||||
* loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
|
||||
* totalSize := ds.TrainImages.MustSize()[0]
|
||||
* samples := int(totalSize)
|
||||
* index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
|
||||
* imagesTs := ds.TrainImages.MustIndexSelect(0, index)
|
||||
* labelsTs := ds.TrainLabels.MustIndexSelect(0, index)
|
||||
*
|
||||
* batches := samples / batchSize
|
||||
* batchIndex := 0
|
||||
* for i := 0; i < batches; i++ {
|
||||
* start := batchIndex * batchSize
|
||||
* size := batchSize
|
||||
* if samples-start < batchSize {
|
||||
* // size = samples - start
|
||||
* break
|
||||
* }
|
||||
* batchIndex += 1
|
||||
*
|
||||
* // Indexing
|
||||
* narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
|
||||
* // bImages := ds.TrainImages.Idx(narrowIndex)
|
||||
* // bLabels := ds.TrainLabels.Idx(narrowIndex)
|
||||
* bImages := imagesTs.Idx(narrowIndex)
|
||||
* bLabels := labelsTs.Idx(narrowIndex)
|
||||
*
|
||||
* logits := bImages.MustMm(ws).MustAdd(bs)
|
||||
* // loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels)
|
||||
* loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels)
|
||||
*
|
||||
* ws.ZeroGrad()
|
||||
* bs.ZeroGrad()
|
||||
* loss.Backward()
|
||||
*
|
||||
* bs.MustGrad().Print()
|
||||
*
|
||||
* ts.NoGrad(func() {
|
||||
* ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
* bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
* })
|
||||
* loss.Print()
|
||||
* }
|
||||
*
|
||||
* imagesTs.MustDrop()
|
||||
* labelsTs.MustDrop()
|
||||
* */
|
||||
|
||||
// bs.MustGrad().Print()
|
||||
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
|
||||
// loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
|
||||
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
|
||||
// loss := ds.TrainImages.MustMm(ws).MustAdd(bs).MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
|
||||
|
||||
ws.ZeroGrad()
|
||||
bs.ZeroGrad()
|
||||
// loss.MustBackward()
|
||||
loss.Backward()
|
||||
|
||||
// TODO: why `loss` need to print out to get updated?
|
||||
fmt.Printf("loss (epoch %v): %v\n", epoch, loss.MustToString(0))
|
||||
// fmt.Printf("bs grad (epoch %v): %v\n", epoch, bs.MustGrad().MustToString(1))
|
||||
|
||||
ts.NoGrad(func() {
|
||||
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
})
|
||||
|
||||
// fmt.Printf("bs(epoch %v): \n%v\n", epoch, bs.MustToString(1))
|
||||
// fmt.Printf("ws mean(epoch %v): \n%v\n", epoch, ws.MustMean(gotch.Float.CInt()).MustToString(1))
|
||||
|
||||
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
|
||||
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
|
||||
// testAccuracy := ds.TestImages.MustMm(ws).MustAdd(bs).MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
//
|
||||
fmt.Printf("Epoch: %v - Test accuracy: %v\n", epoch, testAccuracy*100)
|
||||
|
||||
// fmt.Printf("Epoch: %v - Train loss: %v - Test accuracy: %v\n", epoch, loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}), testAccuracy*100)
|
||||
}
|
||||
}
|
||||
|
|
10
nn/data.go
10
nn/data.go
|
@ -124,15 +124,9 @@ func (it *Iter2) Next() (item Iter2Item, ok bool) {
|
|||
// Indexing
|
||||
narrowIndex := ts.NewNarrow(start, start+size)
|
||||
|
||||
// ts1 := it.xs.Idx(narrowIndex).MustTo(it.device)
|
||||
// ts2 := it.ys.Idx(narrowIndex).MustTo(it.device)
|
||||
|
||||
ts1 := it.xs.Idx(narrowIndex)
|
||||
ts2 := it.ys.Idx(narrowIndex)
|
||||
|
||||
return Iter2Item{
|
||||
Images: ts1,
|
||||
Labels: ts2,
|
||||
Images: it.xs.Idx(narrowIndex),
|
||||
Labels: it.ys.Idx(narrowIndex),
|
||||
}, true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -825,6 +825,7 @@ func (ts Tensor) ToString(lw int64) (retVal string, err error) {
|
|||
|
||||
// MustToString returns a string representation for the tensor. It will be panic
|
||||
// if error.
|
||||
// lw : line width (size)
|
||||
func (ts Tensor) MustToString(lw int64) (retVal string) {
|
||||
retVal, err := ts.ToString(lw)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue
Block a user