309 lines
7.4 KiB
Go
309 lines
7.4 KiB
Go
package ts
|
|
|
|
import (
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
"git.andr3h3nriqu3s.com/andr3/gotch"
|
|
)
|
|
|
|
// Iter2 is an iterator over a pair of tensors which have the same first dimension
|
|
// size.
|
|
// The typical use case is to iterate over batches. Each batch is a pair
|
|
// containing a (potentially random) slice of each of the two input
|
|
// tensors.
|
|
type Iter2 struct {
|
|
xs *Tensor
|
|
ys *Tensor
|
|
batchIndex int64
|
|
batchSize int64
|
|
totalSize int64
|
|
device gotch.Device
|
|
returnSmallLastBatch bool
|
|
}
|
|
|
|
// NewIter2 returns a new iterator.
|
|
//
|
|
// This takes as input two tensors which first dimension must match. The
|
|
// returned iterator can be used to range over mini-batches of data of
|
|
// specified size.
|
|
// An error is returned if `xs` and `ys` have different first dimension
|
|
// sizes.
|
|
//
|
|
// # Arguments
|
|
//
|
|
// * `xs` - the features to be used by the model.
|
|
// * `ys` - the targets that the model attempts to predict.
|
|
// * `batch_size` - the size of batches to be returned.
|
|
func NewIter2(xs, ys *Tensor, batchSize int64) (*Iter2, error) {
|
|
var (
|
|
iter *Iter2
|
|
err error
|
|
)
|
|
|
|
totalSize := xs.MustSize()[0]
|
|
if ys.MustSize()[0] != totalSize {
|
|
err = fmt.Errorf("Different dimension for the two inputs: %v - %v", xs.MustSize(), ys.MustSize())
|
|
return nil, err
|
|
}
|
|
|
|
// xsClone, err := xs.ZerosLike(false)
|
|
// if err != nil {
|
|
// log.Fatal(err)
|
|
// }
|
|
// xsClone.Copy_(xs)
|
|
//
|
|
// ysClone, err := ys.ZerosLike(false)
|
|
// if err != nil {
|
|
// log.Fatal(err)
|
|
// }
|
|
// ysClone.Copy_(ys)
|
|
|
|
iter = &Iter2{
|
|
xs: xs.MustShallowClone(),
|
|
ys: ys.MustShallowClone(),
|
|
// xs: xsClone,
|
|
// ys: ysClone,
|
|
batchIndex: 0,
|
|
batchSize: batchSize,
|
|
totalSize: totalSize,
|
|
returnSmallLastBatch: false,
|
|
}
|
|
|
|
return iter, nil
|
|
}
|
|
|
|
// MustNewIter2 returns a new iterator.
|
|
//
|
|
// This takes as input two tensors which first dimension must match. The
|
|
// returned iterator can be used to range over mini-batches of data of
|
|
// specified size.
|
|
// Panics if `xs` and `ys` have different first dimension sizes.
|
|
//
|
|
// # Arguments
|
|
//
|
|
// * `xs` - the features to be used by the model.
|
|
// * `ys` - the targets that the model attempts to predict.
|
|
// * `batch_size` - the size of batches to be returned.
|
|
func MustNewIter2(xs, ys *Tensor, batchSize int64) *Iter2 {
|
|
iter, err := NewIter2(xs, ys, batchSize)
|
|
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return iter
|
|
}
|
|
|
|
// Shuffle shuffles the dataset.
|
|
//
|
|
// The iterator would still run over the whole dataset but the order in
|
|
// which elements are grouped in mini-batches is randomized.
|
|
func (it *Iter2) Shuffle() {
|
|
index := MustRandperm(it.totalSize, gotch.Int64, gotch.CPU)
|
|
|
|
it.xs = it.xs.MustIndexSelect(0, index, true)
|
|
it.ys = it.ys.MustIndexSelect(0, index, true)
|
|
|
|
index.MustDrop()
|
|
}
|
|
|
|
// ToDevice transfers the mini-batches to a specified device.
|
|
func (it *Iter2) ToDevice(device gotch.Device) *Iter2 {
|
|
it.device = device
|
|
return it
|
|
}
|
|
|
|
// ReturnSmallLastBatch when set, returns the last batch even if smaller than the batch size.
|
|
func (it *Iter2) ReturnSmallLastBatch() *Iter2 {
|
|
it.returnSmallLastBatch = true
|
|
return it
|
|
}
|
|
|
|
type Iter2Item struct {
|
|
Data *Tensor
|
|
Label *Tensor
|
|
}
|
|
|
|
// Next implements iterator for Iter2
|
|
func (it *Iter2) Next() (item Iter2Item, ok bool) {
|
|
start := it.batchIndex * it.batchSize
|
|
size := it.batchSize
|
|
if it.totalSize-start < it.batchSize {
|
|
size = it.totalSize - start
|
|
}
|
|
|
|
if (size <= 0) || (!it.returnSmallLastBatch && size < it.batchSize) {
|
|
// err = fmt.Errorf("Last small batch error")
|
|
return item, false
|
|
} else {
|
|
it.batchIndex += 1
|
|
|
|
// Indexing
|
|
narrowIndex := NewNarrow(start, start+size)
|
|
|
|
return Iter2Item{
|
|
Data: it.xs.Idx(narrowIndex),
|
|
Label: it.ys.Idx(narrowIndex),
|
|
}, true
|
|
}
|
|
}
|
|
|
|
func (it *Iter2) Drop() {
|
|
it.xs.MustDrop()
|
|
it.ys.MustDrop()
|
|
}
|
|
|
|
// TextData represent text data in tensor of runes (uint8)
|
|
// and its corresponding string
|
|
type TextData struct {
|
|
Data *Tensor // frequency (occurence) of byte value from input text
|
|
CharForLabel []rune // unique rune values from input text
|
|
}
|
|
|
|
// TextDataIter is a text data interator
|
|
type TextDataIter struct {
|
|
Data *Tensor
|
|
SeqLen int64
|
|
BatchIndex int64
|
|
BatchSize int64
|
|
Indexes *Tensor
|
|
IndexesLen int64
|
|
}
|
|
|
|
// NewTextData creates a text dataset from a file
|
|
//
|
|
// It reads text input from file to `[]byte` buffer
|
|
// - Loops over each byte
|
|
// - first byte will be labelled `0`
|
|
// - next byte if exist will be labelled with existing label (index), otherwise
|
|
// will labelled with new label(index)
|
|
// Data: tensor of labels
|
|
// CharForLabel: []rune (unique runes from text input)
|
|
func NewTextData(filename string) (*TextData, error) {
|
|
filePath, err := filepath.Abs(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
r, err := os.Open(filePath)
|
|
|
|
buffer, err := ioutil.ReadAll(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var labelForChar map[byte]uint8 = make(map[byte]uint8, 0)
|
|
var charForLabel []rune
|
|
var dataIndexes []uint8
|
|
|
|
for _, runeVal := range buffer {
|
|
if len(labelForChar) == 0 {
|
|
labelForChar[runeVal] = 0
|
|
dataIndexes = append(dataIndexes, 0)
|
|
charForLabel = append(charForLabel, rune(runeVal))
|
|
} else {
|
|
label, ok := labelForChar[runeVal]
|
|
if !ok {
|
|
newLabel := uint8(len(labelForChar))
|
|
labelForChar[runeVal] = newLabel
|
|
dataIndexes = append(dataIndexes, newLabel)
|
|
charForLabel = append(charForLabel, rune(runeVal))
|
|
} else {
|
|
dataIndexes = append(dataIndexes, label)
|
|
}
|
|
}
|
|
}
|
|
|
|
data := MustOfSlice(dataIndexes)
|
|
|
|
return &TextData{
|
|
Data: data,
|
|
CharForLabel: charForLabel,
|
|
}, nil
|
|
}
|
|
|
|
func (tdi *TextDataIter) Progress() float32 {
|
|
startIndex := (tdi.BatchIndex * tdi.BatchSize)
|
|
availableIndices := tdi.IndexesLen
|
|
progress := float32(startIndex) / float32(availableIndices)
|
|
return progress
|
|
}
|
|
|
|
// Labels returns the number of different `character` (rune) used by the dataset.
|
|
func (td *TextData) Labels() (retVal int64) {
|
|
return int64(len(td.CharForLabel))
|
|
}
|
|
|
|
// Data returns a shallow copy of the data.
|
|
func (td *TextData) CloneData() *Tensor {
|
|
return td.Data.MustShallowClone()
|
|
}
|
|
|
|
// LabelForChar returns a corresponding `char` (rune) for
|
|
// specified label input
|
|
func (td *TextData) LabelForChar(label int64) rune {
|
|
return td.CharForLabel[int(label)]
|
|
}
|
|
|
|
// IterShuffle returns a batch iterator over the dataset.
|
|
// Each sample is made of seq_len characters.
|
|
func (td *TextData) IterShuffle(seqLen int64, batchSize int64) *TextDataIter {
|
|
|
|
indexesLen := td.Data.MustSize()[0] - seqLen + 1
|
|
|
|
return &TextDataIter{
|
|
Data: td.Data.MustShallowClone(),
|
|
SeqLen: seqLen,
|
|
BatchIndex: 0,
|
|
BatchSize: batchSize,
|
|
Indexes: MustRandperm(indexesLen, gotch.Int64, gotch.CPU),
|
|
IndexesLen: indexesLen,
|
|
}
|
|
}
|
|
|
|
// Next implements iterator for TextDataIter
|
|
func (tdi *TextDataIter) Next() (*Tensor, bool) {
|
|
start := tdi.BatchIndex * tdi.BatchSize
|
|
size := min(tdi.BatchSize, tdi.IndexesLen-start)
|
|
|
|
if size < tdi.BatchSize {
|
|
return nil, false
|
|
}
|
|
|
|
tdi.BatchIndex += 1
|
|
|
|
narrowIdx := NewNarrow(start, start+size)
|
|
indexesTs := tdi.Indexes.Idx(narrowIdx)
|
|
|
|
indexes := indexesTs.Int64Values()
|
|
indexesTs.MustDrop()
|
|
|
|
var batch []*Tensor
|
|
|
|
for _, idx := range indexes {
|
|
narrowIdx := NewNarrow(idx, idx+tdi.SeqLen)
|
|
idxTs := tdi.Data.Idx(narrowIdx)
|
|
batch = append(batch, idxTs)
|
|
}
|
|
|
|
retVal := MustStack(batch, 0)
|
|
|
|
// Delete intermediate tensors
|
|
for _, xs := range batch {
|
|
xs.MustDrop()
|
|
}
|
|
|
|
return retVal, true
|
|
}
|
|
|
|
func min(v1, v2 int64) int64 {
|
|
if v1 < v2 {
|
|
return v1
|
|
}
|
|
return v2
|
|
}
|