feat(nn/data): data interator Iter2
This commit is contained in:
parent
da950fe881
commit
830f9ad9df
|
@ -13,17 +13,18 @@ const (
|
||||||
Label int64 = 10
|
Label int64 = 10
|
||||||
MnistDir string = "../../data/mnist"
|
MnistDir string = "../../data/mnist"
|
||||||
|
|
||||||
epochs = 200
|
epochs = 100
|
||||||
|
batchSize = 256
|
||||||
)
|
)
|
||||||
|
|
||||||
func runLinear() {
|
func runLinear() {
|
||||||
var ds vision.Dataset
|
var ds vision.Dataset
|
||||||
ds = vision.LoadMNISTDir(MnistDir)
|
ds = vision.LoadMNISTDir(MnistDir)
|
||||||
|
|
||||||
fmt.Printf("Train image size: %v\n", ds.TrainImages.MustSize())
|
// fmt.Printf("Train image size: %v\n", ds.TrainImages.MustSize())
|
||||||
fmt.Printf("Train label size: %v\n", ds.TrainLabels.MustSize())
|
// fmt.Printf("Train label size: %v\n", ds.TrainLabels.MustSize())
|
||||||
fmt.Printf("Test image size: %v\n", ds.TestImages.MustSize())
|
// fmt.Printf("Test image size: %v\n", ds.TestImages.MustSize())
|
||||||
fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
|
// fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
|
||||||
|
|
||||||
device := (gotch.CPU).CInt()
|
device := (gotch.CPU).CInt()
|
||||||
dtype := (gotch.Float).CInt()
|
dtype := (gotch.Float).CInt()
|
||||||
|
@ -32,22 +33,51 @@ func runLinear() {
|
||||||
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
|
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
|
||||||
|
|
||||||
for epoch := 0; epoch < epochs; epoch++ {
|
for epoch := 0; epoch < epochs; epoch++ {
|
||||||
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
|
|
||||||
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
|
|
||||||
|
|
||||||
ws.ZeroGrad()
|
var loss ts.Tensor
|
||||||
bs.ZeroGrad()
|
trainIter := ds.TrainIter(batchSize)
|
||||||
loss.Backward()
|
trainIter.Shuffle().ToDevice(gotch.CPU)
|
||||||
|
// item a pair of images and labels as 2 tensors
|
||||||
|
for {
|
||||||
|
batch, ok := trainIter.Next()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
ts.NoGrad(func() {
|
logits := batch.Images.MustMm(ws).MustAdd(bs)
|
||||||
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
loss = logits.MustLogSoftmax(-1, dtype).MustNllLoss(batch.Labels)
|
||||||
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
ws.ZeroGrad()
|
||||||
})
|
bs.ZeroGrad()
|
||||||
|
loss.Backward()
|
||||||
|
|
||||||
|
ts.NoGrad(func() {
|
||||||
|
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||||
|
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
|
||||||
|
* loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
|
||||||
|
*
|
||||||
|
* ws.ZeroGrad()
|
||||||
|
* bs.ZeroGrad()
|
||||||
|
* loss.Backward()
|
||||||
|
*
|
||||||
|
* ts.NoGrad(func() {
|
||||||
|
* ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||||
|
* bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||||
|
* })
|
||||||
|
* loss.Print()
|
||||||
|
* */
|
||||||
|
|
||||||
|
// bs.MustGrad().Print()
|
||||||
|
|
||||||
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
|
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
|
||||||
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||||
|
|
||||||
fmt.Printf("Epoch: %v - Train loss: %v - Test accuracy: %v\n", epoch, loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}), testAccuracy*100)
|
fmt.Printf("Epoch: %v - Test accuracy: %v\n", epoch, testAccuracy*100)
|
||||||
|
|
||||||
|
// fmt.Printf("Epoch: %v - Train loss: %v - Test accuracy: %v\n", epoch, loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}), testAccuracy*100)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -233,3 +233,12 @@ func AtgView(ptr *Ctensor, self Ctensor, sizeData []int64, sizeLen int) {
|
||||||
func AtgDiv1(ptr *Ctensor, self Ctensor, other Cscalar) {
|
func AtgDiv1(ptr *Ctensor, self Ctensor, other Cscalar) {
|
||||||
C.atg_div1(ptr, self, other)
|
C.atg_div1(ptr, self, other)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// void atg_randperm(tensor *, int64_t n, int options_kind, int options_device);
|
||||||
|
func AtgRandperm(ptr *Ctensor, n int64, optionKind int32, optionDevice int32) {
|
||||||
|
cn := *(*C.int64_t)(unsafe.Pointer(&n))
|
||||||
|
coptionKind := *(*C.int)(unsafe.Pointer(&optionKind))
|
||||||
|
coptionDevice := *(*C.int)(unsafe.Pointer(&optionDevice))
|
||||||
|
|
||||||
|
C.atg_randperm(ptr, cn, coptionKind, coptionDevice)
|
||||||
|
}
|
||||||
|
|
138
nn/data.go
Normal file
138
nn/data.go
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
package nn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Iter2 is an iterator over a pair of tensors which have the same first dimension
|
||||||
|
// size.
|
||||||
|
// The typical use case is to iterate over batches. Each batch is a pair
|
||||||
|
// containing a (potentially random) slice of each of the two input
|
||||||
|
// tensors.
|
||||||
|
type Iter2 struct {
|
||||||
|
xs ts.Tensor
|
||||||
|
ys ts.Tensor
|
||||||
|
batchIndex int64
|
||||||
|
batchSize int64
|
||||||
|
totalSize int64
|
||||||
|
device gotch.Device
|
||||||
|
returnSmallLastBatch bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIter2 returns a new iterator.
|
||||||
|
//
|
||||||
|
// This takes as input two tensors which first dimension must match. The
|
||||||
|
// returned iterator can be used to range over mini-batches of data of
|
||||||
|
// specified size.
|
||||||
|
// An error is returned if `xs` and `ys` have different first dimension
|
||||||
|
// sizes.
|
||||||
|
//
|
||||||
|
// # Arguments
|
||||||
|
//
|
||||||
|
// * `xs` - the features to be used by the model.
|
||||||
|
// * `ys` - the targets that the model attempts to predict.
|
||||||
|
// * `batch_size` - the size of batches to be returned.
|
||||||
|
func NewIter2(xs, ys ts.Tensor, batchSize int64) (retVal Iter2, err error) {
|
||||||
|
|
||||||
|
totalSize := xs.MustSize()[0]
|
||||||
|
if ys.MustSize()[0] != totalSize {
|
||||||
|
err = fmt.Errorf("Different dimension for the two inputs: %v - %v", xs.MustSize(), ys.MustSize())
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Iter2{
|
||||||
|
xs: xs.MustShallowClone(),
|
||||||
|
ys: ys.MustShallowClone(),
|
||||||
|
batchIndex: 0,
|
||||||
|
batchSize: batchSize,
|
||||||
|
totalSize: totalSize,
|
||||||
|
returnSmallLastBatch: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustNewIter2 returns a new iterator.
|
||||||
|
//
|
||||||
|
// This takes as input two tensors which first dimension must match. The
|
||||||
|
// returned iterator can be used to range over mini-batches of data of
|
||||||
|
// specified size.
|
||||||
|
// Panics if `xs` and `ys` have different first dimension sizes.
|
||||||
|
//
|
||||||
|
// # Arguments
|
||||||
|
//
|
||||||
|
// * `xs` - the features to be used by the model.
|
||||||
|
// * `ys` - the targets that the model attempts to predict.
|
||||||
|
// * `batch_size` - the size of batches to be returned.
|
||||||
|
func MustNewIter2(xs, ys ts.Tensor, batchSize int64) (retVal Iter2) {
|
||||||
|
retVal, err := NewIter2(xs, ys, batchSize)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shuffle shuffles the dataset.
|
||||||
|
//
|
||||||
|
// The iterator would still run over the whole dataset but the order in
|
||||||
|
// which elements are grouped in mini-batches is randomized.
|
||||||
|
func (it Iter2) Shuffle() (retVal Iter2) {
|
||||||
|
index := ts.MustRandperm(it.totalSize, gotch.Int64, gotch.CPU)
|
||||||
|
|
||||||
|
it.xs = it.xs.MustIndexSelect(0, index)
|
||||||
|
it.ys = it.ys.MustIndexSelect(0, index)
|
||||||
|
return it
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToDevice transfers the mini-batches to a specified device.
|
||||||
|
func (it Iter2) ToDevice(device gotch.Device) (retVal Iter2) {
|
||||||
|
it.device = device
|
||||||
|
return it
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReturnSmallLastBatch when set, returns the last batch even if smaller than the batch size.
|
||||||
|
func (it Iter2) ReturnSmallLastBatch() (retVal Iter2) {
|
||||||
|
it.returnSmallLastBatch = true
|
||||||
|
return it
|
||||||
|
}
|
||||||
|
|
||||||
|
type Iter2Item struct {
|
||||||
|
Images ts.Tensor
|
||||||
|
Labels ts.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next implements iterator for Iter2
|
||||||
|
func (it *Iter2) Next() (item Iter2Item, ok bool) {
|
||||||
|
start := it.batchIndex * it.batchSize
|
||||||
|
size := it.batchSize
|
||||||
|
if it.totalSize-start < it.batchSize {
|
||||||
|
size = it.totalSize - start
|
||||||
|
}
|
||||||
|
|
||||||
|
if (size <= 0) || (!it.returnSmallLastBatch && size < it.batchSize) {
|
||||||
|
// err = fmt.Errorf("Last small batch error")
|
||||||
|
return item, false
|
||||||
|
} else {
|
||||||
|
it.batchIndex += 1
|
||||||
|
|
||||||
|
// Indexing
|
||||||
|
narrowIndex := ts.NewNarrow(start, start+size)
|
||||||
|
|
||||||
|
// ts1 := it.xs.Idx(narrowIndex).MustTo(it.device)
|
||||||
|
// ts2 := it.ys.Idx(narrowIndex).MustTo(it.device)
|
||||||
|
|
||||||
|
ts1 := it.xs.Idx(narrowIndex)
|
||||||
|
ts2 := it.ys.Idx(narrowIndex)
|
||||||
|
|
||||||
|
return Iter2Item{
|
||||||
|
Images: ts1,
|
||||||
|
Labels: ts2,
|
||||||
|
}, true
|
||||||
|
}
|
||||||
|
}
|
25
nn/init.go
25
nn/init.go
|
@ -30,12 +30,12 @@ func NewConstInit(v float64) constInit {
|
||||||
|
|
||||||
func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
||||||
var err error
|
var err error
|
||||||
kind := gotch.DType2CInt(gotch.Float)
|
kind := gotch.Float.CInt()
|
||||||
switch {
|
switch {
|
||||||
case c.value == 0.0:
|
case c.value == 0.0:
|
||||||
retVal = ts.Zeros(dims, kind, device.CInt())
|
retVal = ts.MustZeros(dims, kind, device.CInt())
|
||||||
case c.value == 1.0:
|
case c.value == 1.0:
|
||||||
retVal = ts.Ones(dims, kind, device.CInt())
|
retVal = ts.MustOnes(dims, kind, device.CInt())
|
||||||
default:
|
default:
|
||||||
data := make([]float64, ts.FlattenDim(dims))
|
data := make([]float64, ts.FlattenDim(dims))
|
||||||
for i := range data {
|
for i := range data {
|
||||||
|
@ -57,7 +57,7 @@ func (c constInit) Set(tensor ts.Tensor) {
|
||||||
log.Fatalf("constInit - Set method call error: %v\n", err)
|
log.Fatalf("constInit - Set method call error: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ts.Fill_(scalarVal)
|
tensor.Fill_(scalarVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
// randnInit :
|
// randnInit :
|
||||||
|
@ -125,9 +125,9 @@ func NewUniformInit(lo, up float64) uniformInit {
|
||||||
|
|
||||||
func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
||||||
var err error
|
var err error
|
||||||
kind := gotch.DType2CInt(gotch.Float)
|
kind := gotch.Float.CInt()
|
||||||
tmpTs := ts.Zeros(dims, kind, device.CInt())
|
retVal = ts.MustZeros(dims, kind, device.CInt())
|
||||||
retVal, err = tmpTs.Uniform_(u.lo, u.up)
|
retVal.Uniform_(u.lo, u.up)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
|
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
|
||||||
}
|
}
|
||||||
|
@ -150,13 +150,10 @@ func NewKaimingUniformInit() kaimingUniformInit {
|
||||||
func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
||||||
fanIn := factorial(uint64(len(dims) - 1))
|
fanIn := factorial(uint64(len(dims) - 1))
|
||||||
bound := math.Sqrt(1.0 / float64(fanIn))
|
bound := math.Sqrt(1.0 / float64(fanIn))
|
||||||
var err error
|
kind := gotch.Float.CInt()
|
||||||
kind := gotch.DType2CInt(gotch.Float)
|
retVal = ts.MustZeros(dims, kind, device.CInt())
|
||||||
tmpTs := ts.Zeros(dims, kind, device.CInt())
|
retVal.Uniform_(-bound, bound)
|
||||||
retVal, err = tmpTs.Uniform_(-bound, bound)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
|
|
||||||
}
|
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -160,7 +160,9 @@ func (vs *VarStore) Load(filepath string) (err error) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
retValErr, err := ts.NoGrad(ts.Copy_(currTs, namedTs.Tensor))
|
retValErr, err := ts.NoGrad(func() {
|
||||||
|
ts.Copy_(currTs, namedTs.Tensor)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -205,7 +207,9 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// It's matched. Now, copy in-place the loaded tensor value to var-store
|
// It's matched. Now, copy in-place the loaded tensor value to var-store
|
||||||
retValErr, err := ts.NoGrad(ts.Copy_(currTs, namedTs.Tensor))
|
retValErr, err := ts.NoGrad(func() {
|
||||||
|
ts.Copy_(currTs, namedTs.Tensor)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -274,7 +278,9 @@ func (vs *VarStore) Copy(src VarStore) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
retValErr, err := ts.NoGrad(ts.Copy_(v, srcDevTs))
|
retValErr, err := ts.NoGrad(func() {
|
||||||
|
ts.Copy_(v, srcDevTs)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -520,7 +526,7 @@ func (p *Path) Uniform(name string, dims []int64, lo, up float64) (retVal ts.Ten
|
||||||
// will be tracked.
|
// will be tracked.
|
||||||
// The variable uses a float tensor initialized randomly using a
|
// The variable uses a float tensor initialized randomly using a
|
||||||
// uniform distribution which bounds follow Kaiming initialization.
|
// uniform distribution which bounds follow Kaiming initialization.
|
||||||
func (p *Path) Uniform(name string, dims []int64) (retVal ts.Tensor) {
|
func (p *Path) KaimingUniform(name string, dims []int64) (retVal ts.Tensor) {
|
||||||
// TODO: implement it
|
// TODO: implement it
|
||||||
// self.var(name, dims, Init::KaimingUniform)
|
// self.var(name, dims, Init::KaimingUniform)
|
||||||
|
|
||||||
|
@ -542,12 +548,14 @@ func (p *Path) VarCopy(name string, t ts.Tensor) (retVal ts.Tensor) {
|
||||||
}
|
}
|
||||||
v := p.Zeros(name, size)
|
v := p.Zeros(name, size)
|
||||||
|
|
||||||
retValErr, err := ts.NoGrad(ts.Copy_(v, t))
|
retValErr, err := ts.NoGrad(func() {
|
||||||
|
ts.Copy_(v, t)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
if retValErr != nil {
|
if retValErr != nil {
|
||||||
return retValErr.(error)
|
log.Fatal(retValErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return v
|
return v
|
||||||
|
@ -555,14 +563,13 @@ func (p *Path) VarCopy(name string, t ts.Tensor) (retVal ts.Tensor) {
|
||||||
|
|
||||||
// Get gets the tensor corresponding to a given name if present.
|
// Get gets the tensor corresponding to a given name if present.
|
||||||
func (p *Path) Get(name string) (retVal ts.Tensor, err error) {
|
func (p *Path) Get(name string) (retVal ts.Tensor, err error) {
|
||||||
path := p.path(name)
|
|
||||||
|
|
||||||
p.varstore.variables.mutex.Lock()
|
p.varstore.variables.mutex.Lock()
|
||||||
defer p.varstore.variables.mutex.Unlock()
|
defer p.varstore.variables.mutex.Unlock()
|
||||||
|
|
||||||
v, ok := p.varstore.variables.NamedVariables[path]
|
v, ok := p.varstore.variables.NamedVariables[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
err = fmt.Errorf("Path - Get method call error: Cannot find variable for name: %v\n", path)
|
err = fmt.Errorf("Path - Get method call error: Cannot find variable for name: %v\n", name)
|
||||||
return retVal, err
|
return retVal, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -577,7 +584,7 @@ func (p *Path) Entry(name string) (retVal Entry) {
|
||||||
return Entry{
|
return Entry{
|
||||||
name: name,
|
name: name,
|
||||||
variables: p.varstore.variables,
|
variables: p.varstore.variables,
|
||||||
path: &p,
|
path: *p,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -608,12 +615,14 @@ func (e *Entry) OrVarCopy(tensor ts.Tensor) (retVal ts.Tensor) {
|
||||||
}
|
}
|
||||||
v := e.OrZeros(size)
|
v := e.OrZeros(size)
|
||||||
|
|
||||||
retValErr, err := ts.NoGrad(ts.Copy_(v, tensor))
|
retValErr, err := ts.NoGrad(func() {
|
||||||
|
ts.Copy_(v, tensor)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
if retValErr != nil {
|
if retValErr != nil {
|
||||||
return retValErr.(error)
|
log.Fatal(retValErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return v
|
return v
|
||||||
|
|
|
@ -248,7 +248,7 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
||||||
// `spec` is a function type implements `TensorIndexer`
|
// `spec` is a function type implements `TensorIndexer`
|
||||||
for _, spec := range indexSpec {
|
for _, spec := range indexSpec {
|
||||||
|
|
||||||
fmt.Printf("spec type: %v\n", reflect.TypeOf(spec).Name())
|
// fmt.Printf("spec type: %v\n", reflect.TypeOf(spec).Name())
|
||||||
|
|
||||||
switch reflect.TypeOf(spec).Name() {
|
switch reflect.TypeOf(spec).Name() {
|
||||||
case "InsertNewAxis":
|
case "InsertNewAxis":
|
||||||
|
@ -291,8 +291,6 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
||||||
nextIdx = currIdx + 1
|
nextIdx = currIdx + 1
|
||||||
} // end of switch
|
} // end of switch
|
||||||
|
|
||||||
currTensor.Print()
|
|
||||||
|
|
||||||
currTensor = nextTensor
|
currTensor = nextTensor
|
||||||
currIdx = nextIdx
|
currIdx = nextIdx
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,11 +8,11 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
gt "github.com/sugarme/gotch"
|
"github.com/sugarme/gotch"
|
||||||
lib "github.com/sugarme/gotch/libtch"
|
lib "github.com/sugarme/gotch/libtch"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (ts Tensor) To(device gt.Device) (retVal Tensor, err error) {
|
func (ts Tensor) To(device gotch.Device) (retVal Tensor, err error) {
|
||||||
|
|
||||||
// TODO: how to get pointer to CUDA memory???
|
// TODO: how to get pointer to CUDA memory???
|
||||||
// C.cuMemAlloc((*C.ulonglong)(cudaPtr), 1) // 0 byte is invalid
|
// C.cuMemAlloc((*C.ulonglong)(cudaPtr), 1) // 0 byte is invalid
|
||||||
|
@ -28,7 +28,7 @@ func (ts Tensor) To(device gt.Device) (retVal Tensor, err error) {
|
||||||
return Tensor{ctensor: *ptr}, nil
|
return Tensor{ctensor: *ptr}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts Tensor) MustTo(device gt.Device) (retVal Tensor) {
|
func (ts Tensor) MustTo(device gotch.Device) (retVal Tensor) {
|
||||||
var err error
|
var err error
|
||||||
retVal, err = ts.To(device)
|
retVal, err = ts.To(device)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -271,10 +271,10 @@ func (ts Tensor) MustAddG(other Tensor) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Totype casts type of tensor to a new tensor with specified DType
|
// Totype casts type of tensor to a new tensor with specified DType
|
||||||
func (ts Tensor) Totype(dtype gt.DType) (retVal Tensor, err error) {
|
func (ts Tensor) Totype(dtype gotch.DType) (retVal Tensor, err error) {
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
defer C.free(unsafe.Pointer(ptr))
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
cint, err := gt.DType2CInt(dtype)
|
cint, err := gotch.DType2CInt(dtype)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return retVal, err
|
return retVal, err
|
||||||
}
|
}
|
||||||
|
@ -291,7 +291,7 @@ func (ts Tensor) Totype(dtype gt.DType) (retVal Tensor, err error) {
|
||||||
|
|
||||||
// Totype casts type of tensor to a new tensor with specified DType. It will
|
// Totype casts type of tensor to a new tensor with specified DType. It will
|
||||||
// panic if error
|
// panic if error
|
||||||
func (ts Tensor) MustTotype(dtype gt.DType) (retVal Tensor) {
|
func (ts Tensor) MustTotype(dtype gotch.DType) (retVal Tensor) {
|
||||||
retVal, err := ts.Totype(dtype)
|
retVal, err := ts.Totype(dtype)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
|
@ -361,6 +361,14 @@ func (ts Tensor) IndexSelect(dim int64, index Tensor) (retVal Tensor, err error)
|
||||||
|
|
||||||
return retVal, nil
|
return retVal, nil
|
||||||
}
|
}
|
||||||
|
func (ts Tensor) MustIndexSelect(dim int64, index Tensor) (retVal Tensor) {
|
||||||
|
retVal, err := ts.IndexSelect(dim, index)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
func Zeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
|
func Zeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
@ -697,3 +705,26 @@ func (ts Tensor) MustDiv1(other Scalar) (retVal Tensor) {
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Randperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (retVal Tensor, err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
lib.AtgRandperm(ptr, n, optionKind.CInt(), optionDevice.CInt())
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func MustRandperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (retVal Tensor) {
|
||||||
|
retVal, err := Randperm(n, optionKind, optionDevice)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package vision
|
||||||
// A simple dataset structure shared by various computer vision datasets.
|
// A simple dataset structure shared by various computer vision datasets.
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/sugarme/gotch/nn"
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,4 +15,16 @@ type Dataset struct {
|
||||||
Labels int64
|
Labels int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: implement methods
|
// Dataset Methods:
|
||||||
|
//=================
|
||||||
|
|
||||||
|
// TrainIter creates an iterator of Iter type for train images and labels
|
||||||
|
func (ds Dataset) TrainIter(batchSize int64) (retVal nn.Iter2) {
|
||||||
|
return nn.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchSize)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIter creates an iterator of Iter type for test images and labels
|
||||||
|
func (ds Dataset) TestIter(batchSize int64) (retVal nn.Iter2) {
|
||||||
|
return nn.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user