feat(nn/data): data interator Iter2

This commit is contained in:
sugarme 2020-06-16 13:39:02 +10:00
parent da950fe881
commit 830f9ad9df
8 changed files with 278 additions and 53 deletions

View File

@ -13,17 +13,18 @@ const (
Label int64 = 10
MnistDir string = "../../data/mnist"
epochs = 200
epochs = 100
batchSize = 256
)
func runLinear() {
var ds vision.Dataset
ds = vision.LoadMNISTDir(MnistDir)
fmt.Printf("Train image size: %v\n", ds.TrainImages.MustSize())
fmt.Printf("Train label size: %v\n", ds.TrainLabels.MustSize())
fmt.Printf("Test image size: %v\n", ds.TestImages.MustSize())
fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
// fmt.Printf("Train image size: %v\n", ds.TrainImages.MustSize())
// fmt.Printf("Train label size: %v\n", ds.TrainLabels.MustSize())
// fmt.Printf("Test image size: %v\n", ds.TestImages.MustSize())
// fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
device := (gotch.CPU).CInt()
dtype := (gotch.Float).CInt()
@ -32,22 +33,51 @@ func runLinear() {
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
for epoch := 0; epoch < epochs; epoch++ {
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
ws.ZeroGrad()
bs.ZeroGrad()
loss.Backward()
var loss ts.Tensor
trainIter := ds.TrainIter(batchSize)
trainIter.Shuffle().ToDevice(gotch.CPU)
// item a pair of images and labels as 2 tensors
for {
batch, ok := trainIter.Next()
if !ok {
break
}
ts.NoGrad(func() {
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
})
logits := batch.Images.MustMm(ws).MustAdd(bs)
loss = logits.MustLogSoftmax(-1, dtype).MustNllLoss(batch.Labels)
ws.ZeroGrad()
bs.ZeroGrad()
loss.Backward()
ts.NoGrad(func() {
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
})
}
/*
* logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
* loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
*
* ws.ZeroGrad()
* bs.ZeroGrad()
* loss.Backward()
*
* ts.NoGrad(func() {
* ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
* bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
* })
* loss.Print()
* */
// bs.MustGrad().Print()
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
fmt.Printf("Epoch: %v - Train loss: %v - Test accuracy: %v\n", epoch, loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}), testAccuracy*100)
fmt.Printf("Epoch: %v - Test accuracy: %v\n", epoch, testAccuracy*100)
// fmt.Printf("Epoch: %v - Train loss: %v - Test accuracy: %v\n", epoch, loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}), testAccuracy*100)
}
}

View File

@ -233,3 +233,12 @@ func AtgView(ptr *Ctensor, self Ctensor, sizeData []int64, sizeLen int) {
func AtgDiv1(ptr *Ctensor, self Ctensor, other Cscalar) {
C.atg_div1(ptr, self, other)
}
// void atg_randperm(tensor *, int64_t n, int options_kind, int options_device);
func AtgRandperm(ptr *Ctensor, n int64, optionKind int32, optionDevice int32) {
cn := *(*C.int64_t)(unsafe.Pointer(&n))
coptionKind := *(*C.int)(unsafe.Pointer(&optionKind))
coptionDevice := *(*C.int)(unsafe.Pointer(&optionDevice))
C.atg_randperm(ptr, cn, coptionKind, coptionDevice)
}

138
nn/data.go Normal file
View File

@ -0,0 +1,138 @@
package nn
import (
"fmt"
"log"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
// Iter2 is an iterator over a pair of tensors which have the same first dimension
// size.
// The typical use case is to iterate over batches. Each batch is a pair
// containing a (potentially random) slice of each of the two input
// tensors.
type Iter2 struct {
xs ts.Tensor
ys ts.Tensor
batchIndex int64
batchSize int64
totalSize int64
device gotch.Device
returnSmallLastBatch bool
}
// NewIter2 returns a new iterator.
//
// This takes as input two tensors which first dimension must match. The
// returned iterator can be used to range over mini-batches of data of
// specified size.
// An error is returned if `xs` and `ys` have different first dimension
// sizes.
//
// # Arguments
//
// * `xs` - the features to be used by the model.
// * `ys` - the targets that the model attempts to predict.
// * `batch_size` - the size of batches to be returned.
func NewIter2(xs, ys ts.Tensor, batchSize int64) (retVal Iter2, err error) {
totalSize := xs.MustSize()[0]
if ys.MustSize()[0] != totalSize {
err = fmt.Errorf("Different dimension for the two inputs: %v - %v", xs.MustSize(), ys.MustSize())
return retVal, err
}
retVal = Iter2{
xs: xs.MustShallowClone(),
ys: ys.MustShallowClone(),
batchIndex: 0,
batchSize: batchSize,
totalSize: totalSize,
returnSmallLastBatch: false,
}
return retVal, nil
}
// MustNewIter2 returns a new iterator.
//
// This takes as input two tensors which first dimension must match. The
// returned iterator can be used to range over mini-batches of data of
// specified size.
// Panics if `xs` and `ys` have different first dimension sizes.
//
// # Arguments
//
// * `xs` - the features to be used by the model.
// * `ys` - the targets that the model attempts to predict.
// * `batch_size` - the size of batches to be returned.
func MustNewIter2(xs, ys ts.Tensor, batchSize int64) (retVal Iter2) {
retVal, err := NewIter2(xs, ys, batchSize)
if err != nil {
log.Fatal(err)
}
return retVal
}
// Shuffle shuffles the dataset.
//
// The iterator would still run over the whole dataset but the order in
// which elements are grouped in mini-batches is randomized.
func (it Iter2) Shuffle() (retVal Iter2) {
index := ts.MustRandperm(it.totalSize, gotch.Int64, gotch.CPU)
it.xs = it.xs.MustIndexSelect(0, index)
it.ys = it.ys.MustIndexSelect(0, index)
return it
}
// ToDevice transfers the mini-batches to a specified device.
func (it Iter2) ToDevice(device gotch.Device) (retVal Iter2) {
it.device = device
return it
}
// ReturnSmallLastBatch when set, returns the last batch even if smaller than the batch size.
func (it Iter2) ReturnSmallLastBatch() (retVal Iter2) {
it.returnSmallLastBatch = true
return it
}
type Iter2Item struct {
Images ts.Tensor
Labels ts.Tensor
}
// Next implements iterator for Iter2
func (it *Iter2) Next() (item Iter2Item, ok bool) {
start := it.batchIndex * it.batchSize
size := it.batchSize
if it.totalSize-start < it.batchSize {
size = it.totalSize - start
}
if (size <= 0) || (!it.returnSmallLastBatch && size < it.batchSize) {
// err = fmt.Errorf("Last small batch error")
return item, false
} else {
it.batchIndex += 1
// Indexing
narrowIndex := ts.NewNarrow(start, start+size)
// ts1 := it.xs.Idx(narrowIndex).MustTo(it.device)
// ts2 := it.ys.Idx(narrowIndex).MustTo(it.device)
ts1 := it.xs.Idx(narrowIndex)
ts2 := it.ys.Idx(narrowIndex)
return Iter2Item{
Images: ts1,
Labels: ts2,
}, true
}
}

View File

@ -30,12 +30,12 @@ func NewConstInit(v float64) constInit {
func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
var err error
kind := gotch.DType2CInt(gotch.Float)
kind := gotch.Float.CInt()
switch {
case c.value == 0.0:
retVal = ts.Zeros(dims, kind, device.CInt())
retVal = ts.MustZeros(dims, kind, device.CInt())
case c.value == 1.0:
retVal = ts.Ones(dims, kind, device.CInt())
retVal = ts.MustOnes(dims, kind, device.CInt())
default:
data := make([]float64, ts.FlattenDim(dims))
for i := range data {
@ -57,7 +57,7 @@ func (c constInit) Set(tensor ts.Tensor) {
log.Fatalf("constInit - Set method call error: %v\n", err)
}
ts.Fill_(scalarVal)
tensor.Fill_(scalarVal)
}
// randnInit :
@ -125,9 +125,9 @@ func NewUniformInit(lo, up float64) uniformInit {
func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
var err error
kind := gotch.DType2CInt(gotch.Float)
tmpTs := ts.Zeros(dims, kind, device.CInt())
retVal, err = tmpTs.Uniform_(u.lo, u.up)
kind := gotch.Float.CInt()
retVal = ts.MustZeros(dims, kind, device.CInt())
retVal.Uniform_(u.lo, u.up)
if err != nil {
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
}
@ -150,13 +150,10 @@ func NewKaimingUniformInit() kaimingUniformInit {
func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
fanIn := factorial(uint64(len(dims) - 1))
bound := math.Sqrt(1.0 / float64(fanIn))
var err error
kind := gotch.DType2CInt(gotch.Float)
tmpTs := ts.Zeros(dims, kind, device.CInt())
retVal, err = tmpTs.Uniform_(-bound, bound)
if err != nil {
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
}
kind := gotch.Float.CInt()
retVal = ts.MustZeros(dims, kind, device.CInt())
retVal.Uniform_(-bound, bound)
return retVal
}

View File

@ -160,7 +160,9 @@ func (vs *VarStore) Load(filepath string) (err error) {
return err
}
retValErr, err := ts.NoGrad(ts.Copy_(currTs, namedTs.Tensor))
retValErr, err := ts.NoGrad(func() {
ts.Copy_(currTs, namedTs.Tensor)
})
if err != nil {
return err
}
@ -205,7 +207,9 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
}
// It's matched. Now, copy in-place the loaded tensor value to var-store
retValErr, err := ts.NoGrad(ts.Copy_(currTs, namedTs.Tensor))
retValErr, err := ts.NoGrad(func() {
ts.Copy_(currTs, namedTs.Tensor)
})
if err != nil {
return nil, err
}
@ -274,7 +278,9 @@ func (vs *VarStore) Copy(src VarStore) (err error) {
if err != nil {
return err
}
retValErr, err := ts.NoGrad(ts.Copy_(v, srcDevTs))
retValErr, err := ts.NoGrad(func() {
ts.Copy_(v, srcDevTs)
})
if err != nil {
return err
}
@ -520,7 +526,7 @@ func (p *Path) Uniform(name string, dims []int64, lo, up float64) (retVal ts.Ten
// will be tracked.
// The variable uses a float tensor initialized randomly using a
// uniform distribution which bounds follow Kaiming initialization.
func (p *Path) Uniform(name string, dims []int64) (retVal ts.Tensor) {
func (p *Path) KaimingUniform(name string, dims []int64) (retVal ts.Tensor) {
// TODO: implement it
// self.var(name, dims, Init::KaimingUniform)
@ -542,12 +548,14 @@ func (p *Path) VarCopy(name string, t ts.Tensor) (retVal ts.Tensor) {
}
v := p.Zeros(name, size)
retValErr, err := ts.NoGrad(ts.Copy_(v, t))
retValErr, err := ts.NoGrad(func() {
ts.Copy_(v, t)
})
if err != nil {
return err
log.Fatal(err)
}
if retValErr != nil {
return retValErr.(error)
log.Fatal(retValErr)
}
return v
@ -555,14 +563,13 @@ func (p *Path) VarCopy(name string, t ts.Tensor) (retVal ts.Tensor) {
// Get gets the tensor corresponding to a given name if present.
func (p *Path) Get(name string) (retVal ts.Tensor, err error) {
path := p.path(name)
p.varstore.variables.mutex.Lock()
defer p.varstore.variables.mutex.Unlock()
v, ok := p.varstore.variables.NamedVariables[path]
v, ok := p.varstore.variables.NamedVariables[name]
if !ok {
err = fmt.Errorf("Path - Get method call error: Cannot find variable for name: %v\n", path)
err = fmt.Errorf("Path - Get method call error: Cannot find variable for name: %v\n", name)
return retVal, err
}
@ -577,7 +584,7 @@ func (p *Path) Entry(name string) (retVal Entry) {
return Entry{
name: name,
variables: p.varstore.variables,
path: &p,
path: *p,
}
}
@ -608,12 +615,14 @@ func (e *Entry) OrVarCopy(tensor ts.Tensor) (retVal ts.Tensor) {
}
v := e.OrZeros(size)
retValErr, err := ts.NoGrad(ts.Copy_(v, tensor))
retValErr, err := ts.NoGrad(func() {
ts.Copy_(v, tensor)
})
if err != nil {
return err
log.Fatal(err)
}
if retValErr != nil {
return retValErr.(error)
log.Fatal(retValErr)
}
return v

View File

@ -248,7 +248,7 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
// `spec` is a function type implements `TensorIndexer`
for _, spec := range indexSpec {
fmt.Printf("spec type: %v\n", reflect.TypeOf(spec).Name())
// fmt.Printf("spec type: %v\n", reflect.TypeOf(spec).Name())
switch reflect.TypeOf(spec).Name() {
case "InsertNewAxis":
@ -291,8 +291,6 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
nextIdx = currIdx + 1
} // end of switch
currTensor.Print()
currTensor = nextTensor
currIdx = nextIdx
}

View File

@ -8,11 +8,11 @@ import (
"log"
"unsafe"
gt "github.com/sugarme/gotch"
"github.com/sugarme/gotch"
lib "github.com/sugarme/gotch/libtch"
)
func (ts Tensor) To(device gt.Device) (retVal Tensor, err error) {
func (ts Tensor) To(device gotch.Device) (retVal Tensor, err error) {
// TODO: how to get pointer to CUDA memory???
// C.cuMemAlloc((*C.ulonglong)(cudaPtr), 1) // 0 byte is invalid
@ -28,7 +28,7 @@ func (ts Tensor) To(device gt.Device) (retVal Tensor, err error) {
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustTo(device gt.Device) (retVal Tensor) {
func (ts Tensor) MustTo(device gotch.Device) (retVal Tensor) {
var err error
retVal, err = ts.To(device)
if err != nil {
@ -271,10 +271,10 @@ func (ts Tensor) MustAddG(other Tensor) {
}
// Totype casts type of tensor to a new tensor with specified DType
func (ts Tensor) Totype(dtype gt.DType) (retVal Tensor, err error) {
func (ts Tensor) Totype(dtype gotch.DType) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
cint, err := gt.DType2CInt(dtype)
cint, err := gotch.DType2CInt(dtype)
if err != nil {
return retVal, err
}
@ -291,7 +291,7 @@ func (ts Tensor) Totype(dtype gt.DType) (retVal Tensor, err error) {
// Totype casts type of tensor to a new tensor with specified DType. It will
// panic if error
func (ts Tensor) MustTotype(dtype gt.DType) (retVal Tensor) {
func (ts Tensor) MustTotype(dtype gotch.DType) (retVal Tensor) {
retVal, err := ts.Totype(dtype)
if err != nil {
log.Fatal(err)
@ -361,6 +361,14 @@ func (ts Tensor) IndexSelect(dim int64, index Tensor) (retVal Tensor, err error)
return retVal, nil
}
func (ts Tensor) MustIndexSelect(dim int64, index Tensor) (retVal Tensor) {
retVal, err := ts.IndexSelect(dim, index)
if err != nil {
log.Fatal(err)
}
return retVal
}
func Zeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
@ -697,3 +705,26 @@ func (ts Tensor) MustDiv1(other Scalar) (retVal Tensor) {
return retVal
}
func Randperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgRandperm(ptr, n, optionKind.CInt(), optionDevice.CInt())
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func MustRandperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (retVal Tensor) {
retVal, err := Randperm(n, optionKind, optionDevice)
if err != nil {
log.Fatal(err)
}
return retVal
}

View File

@ -3,6 +3,7 @@ package vision
// A simple dataset structure shared by various computer vision datasets.
import (
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
)
@ -14,4 +15,16 @@ type Dataset struct {
Labels int64
}
// TODO: implement methods
// Dataset Methods:
//=================
// TrainIter creates an iterator of Iter type for train images and labels
func (ds Dataset) TrainIter(batchSize int64) (retVal nn.Iter2) {
return nn.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchSize)
}
// TestIter creates an iterator of Iter type for test images and labels
func (ds Dataset) TestIter(batchSize int64) (retVal nn.Iter2) {
return nn.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize)
}