feat(dutil): added DataSet and DataLoader
This commit is contained in:
parent
c8ccc03a19
commit
595d40bba6
91
CHANGELOG.md
91
CHANGELOG.md
|
@ -4,55 +4,24 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Fixed
|
||||
- [#...]: Fix a bug with...
|
||||
|
||||
### Changed
|
||||
- [#...]:
|
||||
## [0.3.6]
|
||||
|
||||
### Added
|
||||
- [#...]:
|
||||
- Added `dutil` sub-package that serves Pytorch `DataSet` and `DataLoader` concepts
|
||||
|
||||
|
||||
## [0.1.8]
|
||||
|
||||
### Changed
|
||||
- [#10]: `ts.Drop()` and `ts.MustDrop()` now can call multiple times without panic
|
||||
|
||||
## [0.1.9]
|
||||
|
||||
### Changed
|
||||
- Reverse changes [#10] to original.
|
||||
|
||||
## [0.1.10]
|
||||
## [0.3.5]
|
||||
|
||||
### Added
|
||||
- Added `tensor.SaveMultiNew`
|
||||
|
||||
|
||||
## [0.2.0]
|
||||
- Added function `gotch.CudaIfAvailable()`. NOTE that: `device := gotch.NewCuda().CudaIfAvailable()` will throw error if CUDA is not available.
|
||||
|
||||
### Changed
|
||||
- Convert all APIs to using **Pointer Receiver**
|
||||
- Switched back to install libtorch inside gotch library as go init() function is triggered after cgo called.
|
||||
|
||||
## [0.3.4]
|
||||
|
||||
### Added
|
||||
- Added drawing image label at `example/yolo` example
|
||||
- Added some example images and README files for `example/yolo` and `example/neural-style-transfer`
|
||||
|
||||
## [0.3.0]
|
||||
|
||||
### Changed
|
||||
- Updated to Pytorch C++ APIs v1.7.0
|
||||
- Switched back to `lib.AtoAddParametersOld` as the `ato_add_parameters` has not been implemented correctly. Using the updated API will cause optimizer stops working.
|
||||
|
||||
## [0.3.1]
|
||||
|
||||
### Changed
|
||||
- Changed to use `map[string]*Tensor` at `nn/varstore.go`
|
||||
- Changed to use `*Path` argument of `NewLayerNorm` method at `nn/layer-norm.go`
|
||||
- Lots of clean-up return variables i.e. retVal, err
|
||||
- [#4] Automatically download and install Libtorch and setup environment variables.
|
||||
|
||||
## [0.3.2]
|
||||
|
||||
|
@ -65,18 +34,42 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- nn/sequential: fixed missing case number of layers = 1 causing panic
|
||||
- nn/varstore: fixed(nn/varstore): fixed nil pointer at LoadPartial due to not break loop
|
||||
|
||||
## [0.3.4]
|
||||
|
||||
### Added
|
||||
- [#4] Automatically download and install Libtorch and setup environment variables.
|
||||
|
||||
## [0.3.5]
|
||||
|
||||
### Added
|
||||
- Added function `gotch.CudaIfAvailable()`. NOTE that: `device := gotch.NewCuda().CudaIfAvailable()` will throw error if CUDA is not available.
|
||||
## [0.3.1]
|
||||
|
||||
### Changed
|
||||
- Switched back to install libtorch inside gotch library as go init() function is triggered after cgo called.
|
||||
- Changed to use `map[string]*Tensor` at `nn/varstore.go`
|
||||
- Changed to use `*Path` argument of `NewLayerNorm` method at `nn/layer-norm.go`
|
||||
- Lots of clean-up return variables i.e. retVal, err
|
||||
|
||||
## [0.3.0]
|
||||
|
||||
### Changed
|
||||
- Updated to Pytorch C++ APIs v1.7.0
|
||||
- Switched back to `lib.AtoAddParametersOld` as the `ato_add_parameters` has not been implemented correctly. Using the updated API will cause optimizer stops working.
|
||||
|
||||
## [0.2.0]
|
||||
|
||||
### Changed
|
||||
- Convert all APIs to using **Pointer Receiver**
|
||||
|
||||
### Added
|
||||
- Added drawing image label at `example/yolo` example
|
||||
- Added some example images and README files for `example/yolo` and `example/neural-style-transfer`
|
||||
|
||||
## [0.1.10]
|
||||
|
||||
### Added
|
||||
- Added `tensor.SaveMultiNew`
|
||||
|
||||
## [0.1.9]
|
||||
|
||||
### Changed
|
||||
- Reverse changes [#10] to original.
|
||||
|
||||
## [0.1.8]
|
||||
|
||||
### Changed
|
||||
- [#10]: `ts.Drop()` and `ts.MustDrop()` now can call multiple times without panic
|
||||
|
||||
|
||||
[#10]: https://github.com/sugarme/gotch/issues/10
|
||||
|
|
117
dutil/dataloader.go
Normal file
117
dutil/dataloader.go
Normal file
|
@ -0,0 +1,117 @@
|
|||
package dutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// DataLoader combines a dataset and a sampler and provides
|
||||
// an iterable over the given dataset.
|
||||
type DataLoader struct {
|
||||
dataset Dataset
|
||||
indexes []int // order of samples in dataset for interation.
|
||||
batchSize int
|
||||
currIdx int
|
||||
}
|
||||
|
||||
func NewDataLoader(data Dataset, s Sampler) (*DataLoader, error) {
|
||||
dkind, err := checkDKind(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use default Sampler if no specified
|
||||
if s == nil {
|
||||
switch dkind {
|
||||
case SliceDKind:
|
||||
s = NewSequentialSampler(data.Len())
|
||||
case MapDKind:
|
||||
s, err = NewRandomSampler(data.Len())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &DataLoader{
|
||||
dataset: data,
|
||||
indexes: s.Sample(),
|
||||
batchSize: s.BatchSize(),
|
||||
currIdx: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func checkDKind(data Dataset) (DatasetKind, error) {
|
||||
dtyp := data.DType()
|
||||
dkind := dtyp.Kind().String()
|
||||
|
||||
switch dkind {
|
||||
case "slice":
|
||||
return SliceDKind, nil
|
||||
case "map":
|
||||
if dtyp.Key().Kind().String() != "string" {
|
||||
err := fmt.Errorf("Invalid Dataset. Dataset should be a collection data of type '[]interface{}' or 'map[string]interface{}'")
|
||||
return InvalidDKind, err
|
||||
}
|
||||
return MapDKind, nil
|
||||
|
||||
default: // other types are invalid
|
||||
err := fmt.Errorf("Invalid Dataset. Dataset should be a collection data of type '[]interface{}' or 'map[string]interface{}'")
|
||||
return InvalidDKind, err
|
||||
}
|
||||
}
|
||||
|
||||
// Next acts as iterator to return next sample(s) from dataset.
|
||||
func (dl *DataLoader) Next() (interface{}, error) {
|
||||
if !dl.HasNext() {
|
||||
err := fmt.Errorf("Next Error: no more item to iterate.\n")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Non-batching
|
||||
if dl.batchSize == 1 {
|
||||
item, err := dl.dataset.Item(dl.currIdx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dl.currIdx += 1
|
||||
return item, nil
|
||||
}
|
||||
|
||||
// Batch sampling
|
||||
elem, err := dl.dataset.Item(0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
elemType := reflect.TypeOf(elem)
|
||||
|
||||
items := reflect.MakeSlice(reflect.SliceOf(elemType), 0, dl.dataset.Len())
|
||||
nextIndex := dl.currIdx + dl.batchSize
|
||||
|
||||
// NOTE. length of indexes can be shorter than dataset length
|
||||
if nextIndex >= len(dl.indexes) {
|
||||
nextIndex = len(dl.indexes)
|
||||
}
|
||||
for i := dl.currIdx; i < nextIndex; i++ {
|
||||
item, err := dl.dataset.Item(i)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = reflect.Append(items, reflect.ValueOf(item))
|
||||
}
|
||||
|
||||
dl.currIdx = nextIndex
|
||||
return items.Interface(), nil
|
||||
}
|
||||
|
||||
// HasNext returns whether there is a next item in the iteration.
|
||||
func (dl *DataLoader) HasNext() bool {
|
||||
return dl.currIdx < len(dl.indexes)
|
||||
}
|
||||
|
||||
// Reset reset index to start position.
|
||||
func (dl *DataLoader) Reset() {
|
||||
dl.currIdx = 0
|
||||
}
|
44
dutil/dataloader_test.go
Normal file
44
dutil/dataloader_test.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package dutil_test
|
||||
|
||||
import (
|
||||
// "reflect"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch/dutil"
|
||||
)
|
||||
|
||||
func TestNewDataLoader(t *testing.T) {
|
||||
data, err := dutil.NewSliceDataset([]int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
_, err = dutil.NewDataLoader(data, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error. Got: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataLoader_Next(t *testing.T) {
|
||||
data, err := dutil.NewSliceDataset([]int{100, 1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
dl, err := dutil.NewDataLoader(data, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error. Got: %v\n", err)
|
||||
}
|
||||
|
||||
got, err := dl.Next()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
want := 100
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("Want: %v\n", want)
|
||||
t.Errorf("Got: %v\n", got)
|
||||
}
|
||||
}
|
118
dutil/dataset.go
Normal file
118
dutil/dataset.go
Normal file
|
@ -0,0 +1,118 @@
|
|||
package dutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Dataset represents a set of samples and
|
||||
// how to access a sample by its index by implementing
|
||||
// `Item()` method.
|
||||
type Dataset interface {
|
||||
Item(idx int) (interface{}, error)
|
||||
DType() reflect.Type
|
||||
Len() int
|
||||
}
|
||||
|
||||
type DatasetKind int
|
||||
|
||||
const (
|
||||
SliceDKind DatasetKind = iota
|
||||
MapDKind
|
||||
InvalidDKind
|
||||
)
|
||||
|
||||
// SliceDataset is a slice of samples.
|
||||
type SliceDataset struct {
|
||||
data interface{}
|
||||
}
|
||||
|
||||
// NewSliceDataset creates a new SliceDataset.
|
||||
func NewSliceDataset(data interface{}) (*SliceDataset, error) {
|
||||
kind := reflect.TypeOf(data).Kind().String()
|
||||
if kind != "slice" {
|
||||
err := fmt.Errorf("Invalid Type: expected data of slice type. Got '%v'.\n", kind)
|
||||
return nil, err
|
||||
}
|
||||
return &SliceDataset{
|
||||
data: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Item implements Dataset interface to get a sample by its index.
|
||||
func (ds *SliceDataset) Item(idx int) (interface{}, error) {
|
||||
if idx < 0 || idx >= reflect.ValueOf(ds.data).Len() {
|
||||
err := fmt.Errorf("Idx is out of range.")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reflect.ValueOf(ds.data).Index(idx).Interface(), nil
|
||||
}
|
||||
|
||||
func (ds *SliceDataset) Len() int {
|
||||
return reflect.ValueOf(ds.data).Len()
|
||||
}
|
||||
|
||||
func (ds *SliceDataset) DType() reflect.Type {
|
||||
return reflect.TypeOf(ds.data)
|
||||
}
|
||||
|
||||
// MapDataset holds samples in a map.
|
||||
type MapDataset struct {
|
||||
// data map[string]interface{}
|
||||
data interface{}
|
||||
keys []string // keys to access elements in map
|
||||
}
|
||||
|
||||
// NewMapDataset creates a new MapDataset.
|
||||
// func NewMapDataset(data map[string]interface{}) *MapDataset {
|
||||
func NewMapDataset(data interface{}) (*MapDataset, error) {
|
||||
// validate map type
|
||||
dtype := reflect.TypeOf(data).Kind().String()
|
||||
if dtype != "map" {
|
||||
err := fmt.Errorf("Expected data of map type. Got: %v\n", dtype)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// validate key string type
|
||||
keyType := reflect.TypeOf(data).Key().Kind().String()
|
||||
if keyType != "string" {
|
||||
err := fmt.Errorf("Expected data with map key of string type. Got '%v'\n", keyType)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var keys []string
|
||||
mapIter := reflect.ValueOf(data).MapRange()
|
||||
for mapIter.Next() {
|
||||
key := mapIter.Key().Interface()
|
||||
keys = append(keys, key.(string))
|
||||
}
|
||||
|
||||
return &MapDataset{
|
||||
data: data,
|
||||
keys: keys,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Item implements Dataset interface.
|
||||
func (ds *MapDataset) Item(idx int) (interface{}, error) {
|
||||
if idx < 0 || idx >= len(ds.keys) {
|
||||
err := fmt.Errorf("idx is out of range.")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := ds.keys[idx]
|
||||
item := reflect.ValueOf(ds.data).MapIndex(reflect.ValueOf(key)).Interface()
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func (ds *MapDataset) Len() int {
|
||||
return reflect.ValueOf(ds.data).Len()
|
||||
}
|
||||
|
||||
func (ds *MapDataset) DType() reflect.Type {
|
||||
return reflect.TypeOf(ds.data)
|
||||
}
|
||||
|
||||
// NOTE. To make this package agnostic, we don't add TensorDataset here.
|
||||
// A end-user can create a custom dataset by implementing `Item()` method.
|
120
dutil/dataset_test.go
Normal file
120
dutil/dataset_test.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package dutil_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch/dutil"
|
||||
)
|
||||
|
||||
func TestNewSliceDataset(t *testing.T) {
|
||||
// Error case: non `slice` type
|
||||
invalidData := 1
|
||||
_, err := dutil.NewSliceDataset(invalidData)
|
||||
if err == nil {
|
||||
t.Errorf("Expected invalid data type error: %v.\n", err)
|
||||
}
|
||||
|
||||
// Valid case
|
||||
validData := []int{0, 1, 2, 3}
|
||||
_, err = dutil.NewSliceDataset(validData)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error. Got: %v.\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSliceDataset_Len(t *testing.T) {
|
||||
data := []int{0, 1, 2, 3}
|
||||
ds, err := dutil.NewSliceDataset(data)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
want := 4
|
||||
got := ds.Len()
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("Want data length: %v\n", want)
|
||||
t.Errorf("Got data length: %v\n", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSliceDataset_Item(t *testing.T) {
|
||||
data := []int{0, 1, 2, 3}
|
||||
ds, err := dutil.NewSliceDataset(data)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
want := 2
|
||||
got, err := ds.Item(2)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("Want item value: %v\n", want)
|
||||
t.Errorf("Got item value: %v\n", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMapDataset(t *testing.T) {
|
||||
// Invalid data type
|
||||
invalidData := []int{0, 1, 2, 3}
|
||||
_, err := dutil.NewMapDataset(invalidData)
|
||||
if err == nil {
|
||||
t.Errorf("Expected Invalid data type. Got nil.")
|
||||
}
|
||||
|
||||
// Invalid map key type
|
||||
invalidKey := make(map[int]int, 0)
|
||||
invalidKey[1] = 1
|
||||
invalidKey[2] = 2
|
||||
_, err = dutil.NewMapDataset(invalidKey)
|
||||
if err == nil {
|
||||
t.Errorf("Expected Invalid map key type. Got nil.")
|
||||
}
|
||||
|
||||
// Valid data
|
||||
validData := make(map[string]int)
|
||||
validData["one"] = 1
|
||||
validData["two"] = 2
|
||||
_, err = dutil.NewMapDataset(validData)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error. Got: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaptDataset_Len(t *testing.T) {
|
||||
var data map[string]int = map[string]int{"one": 1, "two": 2}
|
||||
ds, err := dutil.NewMapDataset(data)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
want := 2
|
||||
got := ds.Len()
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("Want data length: %v\n", want)
|
||||
t.Errorf("Got data length: %v\n", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapDataset_Item(t *testing.T) {
|
||||
var data map[string]int = map[string]int{"three": 3, "one": 1, "two": 2}
|
||||
ds, err := dutil.NewMapDataset(data)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
want := 3
|
||||
got, err := ds.Item(0)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("Want: %v\n", want)
|
||||
t.Errorf("Got: %v\n", got)
|
||||
}
|
||||
}
|
143
dutil/kfold.go
Normal file
143
dutil/kfold.go
Normal file
|
@ -0,0 +1,143 @@
|
|||
package dutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// KFold represents a struct helper to
|
||||
// split data into partitions.
|
||||
type KFold struct {
|
||||
n int
|
||||
nfolds int
|
||||
shuffle bool
|
||||
//seed int
|
||||
}
|
||||
|
||||
// Fold represents a partitions with
|
||||
// 2 fields: indexes of train set and test set.
|
||||
type Fold struct {
|
||||
Train []int
|
||||
Test []int
|
||||
}
|
||||
|
||||
type KFoldOptions struct {
|
||||
NFolds int // number of folds
|
||||
Shuffle bool // whether suffling before splitting
|
||||
}
|
||||
|
||||
type KFoldOption func(*KFoldOptions)
|
||||
|
||||
func NewKFoldOptions(options ...KFoldOption) KFoldOptions {
|
||||
opts := KFoldOptions{
|
||||
NFolds: 5,
|
||||
Shuffle: false,
|
||||
}
|
||||
|
||||
for _, o := range options {
|
||||
o(&opts)
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
func WithNFolds(nfolds int) KFoldOption {
|
||||
return func(o *KFoldOptions) {
|
||||
o.NFolds = nfolds
|
||||
}
|
||||
}
|
||||
|
||||
func WithKFoldShuffle(shuffle bool) KFoldOption {
|
||||
return func(o *KFoldOptions) {
|
||||
o.Shuffle = shuffle
|
||||
}
|
||||
}
|
||||
|
||||
// NewKFold creates a new KFold struct.
|
||||
func NewKFold(n int, opt ...KFoldOption) (*KFold, error) {
|
||||
opts := NewKFoldOptions(opt...)
|
||||
|
||||
if opts.NFolds < 2 {
|
||||
err := fmt.Errorf("nfolds must be at least 2. Got: %v\n", opts.NFolds)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if opts.NFolds > n {
|
||||
err := fmt.Errorf("nfolds cannot be greater than number of samples (%v). Got: %v\n", n, opts.NFolds)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &KFold{
|
||||
n: n,
|
||||
nfolds: opts.NFolds,
|
||||
shuffle: opts.Shuffle,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (kf *KFold) Split() []Fold {
|
||||
odd := kf.n % kf.nfolds
|
||||
nsamples := kf.n - odd
|
||||
fsize := nsamples / kf.nfolds
|
||||
var indices []int
|
||||
|
||||
allIndices := rand.Perm(kf.n)
|
||||
// Drop last odd-time elements
|
||||
indices = allIndices[:nsamples]
|
||||
|
||||
if !kf.shuffle {
|
||||
sort.Ints(indices)
|
||||
}
|
||||
|
||||
// Split to train, test sets
|
||||
var (
|
||||
folds [][]int
|
||||
fold []int
|
||||
)
|
||||
for i := 0; i < nsamples; i++ {
|
||||
fold = append(fold, i)
|
||||
if len(fold) == fsize {
|
||||
folds = append(folds, fold)
|
||||
fold = []int{}
|
||||
}
|
||||
}
|
||||
|
||||
var splits []Fold
|
||||
for i := 0; i < kf.nfolds; i++ {
|
||||
test := folds[i]
|
||||
var trainFolds [][]int
|
||||
trainFolds = append(trainFolds, folds[:i]...)
|
||||
trainFolds = append(trainFolds, folds[i+1:]...)
|
||||
var train []int
|
||||
for _, f := range trainFolds {
|
||||
train = append(train, f...)
|
||||
}
|
||||
|
||||
split := Fold{
|
||||
Test: values(indices, test),
|
||||
Train: values(indices, train),
|
||||
}
|
||||
splits = append(splits, split)
|
||||
}
|
||||
|
||||
return splits
|
||||
}
|
||||
|
||||
func contains(data []int, item int) bool {
|
||||
for _, el := range data {
|
||||
if el == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func values(data []int, keys []int) []int {
|
||||
var vals []int
|
||||
for _, k := range keys {
|
||||
v := data[k]
|
||||
vals = append(vals, v)
|
||||
}
|
||||
return vals
|
||||
}
|
56
dutil/kfold_test.go
Normal file
56
dutil/kfold_test.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package dutil_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch/dutil"
|
||||
)
|
||||
|
||||
func TestNewKFold(t *testing.T) {
|
||||
// invalid kfold
|
||||
n := 10
|
||||
nfolds := 11
|
||||
|
||||
_, err := dutil.NewKFold(n, dutil.WithNFolds(nfolds), dutil.WithKFoldShuffle(true))
|
||||
if err == nil {
|
||||
t.Errorf("Expected error: invalid number of folds. Got nil.")
|
||||
}
|
||||
|
||||
// valid
|
||||
nfolds = 3
|
||||
_, err = dutil.NewKFold(n, dutil.WithNFolds(nfolds), dutil.WithKFoldShuffle(true))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error. Got %v.\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKFold_Split(t *testing.T) {
|
||||
n := 11
|
||||
nfolds := 3
|
||||
trainLen := 6
|
||||
testLen := 3
|
||||
|
||||
kf, err := dutil.NewKFold(n, dutil.WithNFolds(nfolds), dutil.WithKFoldShuffle(true))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
splits := kf.Split()
|
||||
|
||||
if len(splits) != nfolds {
|
||||
t.Errorf("Want number of folds: %v\n", nfolds)
|
||||
t.Errorf("Got number of folds: %v\n", len(splits))
|
||||
}
|
||||
|
||||
for _, f := range splits {
|
||||
if len(f.Train) != trainLen {
|
||||
t.Errorf("Expect train length: %v\n", trainLen)
|
||||
t.Errorf("Got train length: %v\n", len(f.Train))
|
||||
}
|
||||
|
||||
if len(f.Test) != testLen {
|
||||
t.Errorf("Expect test length: %v\n", testLen)
|
||||
t.Errorf("Got test length: %v\n", len(f.Test))
|
||||
}
|
||||
}
|
||||
}
|
229
dutil/sampler.go
Normal file
229
dutil/sampler.go
Normal file
|
@ -0,0 +1,229 @@
|
|||
package dutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Sampler represents an interface to draw sample
|
||||
// from a dataset by implementing `Sample` method to
|
||||
// generate a slice of indices of data samples.
|
||||
type Sampler interface {
|
||||
Sample() []int
|
||||
BatchSize() int
|
||||
}
|
||||
|
||||
// SequentialSampler represents a method to
|
||||
// draw sample by its order in the dataset.
|
||||
type SequentialSampler struct {
|
||||
n int // number of samples
|
||||
batchSize int // always = 1
|
||||
}
|
||||
|
||||
// NewSequentialSampler create a new SequentialSampler
|
||||
//
|
||||
// n : number of samples in dataset
|
||||
func NewSequentialSampler(n int) *SequentialSampler {
|
||||
return &SequentialSampler{n, 1}
|
||||
}
|
||||
|
||||
// Sample implements Sampler interface.
|
||||
func (s *SequentialSampler) Sample() []int {
|
||||
var indices []int
|
||||
for i := 0; i < s.n; i++ {
|
||||
indices = append(indices, i)
|
||||
}
|
||||
|
||||
return indices
|
||||
}
|
||||
|
||||
func (s *SequentialSampler) BatchSize() int {
|
||||
return s.batchSize
|
||||
}
|
||||
|
||||
// RandomSampler represents a method to draw
|
||||
// a sample randomly from dataset.
|
||||
type RandomSampler struct {
|
||||
n int // number of samples
|
||||
size int // size of sampling
|
||||
replacement bool // whether replacement or not
|
||||
batchSize int // always = 1
|
||||
}
|
||||
|
||||
type RandOptions struct {
|
||||
Size int
|
||||
Replacement bool
|
||||
}
|
||||
|
||||
type RandOption func(*RandOptions)
|
||||
|
||||
func NewRandOptions(options ...RandOption) RandOptions {
|
||||
opts := RandOptions{
|
||||
Size: 0,
|
||||
Replacement: false,
|
||||
}
|
||||
|
||||
for _, o := range options {
|
||||
o(&opts)
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
func WithSize(size int) RandOption {
|
||||
return func(o *RandOptions) {
|
||||
o.Size = size
|
||||
}
|
||||
}
|
||||
|
||||
func WithReplacement(replacement bool) RandOption {
|
||||
return func(o *RandOptions) {
|
||||
o.Replacement = replacement
|
||||
}
|
||||
}
|
||||
|
||||
// NewRandomSampler creates a new RandomSampler.
|
||||
//
|
||||
// n : number of samples in dataset
|
||||
// size: Optional (default=n). Size of sampling.
|
||||
// replacement: Optional (default=false). Whether not repeated or repeated samples.
|
||||
func NewRandomSampler(n int, opt ...RandOption) (*RandomSampler, error) {
|
||||
|
||||
opts := NewRandOptions(opt...)
|
||||
|
||||
if opts.Size > n {
|
||||
err := fmt.Errorf("Sampling size can not be greater than number of samples.")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
size := n // default sampling size = all samples
|
||||
if opts.Size != 0 {
|
||||
size = opts.Size
|
||||
}
|
||||
|
||||
return &RandomSampler{
|
||||
n: n,
|
||||
size: size,
|
||||
replacement: opts.Replacement,
|
||||
batchSize: 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Sample implements Sampler interface.
|
||||
func (s *RandomSampler) Sample() []int {
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
var indices []int
|
||||
|
||||
if !s.replacement {
|
||||
for i := 0; i < s.size; i++ {
|
||||
idx := r.Intn(s.n)
|
||||
indices = append(indices, idx)
|
||||
}
|
||||
return indices
|
||||
}
|
||||
|
||||
if s.size == s.n {
|
||||
indices = r.Perm(s.n)
|
||||
return indices
|
||||
}
|
||||
|
||||
// Random range with fixed length
|
||||
var max, min int
|
||||
for {
|
||||
max = r.Intn(s.n)
|
||||
min = max - s.size
|
||||
if min >= 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Random permutation in a range [min, max)
|
||||
// ref. https://stackoverflow.com/questions/35354800
|
||||
indices = r.Perm(max - min)
|
||||
for i := range indices {
|
||||
indices[i] += min
|
||||
}
|
||||
|
||||
return indices
|
||||
}
|
||||
|
||||
// BatchSize implements Sampler interface.
|
||||
// It's always return 1.
|
||||
func (s *RandomSampler) BatchSize() int {
|
||||
return s.batchSize
|
||||
}
|
||||
|
||||
func intRange(n int) []int {
|
||||
var r []int
|
||||
for i := 0; i < n; i++ {
|
||||
r = append(r, i)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// BatchSampler constructs a way to draw batches of samples.
|
||||
type BatchSampler struct {
|
||||
n int
|
||||
batchSize int
|
||||
shuffle bool
|
||||
dropLast bool
|
||||
}
|
||||
|
||||
// NewBatchSampler creates a new BatchSampler.
|
||||
func NewBatchSampler(n, batchSize int, dropLast bool, shuffleOpt ...bool) (*BatchSampler, error) {
|
||||
if batchSize > n || batchSize <= 1 {
|
||||
err := fmt.Errorf("Invalid batch size: batch size must be greater than 1 and less or equal to number of samples(%v).", n)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
shuffle := false
|
||||
if len(shuffleOpt) > 0 {
|
||||
shuffle = shuffleOpt[0]
|
||||
}
|
||||
|
||||
return &BatchSampler{
|
||||
n: n,
|
||||
batchSize: batchSize,
|
||||
shuffle: shuffle,
|
||||
dropLast: dropLast,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Sample implements Sampler interface
|
||||
func (s *BatchSampler) Sample() []int {
|
||||
var (
|
||||
batch []int
|
||||
batches []int
|
||||
)
|
||||
|
||||
var indices []int
|
||||
switch s.shuffle {
|
||||
case false:
|
||||
for i := 0; i < s.n; i++ {
|
||||
indices = append(indices, i)
|
||||
}
|
||||
case true:
|
||||
// random permutation
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
indices = r.Perm(s.n)
|
||||
}
|
||||
|
||||
for _, i := range indices {
|
||||
batch = append(batch, i)
|
||||
if len(batch) == s.batchSize {
|
||||
batches = append(batches, batch...)
|
||||
batch = []int{}
|
||||
}
|
||||
}
|
||||
if !s.dropLast {
|
||||
batches = append(batches, batch...)
|
||||
}
|
||||
|
||||
return batches
|
||||
}
|
||||
|
||||
// BatchSize returns batch size.
|
||||
func (s *BatchSampler) BatchSize() int {
|
||||
return s.batchSize
|
||||
}
|
144
dutil/sampler_test.go
Normal file
144
dutil/sampler_test.go
Normal file
|
@ -0,0 +1,144 @@
|
|||
package dutil_test
|
||||
|
||||
import (
|
||||
// "fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch/dutil"
|
||||
)
|
||||
|
||||
func TestSequentialSampler(t *testing.T) {
|
||||
s := dutil.NewSequentialSampler(3)
|
||||
want := []int{0, 1, 2}
|
||||
got := s.Sample()
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("Want: %+v\n", want)
|
||||
t.Errorf("Got: %+v\n", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandomSampler(t *testing.T) {
|
||||
// Default Optional (size and replacement)
|
||||
s, err := dutil.NewRandomSampler(10)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error. Got: %v\n", err)
|
||||
}
|
||||
|
||||
want := 1
|
||||
got := s.BatchSize() // NOTE. BatchSize is always 1 (for SequentialSampler and RandomSampler)
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("Want: %+v\n", want)
|
||||
t.Errorf("Got: %+v\n", got)
|
||||
}
|
||||
|
||||
// Replacement Opt
|
||||
s1, err := dutil.NewRandomSampler(3, dutil.WithReplacement(true))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error. Got: %v\n", err)
|
||||
}
|
||||
|
||||
indices := s1.Sample()
|
||||
if isDup(indices) {
|
||||
t.Errorf("Unexpected duplicated elements. Got: %+v\n", indices)
|
||||
}
|
||||
|
||||
// Size option
|
||||
size := 3
|
||||
s2, err := dutil.NewRandomSampler(10, dutil.WithSize(size))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error. Got: %v\n", err)
|
||||
}
|
||||
indices = s2.Sample()
|
||||
|
||||
if len(indices) != size {
|
||||
t.Errorf("Want size: %v\n", size)
|
||||
t.Errorf("Got size: %v\n", len(indices))
|
||||
}
|
||||
}
|
||||
|
||||
func isDup(input []int) bool {
|
||||
dmap := make(map[int]bool)
|
||||
|
||||
for _, key := range input {
|
||||
if _, ok := dmap[key]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
dmap[key] = true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func TestNewBatchSampler(t *testing.T) {
|
||||
// Valid
|
||||
_, err := dutil.NewBatchSampler(10, 3, true)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error. Got: %v\n", err)
|
||||
}
|
||||
|
||||
// Invalid batch size
|
||||
_, err = dutil.NewBatchSampler(10, 11, true)
|
||||
if err == nil {
|
||||
t.Errorf("Expected invalid batch size error.")
|
||||
}
|
||||
_, err = dutil.NewBatchSampler(10, 1, true)
|
||||
if err == nil {
|
||||
t.Errorf("Expected invalid batch size error.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchSampler_BatchSize(t *testing.T) {
|
||||
batchSize := 5
|
||||
s, err := dutil.NewBatchSampler(10, batchSize, true)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error. Got: %v\n", err)
|
||||
}
|
||||
|
||||
got := s.BatchSize()
|
||||
|
||||
if !reflect.DeepEqual(batchSize, got) {
|
||||
t.Errorf("Want batch size: %v\n", batchSize)
|
||||
t.Errorf("Got batch size: %v\n", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchSampler_Sample(t *testing.T) {
|
||||
s1, err := dutil.NewBatchSampler(10, 3, true)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error. Got: %v\n", err)
|
||||
}
|
||||
|
||||
indices := s1.Sample()
|
||||
want1 := 9
|
||||
got1 := len(indices)
|
||||
|
||||
if !reflect.DeepEqual(want1, got1) {
|
||||
t.Errorf("Want indices length: %v\n", want1)
|
||||
t.Errorf("Got indices length: %v\n", got1)
|
||||
}
|
||||
|
||||
// Shuffle
|
||||
n := 1000
|
||||
s2, err := dutil.NewBatchSampler(n, 3, false, true)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error. Got: %v\n", err)
|
||||
}
|
||||
|
||||
want2 := seq(n)
|
||||
got2 := s2.Sample()
|
||||
if reflect.DeepEqual(want2, got2) {
|
||||
t.Errorf("Want indices: %+v\n", want2)
|
||||
t.Errorf("Got indices: %+v\n", got2)
|
||||
}
|
||||
}
|
||||
|
||||
func seq(n int) []int {
|
||||
var s []int
|
||||
for i := 0; i < n; i++ {
|
||||
s = append(s, i)
|
||||
}
|
||||
return s
|
||||
}
|
Loading…
Reference in New Issue
Block a user