feat(dutil): added DataSet and DataLoader

This commit is contained in:
sugarme 2020-12-31 12:48:12 +11:00
parent c8ccc03a19
commit 595d40bba6
10 changed files with 1014 additions and 50 deletions

View File

@ -4,55 +4,24 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Fixed
- [#...]: Fix a bug with...
### Changed
- [#...]:
## [0.3.6]
### Added
- [#...]:
- Added `dutil` sub-package that serves Pytorch `DataSet` and `DataLoader` concepts
## [0.1.8]
### Changed
- [#10]: `ts.Drop()` and `ts.MustDrop()` now can call multiple times without panic
## [0.1.9]
### Changed
- Reverse changes [#10] to original.
## [0.1.10]
## [0.3.5]
### Added
- Added `tensor.SaveMultiNew`
## [0.2.0]
- Added function `gotch.CudaIfAvailable()`. NOTE that: `device := gotch.NewCuda().CudaIfAvailable()` will throw error if CUDA is not available.
### Changed
- Convert all APIs to using **Pointer Receiver**
- Switched back to install libtorch inside gotch library as go init() function is triggered after cgo called.
## [0.3.4]
### Added
- Added drawing image label at `example/yolo` example
- Added some example images and README files for `example/yolo` and `example/neural-style-transfer`
## [0.3.0]
### Changed
- Updated to Pytorch C++ APIs v1.7.0
- Switched back to `lib.AtoAddParametersOld` as the `ato_add_parameters` has not been implemented correctly. Using the updated API will cause optimizer stops working.
## [0.3.1]
### Changed
- Changed to use `map[string]*Tensor` at `nn/varstore.go`
- Changed to use `*Path` argument of `NewLayerNorm` method at `nn/layer-norm.go`
- Lots of clean-up return variables i.e. retVal, err
- [#4] Automatically download and install Libtorch and setup environment variables.
## [0.3.2]
@ -65,18 +34,42 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- nn/sequential: fixed missing case number of layers = 1 causing panic
- nn/varstore: fixed(nn/varstore): fixed nil pointer at LoadPartial due to not break loop
## [0.3.4]
### Added
- [#4] Automatically download and install Libtorch and setup environment variables.
## [0.3.5]
### Added
- Added function `gotch.CudaIfAvailable()`. NOTE that: `device := gotch.NewCuda().CudaIfAvailable()` will throw error if CUDA is not available.
## [0.3.1]
### Changed
- Switched back to install libtorch inside gotch library as go init() function is triggered after cgo called.
- Changed to use `map[string]*Tensor` at `nn/varstore.go`
- Changed to use `*Path` argument of `NewLayerNorm` method at `nn/layer-norm.go`
- Lots of clean-up return variables i.e. retVal, err
## [0.3.0]
### Changed
- Updated to Pytorch C++ APIs v1.7.0
- Switched back to `lib.AtoAddParametersOld` as the `ato_add_parameters` has not been implemented correctly. Using the updated API will cause optimizer stops working.
## [0.2.0]
### Changed
- Convert all APIs to using **Pointer Receiver**
### Added
- Added drawing image label at `example/yolo` example
- Added some example images and README files for `example/yolo` and `example/neural-style-transfer`
## [0.1.10]
### Added
- Added `tensor.SaveMultiNew`
## [0.1.9]
### Changed
- Reverse changes [#10] to original.
## [0.1.8]
### Changed
- [#10]: `ts.Drop()` and `ts.MustDrop()` now can call multiple times without panic
[#10]: https://github.com/sugarme/gotch/issues/10

117
dutil/dataloader.go Normal file
View File

@ -0,0 +1,117 @@
package dutil
import (
"fmt"
"reflect"
)
// DataLoader combines a dataset and a sampler and provides
// an iterable over the given dataset.
type DataLoader struct {
dataset Dataset
indexes []int // order of samples in dataset for interation.
batchSize int
currIdx int
}
func NewDataLoader(data Dataset, s Sampler) (*DataLoader, error) {
dkind, err := checkDKind(data)
if err != nil {
return nil, err
}
// Use default Sampler if no specified
if s == nil {
switch dkind {
case SliceDKind:
s = NewSequentialSampler(data.Len())
case MapDKind:
s, err = NewRandomSampler(data.Len())
if err != nil {
return nil, err
}
}
}
return &DataLoader{
dataset: data,
indexes: s.Sample(),
batchSize: s.BatchSize(),
currIdx: 0,
}, nil
}
func checkDKind(data Dataset) (DatasetKind, error) {
dtyp := data.DType()
dkind := dtyp.Kind().String()
switch dkind {
case "slice":
return SliceDKind, nil
case "map":
if dtyp.Key().Kind().String() != "string" {
err := fmt.Errorf("Invalid Dataset. Dataset should be a collection data of type '[]interface{}' or 'map[string]interface{}'")
return InvalidDKind, err
}
return MapDKind, nil
default: // other types are invalid
err := fmt.Errorf("Invalid Dataset. Dataset should be a collection data of type '[]interface{}' or 'map[string]interface{}'")
return InvalidDKind, err
}
}
// Next acts as iterator to return next sample(s) from dataset.
func (dl *DataLoader) Next() (interface{}, error) {
if !dl.HasNext() {
err := fmt.Errorf("Next Error: no more item to iterate.\n")
return nil, err
}
// Non-batching
if dl.batchSize == 1 {
item, err := dl.dataset.Item(dl.currIdx)
if err != nil {
return nil, err
}
dl.currIdx += 1
return item, nil
}
// Batch sampling
elem, err := dl.dataset.Item(0)
if err != nil {
return nil, err
}
elemType := reflect.TypeOf(elem)
items := reflect.MakeSlice(reflect.SliceOf(elemType), 0, dl.dataset.Len())
nextIndex := dl.currIdx + dl.batchSize
// NOTE. length of indexes can be shorter than dataset length
if nextIndex >= len(dl.indexes) {
nextIndex = len(dl.indexes)
}
for i := dl.currIdx; i < nextIndex; i++ {
item, err := dl.dataset.Item(i)
if err != nil {
return nil, err
}
items = reflect.Append(items, reflect.ValueOf(item))
}
dl.currIdx = nextIndex
return items.Interface(), nil
}
// HasNext returns whether there is a next item in the iteration.
func (dl *DataLoader) HasNext() bool {
return dl.currIdx < len(dl.indexes)
}
// Reset reset index to start position.
func (dl *DataLoader) Reset() {
dl.currIdx = 0
}

44
dutil/dataloader_test.go Normal file
View File

@ -0,0 +1,44 @@
package dutil_test
import (
// "reflect"
"reflect"
"testing"
"github.com/sugarme/gotch/dutil"
)
func TestNewDataLoader(t *testing.T) {
data, err := dutil.NewSliceDataset([]int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
if err != nil {
t.Error(err)
}
_, err = dutil.NewDataLoader(data, nil)
if err != nil {
t.Errorf("Unexpected error. Got: %v\n", err)
}
}
func TestDataLoader_Next(t *testing.T) {
data, err := dutil.NewSliceDataset([]int{100, 1, 2, 3, 4, 5, 6, 7, 8, 9})
if err != nil {
t.Error(err)
}
dl, err := dutil.NewDataLoader(data, nil)
if err != nil {
t.Errorf("Unexpected error. Got: %v\n", err)
}
got, err := dl.Next()
if err != nil {
t.Error(err)
}
want := 100
if !reflect.DeepEqual(want, got) {
t.Errorf("Want: %v\n", want)
t.Errorf("Got: %v\n", got)
}
}

118
dutil/dataset.go Normal file
View File

@ -0,0 +1,118 @@
package dutil
import (
"fmt"
"reflect"
)
// Dataset represents a set of samples and
// how to access a sample by its index by implementing
// `Item()` method.
type Dataset interface {
Item(idx int) (interface{}, error)
DType() reflect.Type
Len() int
}
type DatasetKind int
const (
SliceDKind DatasetKind = iota
MapDKind
InvalidDKind
)
// SliceDataset is a slice of samples.
type SliceDataset struct {
data interface{}
}
// NewSliceDataset creates a new SliceDataset.
func NewSliceDataset(data interface{}) (*SliceDataset, error) {
kind := reflect.TypeOf(data).Kind().String()
if kind != "slice" {
err := fmt.Errorf("Invalid Type: expected data of slice type. Got '%v'.\n", kind)
return nil, err
}
return &SliceDataset{
data: data,
}, nil
}
// Item implements Dataset interface to get a sample by its index.
func (ds *SliceDataset) Item(idx int) (interface{}, error) {
if idx < 0 || idx >= reflect.ValueOf(ds.data).Len() {
err := fmt.Errorf("Idx is out of range.")
return nil, err
}
return reflect.ValueOf(ds.data).Index(idx).Interface(), nil
}
func (ds *SliceDataset) Len() int {
return reflect.ValueOf(ds.data).Len()
}
func (ds *SliceDataset) DType() reflect.Type {
return reflect.TypeOf(ds.data)
}
// MapDataset holds samples in a map.
type MapDataset struct {
// data map[string]interface{}
data interface{}
keys []string // keys to access elements in map
}
// NewMapDataset creates a new MapDataset.
// func NewMapDataset(data map[string]interface{}) *MapDataset {
func NewMapDataset(data interface{}) (*MapDataset, error) {
// validate map type
dtype := reflect.TypeOf(data).Kind().String()
if dtype != "map" {
err := fmt.Errorf("Expected data of map type. Got: %v\n", dtype)
return nil, err
}
// validate key string type
keyType := reflect.TypeOf(data).Key().Kind().String()
if keyType != "string" {
err := fmt.Errorf("Expected data with map key of string type. Got '%v'\n", keyType)
return nil, err
}
var keys []string
mapIter := reflect.ValueOf(data).MapRange()
for mapIter.Next() {
key := mapIter.Key().Interface()
keys = append(keys, key.(string))
}
return &MapDataset{
data: data,
keys: keys,
}, nil
}
// Item implements Dataset interface.
func (ds *MapDataset) Item(idx int) (interface{}, error) {
if idx < 0 || idx >= len(ds.keys) {
err := fmt.Errorf("idx is out of range.")
return nil, err
}
key := ds.keys[idx]
item := reflect.ValueOf(ds.data).MapIndex(reflect.ValueOf(key)).Interface()
return item, nil
}
func (ds *MapDataset) Len() int {
return reflect.ValueOf(ds.data).Len()
}
func (ds *MapDataset) DType() reflect.Type {
return reflect.TypeOf(ds.data)
}
// NOTE. To make this package agnostic, we don't add TensorDataset here.
// A end-user can create a custom dataset by implementing `Item()` method.

120
dutil/dataset_test.go Normal file
View File

@ -0,0 +1,120 @@
package dutil_test
import (
"reflect"
"testing"
"github.com/sugarme/gotch/dutil"
)
func TestNewSliceDataset(t *testing.T) {
// Error case: non `slice` type
invalidData := 1
_, err := dutil.NewSliceDataset(invalidData)
if err == nil {
t.Errorf("Expected invalid data type error: %v.\n", err)
}
// Valid case
validData := []int{0, 1, 2, 3}
_, err = dutil.NewSliceDataset(validData)
if err != nil {
t.Errorf("Unexpected error. Got: %v.\n", err)
}
}
func TestSliceDataset_Len(t *testing.T) {
data := []int{0, 1, 2, 3}
ds, err := dutil.NewSliceDataset(data)
if err != nil {
t.Error(err)
}
want := 4
got := ds.Len()
if !reflect.DeepEqual(want, got) {
t.Errorf("Want data length: %v\n", want)
t.Errorf("Got data length: %v\n", got)
}
}
func TestSliceDataset_Item(t *testing.T) {
data := []int{0, 1, 2, 3}
ds, err := dutil.NewSliceDataset(data)
if err != nil {
t.Error(err)
}
want := 2
got, err := ds.Item(2)
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(want, got) {
t.Errorf("Want item value: %v\n", want)
t.Errorf("Got item value: %v\n", got)
}
}
func TestNewMapDataset(t *testing.T) {
// Invalid data type
invalidData := []int{0, 1, 2, 3}
_, err := dutil.NewMapDataset(invalidData)
if err == nil {
t.Errorf("Expected Invalid data type. Got nil.")
}
// Invalid map key type
invalidKey := make(map[int]int, 0)
invalidKey[1] = 1
invalidKey[2] = 2
_, err = dutil.NewMapDataset(invalidKey)
if err == nil {
t.Errorf("Expected Invalid map key type. Got nil.")
}
// Valid data
validData := make(map[string]int)
validData["one"] = 1
validData["two"] = 2
_, err = dutil.NewMapDataset(validData)
if err != nil {
t.Errorf("Unexpected error. Got: %v\n", err)
}
}
func TestMaptDataset_Len(t *testing.T) {
var data map[string]int = map[string]int{"one": 1, "two": 2}
ds, err := dutil.NewMapDataset(data)
if err != nil {
t.Error(err)
}
want := 2
got := ds.Len()
if !reflect.DeepEqual(want, got) {
t.Errorf("Want data length: %v\n", want)
t.Errorf("Got data length: %v\n", got)
}
}
func TestMapDataset_Item(t *testing.T) {
var data map[string]int = map[string]int{"three": 3, "one": 1, "two": 2}
ds, err := dutil.NewMapDataset(data)
if err != nil {
t.Error(err)
}
want := 3
got, err := ds.Item(0)
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(want, got) {
t.Errorf("Want: %v\n", want)
t.Errorf("Got: %v\n", got)
}
}

143
dutil/kfold.go Normal file
View File

@ -0,0 +1,143 @@
package dutil
import (
"fmt"
"math/rand"
"sort"
)
// KFold represents a struct helper to
// split data into partitions.
type KFold struct {
n int
nfolds int
shuffle bool
//seed int
}
// Fold represents a partitions with
// 2 fields: indexes of train set and test set.
type Fold struct {
Train []int
Test []int
}
type KFoldOptions struct {
NFolds int // number of folds
Shuffle bool // whether suffling before splitting
}
type KFoldOption func(*KFoldOptions)
func NewKFoldOptions(options ...KFoldOption) KFoldOptions {
opts := KFoldOptions{
NFolds: 5,
Shuffle: false,
}
for _, o := range options {
o(&opts)
}
return opts
}
func WithNFolds(nfolds int) KFoldOption {
return func(o *KFoldOptions) {
o.NFolds = nfolds
}
}
func WithKFoldShuffle(shuffle bool) KFoldOption {
return func(o *KFoldOptions) {
o.Shuffle = shuffle
}
}
// NewKFold creates a new KFold struct.
func NewKFold(n int, opt ...KFoldOption) (*KFold, error) {
opts := NewKFoldOptions(opt...)
if opts.NFolds < 2 {
err := fmt.Errorf("nfolds must be at least 2. Got: %v\n", opts.NFolds)
return nil, err
}
if opts.NFolds > n {
err := fmt.Errorf("nfolds cannot be greater than number of samples (%v). Got: %v\n", n, opts.NFolds)
return nil, err
}
return &KFold{
n: n,
nfolds: opts.NFolds,
shuffle: opts.Shuffle,
}, nil
}
func (kf *KFold) Split() []Fold {
odd := kf.n % kf.nfolds
nsamples := kf.n - odd
fsize := nsamples / kf.nfolds
var indices []int
allIndices := rand.Perm(kf.n)
// Drop last odd-time elements
indices = allIndices[:nsamples]
if !kf.shuffle {
sort.Ints(indices)
}
// Split to train, test sets
var (
folds [][]int
fold []int
)
for i := 0; i < nsamples; i++ {
fold = append(fold, i)
if len(fold) == fsize {
folds = append(folds, fold)
fold = []int{}
}
}
var splits []Fold
for i := 0; i < kf.nfolds; i++ {
test := folds[i]
var trainFolds [][]int
trainFolds = append(trainFolds, folds[:i]...)
trainFolds = append(trainFolds, folds[i+1:]...)
var train []int
for _, f := range trainFolds {
train = append(train, f...)
}
split := Fold{
Test: values(indices, test),
Train: values(indices, train),
}
splits = append(splits, split)
}
return splits
}
func contains(data []int, item int) bool {
for _, el := range data {
if el == item {
return true
}
}
return false
}
func values(data []int, keys []int) []int {
var vals []int
for _, k := range keys {
v := data[k]
vals = append(vals, v)
}
return vals
}

56
dutil/kfold_test.go Normal file
View File

@ -0,0 +1,56 @@
package dutil_test
import (
"testing"
"github.com/sugarme/gotch/dutil"
)
func TestNewKFold(t *testing.T) {
// invalid kfold
n := 10
nfolds := 11
_, err := dutil.NewKFold(n, dutil.WithNFolds(nfolds), dutil.WithKFoldShuffle(true))
if err == nil {
t.Errorf("Expected error: invalid number of folds. Got nil.")
}
// valid
nfolds = 3
_, err = dutil.NewKFold(n, dutil.WithNFolds(nfolds), dutil.WithKFoldShuffle(true))
if err != nil {
t.Errorf("Unexpected error. Got %v.\n", err)
}
}
func TestKFold_Split(t *testing.T) {
n := 11
nfolds := 3
trainLen := 6
testLen := 3
kf, err := dutil.NewKFold(n, dutil.WithNFolds(nfolds), dutil.WithKFoldShuffle(true))
if err != nil {
t.Error(err)
}
splits := kf.Split()
if len(splits) != nfolds {
t.Errorf("Want number of folds: %v\n", nfolds)
t.Errorf("Got number of folds: %v\n", len(splits))
}
for _, f := range splits {
if len(f.Train) != trainLen {
t.Errorf("Expect train length: %v\n", trainLen)
t.Errorf("Got train length: %v\n", len(f.Train))
}
if len(f.Test) != testLen {
t.Errorf("Expect test length: %v\n", testLen)
t.Errorf("Got test length: %v\n", len(f.Test))
}
}
}

229
dutil/sampler.go Normal file
View File

@ -0,0 +1,229 @@
package dutil
import (
"fmt"
"math/rand"
"time"
)
// Sampler represents an interface to draw sample
// from a dataset by implementing `Sample` method to
// generate a slice of indices of data samples.
type Sampler interface {
Sample() []int
BatchSize() int
}
// SequentialSampler represents a method to
// draw sample by its order in the dataset.
type SequentialSampler struct {
n int // number of samples
batchSize int // always = 1
}
// NewSequentialSampler create a new SequentialSampler
//
// n : number of samples in dataset
func NewSequentialSampler(n int) *SequentialSampler {
return &SequentialSampler{n, 1}
}
// Sample implements Sampler interface.
func (s *SequentialSampler) Sample() []int {
var indices []int
for i := 0; i < s.n; i++ {
indices = append(indices, i)
}
return indices
}
func (s *SequentialSampler) BatchSize() int {
return s.batchSize
}
// RandomSampler represents a method to draw
// a sample randomly from dataset.
type RandomSampler struct {
n int // number of samples
size int // size of sampling
replacement bool // whether replacement or not
batchSize int // always = 1
}
type RandOptions struct {
Size int
Replacement bool
}
type RandOption func(*RandOptions)
func NewRandOptions(options ...RandOption) RandOptions {
opts := RandOptions{
Size: 0,
Replacement: false,
}
for _, o := range options {
o(&opts)
}
return opts
}
func WithSize(size int) RandOption {
return func(o *RandOptions) {
o.Size = size
}
}
func WithReplacement(replacement bool) RandOption {
return func(o *RandOptions) {
o.Replacement = replacement
}
}
// NewRandomSampler creates a new RandomSampler.
//
// n : number of samples in dataset
// size: Optional (default=n). Size of sampling.
// replacement: Optional (default=false). Whether not repeated or repeated samples.
func NewRandomSampler(n int, opt ...RandOption) (*RandomSampler, error) {
opts := NewRandOptions(opt...)
if opts.Size > n {
err := fmt.Errorf("Sampling size can not be greater than number of samples.")
return nil, err
}
size := n // default sampling size = all samples
if opts.Size != 0 {
size = opts.Size
}
return &RandomSampler{
n: n,
size: size,
replacement: opts.Replacement,
batchSize: 1,
}, nil
}
// Sample implements Sampler interface.
func (s *RandomSampler) Sample() []int {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
var indices []int
if !s.replacement {
for i := 0; i < s.size; i++ {
idx := r.Intn(s.n)
indices = append(indices, idx)
}
return indices
}
if s.size == s.n {
indices = r.Perm(s.n)
return indices
}
// Random range with fixed length
var max, min int
for {
max = r.Intn(s.n)
min = max - s.size
if min >= 0 {
break
}
}
// Random permutation in a range [min, max)
// ref. https://stackoverflow.com/questions/35354800
indices = r.Perm(max - min)
for i := range indices {
indices[i] += min
}
return indices
}
// BatchSize implements Sampler interface.
// It's always return 1.
func (s *RandomSampler) BatchSize() int {
return s.batchSize
}
func intRange(n int) []int {
var r []int
for i := 0; i < n; i++ {
r = append(r, i)
}
return r
}
// BatchSampler constructs a way to draw batches of samples.
type BatchSampler struct {
n int
batchSize int
shuffle bool
dropLast bool
}
// NewBatchSampler creates a new BatchSampler.
func NewBatchSampler(n, batchSize int, dropLast bool, shuffleOpt ...bool) (*BatchSampler, error) {
if batchSize > n || batchSize <= 1 {
err := fmt.Errorf("Invalid batch size: batch size must be greater than 1 and less or equal to number of samples(%v).", n)
return nil, err
}
shuffle := false
if len(shuffleOpt) > 0 {
shuffle = shuffleOpt[0]
}
return &BatchSampler{
n: n,
batchSize: batchSize,
shuffle: shuffle,
dropLast: dropLast,
}, nil
}
// Sample implements Sampler interface
func (s *BatchSampler) Sample() []int {
var (
batch []int
batches []int
)
var indices []int
switch s.shuffle {
case false:
for i := 0; i < s.n; i++ {
indices = append(indices, i)
}
case true:
// random permutation
r := rand.New(rand.NewSource(time.Now().UnixNano()))
indices = r.Perm(s.n)
}
for _, i := range indices {
batch = append(batch, i)
if len(batch) == s.batchSize {
batches = append(batches, batch...)
batch = []int{}
}
}
if !s.dropLast {
batches = append(batches, batch...)
}
return batches
}
// BatchSize returns batch size.
func (s *BatchSampler) BatchSize() int {
return s.batchSize
}

144
dutil/sampler_test.go Normal file
View File

@ -0,0 +1,144 @@
package dutil_test
import (
// "fmt"
"reflect"
"testing"
"github.com/sugarme/gotch/dutil"
)
func TestSequentialSampler(t *testing.T) {
s := dutil.NewSequentialSampler(3)
want := []int{0, 1, 2}
got := s.Sample()
if !reflect.DeepEqual(want, got) {
t.Errorf("Want: %+v\n", want)
t.Errorf("Got: %+v\n", got)
}
}
func TestRandomSampler(t *testing.T) {
// Default Optional (size and replacement)
s, err := dutil.NewRandomSampler(10)
if err != nil {
t.Errorf("Unexpected error. Got: %v\n", err)
}
want := 1
got := s.BatchSize() // NOTE. BatchSize is always 1 (for SequentialSampler and RandomSampler)
if !reflect.DeepEqual(want, got) {
t.Errorf("Want: %+v\n", want)
t.Errorf("Got: %+v\n", got)
}
// Replacement Opt
s1, err := dutil.NewRandomSampler(3, dutil.WithReplacement(true))
if err != nil {
t.Errorf("Unexpected error. Got: %v\n", err)
}
indices := s1.Sample()
if isDup(indices) {
t.Errorf("Unexpected duplicated elements. Got: %+v\n", indices)
}
// Size option
size := 3
s2, err := dutil.NewRandomSampler(10, dutil.WithSize(size))
if err != nil {
t.Errorf("Unexpected error. Got: %v\n", err)
}
indices = s2.Sample()
if len(indices) != size {
t.Errorf("Want size: %v\n", size)
t.Errorf("Got size: %v\n", len(indices))
}
}
func isDup(input []int) bool {
dmap := make(map[int]bool)
for _, key := range input {
if _, ok := dmap[key]; ok {
return true
}
dmap[key] = true
}
return false
}
func TestNewBatchSampler(t *testing.T) {
// Valid
_, err := dutil.NewBatchSampler(10, 3, true)
if err != nil {
t.Errorf("Unexpected error. Got: %v\n", err)
}
// Invalid batch size
_, err = dutil.NewBatchSampler(10, 11, true)
if err == nil {
t.Errorf("Expected invalid batch size error.")
}
_, err = dutil.NewBatchSampler(10, 1, true)
if err == nil {
t.Errorf("Expected invalid batch size error.")
}
}
func TestBatchSampler_BatchSize(t *testing.T) {
batchSize := 5
s, err := dutil.NewBatchSampler(10, batchSize, true)
if err != nil {
t.Errorf("Unexpected error. Got: %v\n", err)
}
got := s.BatchSize()
if !reflect.DeepEqual(batchSize, got) {
t.Errorf("Want batch size: %v\n", batchSize)
t.Errorf("Got batch size: %v\n", got)
}
}
func TestBatchSampler_Sample(t *testing.T) {
s1, err := dutil.NewBatchSampler(10, 3, true)
if err != nil {
t.Errorf("Unexpected error. Got: %v\n", err)
}
indices := s1.Sample()
want1 := 9
got1 := len(indices)
if !reflect.DeepEqual(want1, got1) {
t.Errorf("Want indices length: %v\n", want1)
t.Errorf("Got indices length: %v\n", got1)
}
// Shuffle
n := 1000
s2, err := dutil.NewBatchSampler(n, 3, false, true)
if err != nil {
t.Errorf("Unexpected error. Got: %v\n", err)
}
want2 := seq(n)
got2 := s2.Sample()
if reflect.DeepEqual(want2, got2) {
t.Errorf("Want indices: %+v\n", want2)
t.Errorf("Got indices: %+v\n", got2)
}
}
func seq(n int) []int {
var s []int
for i := 0; i < n; i++ {
s = append(s, i)
}
return s
}

View File

@ -1,6 +1,6 @@
#!/bin/bash
GOTCH_VERSION="${GOTCH_VER:-v0.3.5}"
GOTCH_VERSION="${GOTCH_VER:-v0.3.6}"
LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}"
CUDA_VERSION="${CUDA_VER:-10.1}"