feat(nn/data): data interator Iter2
This commit is contained in:
parent
da950fe881
commit
830f9ad9df
|
@ -13,17 +13,18 @@ const (
|
|||
Label int64 = 10
|
||||
MnistDir string = "../../data/mnist"
|
||||
|
||||
epochs = 200
|
||||
epochs = 100
|
||||
batchSize = 256
|
||||
)
|
||||
|
||||
func runLinear() {
|
||||
var ds vision.Dataset
|
||||
ds = vision.LoadMNISTDir(MnistDir)
|
||||
|
||||
fmt.Printf("Train image size: %v\n", ds.TrainImages.MustSize())
|
||||
fmt.Printf("Train label size: %v\n", ds.TrainLabels.MustSize())
|
||||
fmt.Printf("Test image size: %v\n", ds.TestImages.MustSize())
|
||||
fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
|
||||
// fmt.Printf("Train image size: %v\n", ds.TrainImages.MustSize())
|
||||
// fmt.Printf("Train label size: %v\n", ds.TrainLabels.MustSize())
|
||||
// fmt.Printf("Test image size: %v\n", ds.TestImages.MustSize())
|
||||
// fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
|
||||
|
||||
device := (gotch.CPU).CInt()
|
||||
dtype := (gotch.Float).CInt()
|
||||
|
@ -32,22 +33,51 @@ func runLinear() {
|
|||
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
|
||||
|
||||
for epoch := 0; epoch < epochs; epoch++ {
|
||||
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
|
||||
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
|
||||
|
||||
ws.ZeroGrad()
|
||||
bs.ZeroGrad()
|
||||
loss.Backward()
|
||||
var loss ts.Tensor
|
||||
trainIter := ds.TrainIter(batchSize)
|
||||
trainIter.Shuffle().ToDevice(gotch.CPU)
|
||||
// item a pair of images and labels as 2 tensors
|
||||
for {
|
||||
batch, ok := trainIter.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
ts.NoGrad(func() {
|
||||
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
})
|
||||
logits := batch.Images.MustMm(ws).MustAdd(bs)
|
||||
loss = logits.MustLogSoftmax(-1, dtype).MustNllLoss(batch.Labels)
|
||||
ws.ZeroGrad()
|
||||
bs.ZeroGrad()
|
||||
loss.Backward()
|
||||
|
||||
ts.NoGrad(func() {
|
||||
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
* logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
|
||||
* loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
|
||||
*
|
||||
* ws.ZeroGrad()
|
||||
* bs.ZeroGrad()
|
||||
* loss.Backward()
|
||||
*
|
||||
* ts.NoGrad(func() {
|
||||
* ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
* bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
* })
|
||||
* loss.Print()
|
||||
* */
|
||||
|
||||
// bs.MustGrad().Print()
|
||||
|
||||
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
|
||||
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
|
||||
fmt.Printf("Epoch: %v - Train loss: %v - Test accuracy: %v\n", epoch, loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}), testAccuracy*100)
|
||||
fmt.Printf("Epoch: %v - Test accuracy: %v\n", epoch, testAccuracy*100)
|
||||
|
||||
// fmt.Printf("Epoch: %v - Train loss: %v - Test accuracy: %v\n", epoch, loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}), testAccuracy*100)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -233,3 +233,12 @@ func AtgView(ptr *Ctensor, self Ctensor, sizeData []int64, sizeLen int) {
|
|||
func AtgDiv1(ptr *Ctensor, self Ctensor, other Cscalar) {
|
||||
C.atg_div1(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_randperm(tensor *, int64_t n, int options_kind, int options_device);
|
||||
func AtgRandperm(ptr *Ctensor, n int64, optionKind int32, optionDevice int32) {
|
||||
cn := *(*C.int64_t)(unsafe.Pointer(&n))
|
||||
coptionKind := *(*C.int)(unsafe.Pointer(&optionKind))
|
||||
coptionDevice := *(*C.int)(unsafe.Pointer(&optionDevice))
|
||||
|
||||
C.atg_randperm(ptr, cn, coptionKind, coptionDevice)
|
||||
}
|
||||
|
|
138
nn/data.go
Normal file
138
nn/data.go
Normal file
|
@ -0,0 +1,138 @@
|
|||
package nn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// Iter2 is an iterator over a pair of tensors which have the same first dimension
|
||||
// size.
|
||||
// The typical use case is to iterate over batches. Each batch is a pair
|
||||
// containing a (potentially random) slice of each of the two input
|
||||
// tensors.
|
||||
type Iter2 struct {
|
||||
xs ts.Tensor
|
||||
ys ts.Tensor
|
||||
batchIndex int64
|
||||
batchSize int64
|
||||
totalSize int64
|
||||
device gotch.Device
|
||||
returnSmallLastBatch bool
|
||||
}
|
||||
|
||||
// NewIter2 returns a new iterator.
|
||||
//
|
||||
// This takes as input two tensors which first dimension must match. The
|
||||
// returned iterator can be used to range over mini-batches of data of
|
||||
// specified size.
|
||||
// An error is returned if `xs` and `ys` have different first dimension
|
||||
// sizes.
|
||||
//
|
||||
// # Arguments
|
||||
//
|
||||
// * `xs` - the features to be used by the model.
|
||||
// * `ys` - the targets that the model attempts to predict.
|
||||
// * `batch_size` - the size of batches to be returned.
|
||||
func NewIter2(xs, ys ts.Tensor, batchSize int64) (retVal Iter2, err error) {
|
||||
|
||||
totalSize := xs.MustSize()[0]
|
||||
if ys.MustSize()[0] != totalSize {
|
||||
err = fmt.Errorf("Different dimension for the two inputs: %v - %v", xs.MustSize(), ys.MustSize())
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Iter2{
|
||||
xs: xs.MustShallowClone(),
|
||||
ys: ys.MustShallowClone(),
|
||||
batchIndex: 0,
|
||||
batchSize: batchSize,
|
||||
totalSize: totalSize,
|
||||
returnSmallLastBatch: false,
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// MustNewIter2 returns a new iterator.
|
||||
//
|
||||
// This takes as input two tensors which first dimension must match. The
|
||||
// returned iterator can be used to range over mini-batches of data of
|
||||
// specified size.
|
||||
// Panics if `xs` and `ys` have different first dimension sizes.
|
||||
//
|
||||
// # Arguments
|
||||
//
|
||||
// * `xs` - the features to be used by the model.
|
||||
// * `ys` - the targets that the model attempts to predict.
|
||||
// * `batch_size` - the size of batches to be returned.
|
||||
func MustNewIter2(xs, ys ts.Tensor, batchSize int64) (retVal Iter2) {
|
||||
retVal, err := NewIter2(xs, ys, batchSize)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// Shuffle shuffles the dataset.
|
||||
//
|
||||
// The iterator would still run over the whole dataset but the order in
|
||||
// which elements are grouped in mini-batches is randomized.
|
||||
func (it Iter2) Shuffle() (retVal Iter2) {
|
||||
index := ts.MustRandperm(it.totalSize, gotch.Int64, gotch.CPU)
|
||||
|
||||
it.xs = it.xs.MustIndexSelect(0, index)
|
||||
it.ys = it.ys.MustIndexSelect(0, index)
|
||||
return it
|
||||
}
|
||||
|
||||
// ToDevice transfers the mini-batches to a specified device.
|
||||
func (it Iter2) ToDevice(device gotch.Device) (retVal Iter2) {
|
||||
it.device = device
|
||||
return it
|
||||
}
|
||||
|
||||
// ReturnSmallLastBatch when set, returns the last batch even if smaller than the batch size.
|
||||
func (it Iter2) ReturnSmallLastBatch() (retVal Iter2) {
|
||||
it.returnSmallLastBatch = true
|
||||
return it
|
||||
}
|
||||
|
||||
type Iter2Item struct {
|
||||
Images ts.Tensor
|
||||
Labels ts.Tensor
|
||||
}
|
||||
|
||||
// Next implements iterator for Iter2
|
||||
func (it *Iter2) Next() (item Iter2Item, ok bool) {
|
||||
start := it.batchIndex * it.batchSize
|
||||
size := it.batchSize
|
||||
if it.totalSize-start < it.batchSize {
|
||||
size = it.totalSize - start
|
||||
}
|
||||
|
||||
if (size <= 0) || (!it.returnSmallLastBatch && size < it.batchSize) {
|
||||
// err = fmt.Errorf("Last small batch error")
|
||||
return item, false
|
||||
} else {
|
||||
it.batchIndex += 1
|
||||
|
||||
// Indexing
|
||||
narrowIndex := ts.NewNarrow(start, start+size)
|
||||
|
||||
// ts1 := it.xs.Idx(narrowIndex).MustTo(it.device)
|
||||
// ts2 := it.ys.Idx(narrowIndex).MustTo(it.device)
|
||||
|
||||
ts1 := it.xs.Idx(narrowIndex)
|
||||
ts2 := it.ys.Idx(narrowIndex)
|
||||
|
||||
return Iter2Item{
|
||||
Images: ts1,
|
||||
Labels: ts2,
|
||||
}, true
|
||||
}
|
||||
}
|
25
nn/init.go
25
nn/init.go
|
@ -30,12 +30,12 @@ func NewConstInit(v float64) constInit {
|
|||
|
||||
func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
||||
var err error
|
||||
kind := gotch.DType2CInt(gotch.Float)
|
||||
kind := gotch.Float.CInt()
|
||||
switch {
|
||||
case c.value == 0.0:
|
||||
retVal = ts.Zeros(dims, kind, device.CInt())
|
||||
retVal = ts.MustZeros(dims, kind, device.CInt())
|
||||
case c.value == 1.0:
|
||||
retVal = ts.Ones(dims, kind, device.CInt())
|
||||
retVal = ts.MustOnes(dims, kind, device.CInt())
|
||||
default:
|
||||
data := make([]float64, ts.FlattenDim(dims))
|
||||
for i := range data {
|
||||
|
@ -57,7 +57,7 @@ func (c constInit) Set(tensor ts.Tensor) {
|
|||
log.Fatalf("constInit - Set method call error: %v\n", err)
|
||||
}
|
||||
|
||||
ts.Fill_(scalarVal)
|
||||
tensor.Fill_(scalarVal)
|
||||
}
|
||||
|
||||
// randnInit :
|
||||
|
@ -125,9 +125,9 @@ func NewUniformInit(lo, up float64) uniformInit {
|
|||
|
||||
func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
||||
var err error
|
||||
kind := gotch.DType2CInt(gotch.Float)
|
||||
tmpTs := ts.Zeros(dims, kind, device.CInt())
|
||||
retVal, err = tmpTs.Uniform_(u.lo, u.up)
|
||||
kind := gotch.Float.CInt()
|
||||
retVal = ts.MustZeros(dims, kind, device.CInt())
|
||||
retVal.Uniform_(u.lo, u.up)
|
||||
if err != nil {
|
||||
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
|
||||
}
|
||||
|
@ -150,13 +150,10 @@ func NewKaimingUniformInit() kaimingUniformInit {
|
|||
func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
||||
fanIn := factorial(uint64(len(dims) - 1))
|
||||
bound := math.Sqrt(1.0 / float64(fanIn))
|
||||
var err error
|
||||
kind := gotch.DType2CInt(gotch.Float)
|
||||
tmpTs := ts.Zeros(dims, kind, device.CInt())
|
||||
retVal, err = tmpTs.Uniform_(-bound, bound)
|
||||
if err != nil {
|
||||
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
|
||||
}
|
||||
kind := gotch.Float.CInt()
|
||||
retVal = ts.MustZeros(dims, kind, device.CInt())
|
||||
retVal.Uniform_(-bound, bound)
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
|
|
|
@ -160,7 +160,9 @@ func (vs *VarStore) Load(filepath string) (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
retValErr, err := ts.NoGrad(ts.Copy_(currTs, namedTs.Tensor))
|
||||
retValErr, err := ts.NoGrad(func() {
|
||||
ts.Copy_(currTs, namedTs.Tensor)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -205,7 +207,9 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
|
|||
}
|
||||
|
||||
// It's matched. Now, copy in-place the loaded tensor value to var-store
|
||||
retValErr, err := ts.NoGrad(ts.Copy_(currTs, namedTs.Tensor))
|
||||
retValErr, err := ts.NoGrad(func() {
|
||||
ts.Copy_(currTs, namedTs.Tensor)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -274,7 +278,9 @@ func (vs *VarStore) Copy(src VarStore) (err error) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
retValErr, err := ts.NoGrad(ts.Copy_(v, srcDevTs))
|
||||
retValErr, err := ts.NoGrad(func() {
|
||||
ts.Copy_(v, srcDevTs)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -520,7 +526,7 @@ func (p *Path) Uniform(name string, dims []int64, lo, up float64) (retVal ts.Ten
|
|||
// will be tracked.
|
||||
// The variable uses a float tensor initialized randomly using a
|
||||
// uniform distribution which bounds follow Kaiming initialization.
|
||||
func (p *Path) Uniform(name string, dims []int64) (retVal ts.Tensor) {
|
||||
func (p *Path) KaimingUniform(name string, dims []int64) (retVal ts.Tensor) {
|
||||
// TODO: implement it
|
||||
// self.var(name, dims, Init::KaimingUniform)
|
||||
|
||||
|
@ -542,12 +548,14 @@ func (p *Path) VarCopy(name string, t ts.Tensor) (retVal ts.Tensor) {
|
|||
}
|
||||
v := p.Zeros(name, size)
|
||||
|
||||
retValErr, err := ts.NoGrad(ts.Copy_(v, t))
|
||||
retValErr, err := ts.NoGrad(func() {
|
||||
ts.Copy_(v, t)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
log.Fatal(err)
|
||||
}
|
||||
if retValErr != nil {
|
||||
return retValErr.(error)
|
||||
log.Fatal(retValErr)
|
||||
}
|
||||
|
||||
return v
|
||||
|
@ -555,14 +563,13 @@ func (p *Path) VarCopy(name string, t ts.Tensor) (retVal ts.Tensor) {
|
|||
|
||||
// Get gets the tensor corresponding to a given name if present.
|
||||
func (p *Path) Get(name string) (retVal ts.Tensor, err error) {
|
||||
path := p.path(name)
|
||||
|
||||
p.varstore.variables.mutex.Lock()
|
||||
defer p.varstore.variables.mutex.Unlock()
|
||||
|
||||
v, ok := p.varstore.variables.NamedVariables[path]
|
||||
v, ok := p.varstore.variables.NamedVariables[name]
|
||||
if !ok {
|
||||
err = fmt.Errorf("Path - Get method call error: Cannot find variable for name: %v\n", path)
|
||||
err = fmt.Errorf("Path - Get method call error: Cannot find variable for name: %v\n", name)
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
|
@ -577,7 +584,7 @@ func (p *Path) Entry(name string) (retVal Entry) {
|
|||
return Entry{
|
||||
name: name,
|
||||
variables: p.varstore.variables,
|
||||
path: &p,
|
||||
path: *p,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -608,12 +615,14 @@ func (e *Entry) OrVarCopy(tensor ts.Tensor) (retVal ts.Tensor) {
|
|||
}
|
||||
v := e.OrZeros(size)
|
||||
|
||||
retValErr, err := ts.NoGrad(ts.Copy_(v, tensor))
|
||||
retValErr, err := ts.NoGrad(func() {
|
||||
ts.Copy_(v, tensor)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
log.Fatal(err)
|
||||
}
|
||||
if retValErr != nil {
|
||||
return retValErr.(error)
|
||||
log.Fatal(retValErr)
|
||||
}
|
||||
|
||||
return v
|
||||
|
|
|
@ -248,7 +248,7 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
|||
// `spec` is a function type implements `TensorIndexer`
|
||||
for _, spec := range indexSpec {
|
||||
|
||||
fmt.Printf("spec type: %v\n", reflect.TypeOf(spec).Name())
|
||||
// fmt.Printf("spec type: %v\n", reflect.TypeOf(spec).Name())
|
||||
|
||||
switch reflect.TypeOf(spec).Name() {
|
||||
case "InsertNewAxis":
|
||||
|
@ -291,8 +291,6 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
|||
nextIdx = currIdx + 1
|
||||
} // end of switch
|
||||
|
||||
currTensor.Print()
|
||||
|
||||
currTensor = nextTensor
|
||||
currIdx = nextIdx
|
||||
}
|
||||
|
|
|
@ -8,11 +8,11 @@ import (
|
|||
"log"
|
||||
"unsafe"
|
||||
|
||||
gt "github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch"
|
||||
lib "github.com/sugarme/gotch/libtch"
|
||||
)
|
||||
|
||||
func (ts Tensor) To(device gt.Device) (retVal Tensor, err error) {
|
||||
func (ts Tensor) To(device gotch.Device) (retVal Tensor, err error) {
|
||||
|
||||
// TODO: how to get pointer to CUDA memory???
|
||||
// C.cuMemAlloc((*C.ulonglong)(cudaPtr), 1) // 0 byte is invalid
|
||||
|
@ -28,7 +28,7 @@ func (ts Tensor) To(device gt.Device) (retVal Tensor, err error) {
|
|||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustTo(device gt.Device) (retVal Tensor) {
|
||||
func (ts Tensor) MustTo(device gotch.Device) (retVal Tensor) {
|
||||
var err error
|
||||
retVal, err = ts.To(device)
|
||||
if err != nil {
|
||||
|
@ -271,10 +271,10 @@ func (ts Tensor) MustAddG(other Tensor) {
|
|||
}
|
||||
|
||||
// Totype casts type of tensor to a new tensor with specified DType
|
||||
func (ts Tensor) Totype(dtype gt.DType) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Totype(dtype gotch.DType) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
cint, err := gt.DType2CInt(dtype)
|
||||
cint, err := gotch.DType2CInt(dtype)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
@ -291,7 +291,7 @@ func (ts Tensor) Totype(dtype gt.DType) (retVal Tensor, err error) {
|
|||
|
||||
// Totype casts type of tensor to a new tensor with specified DType. It will
|
||||
// panic if error
|
||||
func (ts Tensor) MustTotype(dtype gt.DType) (retVal Tensor) {
|
||||
func (ts Tensor) MustTotype(dtype gotch.DType) (retVal Tensor) {
|
||||
retVal, err := ts.Totype(dtype)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -361,6 +361,14 @@ func (ts Tensor) IndexSelect(dim int64, index Tensor) (retVal Tensor, err error)
|
|||
|
||||
return retVal, nil
|
||||
}
|
||||
func (ts Tensor) MustIndexSelect(dim int64, index Tensor) (retVal Tensor) {
|
||||
retVal, err := ts.IndexSelect(dim, index)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func Zeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
@ -697,3 +705,26 @@ func (ts Tensor) MustDiv1(other Scalar) (retVal Tensor) {
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func Randperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgRandperm(ptr, n, optionKind.CInt(), optionDevice.CInt())
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustRandperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (retVal Tensor) {
|
||||
retVal, err := Randperm(n, optionKind, optionDevice)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package vision
|
|||
// A simple dataset structure shared by various computer vision datasets.
|
||||
|
||||
import (
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
|
@ -14,4 +15,16 @@ type Dataset struct {
|
|||
Labels int64
|
||||
}
|
||||
|
||||
// TODO: implement methods
|
||||
// Dataset Methods:
|
||||
//=================
|
||||
|
||||
// TrainIter creates an iterator of Iter type for train images and labels
|
||||
func (ds Dataset) TrainIter(batchSize int64) (retVal nn.Iter2) {
|
||||
return nn.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchSize)
|
||||
|
||||
}
|
||||
|
||||
// TestIter creates an iterator of Iter type for test images and labels
|
||||
func (ds Dataset) TestIter(batchSize int64) (retVal nn.Iter2) {
|
||||
return nn.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user