feat(tensor/data): TextData and TextDataIter; WIP(example/char-rnn)

This commit is contained in:
sugarme 2020-07-25 13:42:15 +10:00
parent b3529b3b1b
commit af7655d3fc
2 changed files with 165 additions and 0 deletions

23
example/char-rnn/main.go Normal file
View File

@ -0,0 +1,23 @@
package main
import (
"fmt"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
"github.com/sugarme/nn"
)
const (
LearningRate float64 = 0.01
HiddenSize int64 = 256
SeqLen int64 = 180
BatchSize int64 = 256
Epochs int64 = 100
SamplingLen int64 = 1024
)
func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Device) (retVal string) {
return
}

View File

@ -2,7 +2,10 @@ package tensor
import (
"fmt"
"io/ioutil"
"log"
"os"
"path/filepath"
"github.com/sugarme/gotch"
)
@ -149,3 +152,142 @@ func (it Iter2) Drop() {
it.xs.MustDrop()
it.ys.MustDrop()
}
// TextData represent text data in tensor of runes (uint8)
// and its corresponding string
type TextData struct {
Data Tensor // frequency (occurence) of byte value from input text
CharForLabel []rune // unique rune values from input text
}
// TextDataIter is a text data interator
type TextDataIter struct {
Data Tensor
SeqLen int64
BatchIndex int64
BatchSize int64
Indexes Tensor
IndexesLen int64
}
// NewTextData creates a text dataset from a file
//
// It reads text input from file to `[]byte` buffer
// - Loops over each byte and counts its occurence
// - first byte will be labelled `0`
// - next byte if exist will be labelled same as previous, otherwise
// will labelled `previous + 1`
// Data: tensor of labels
// CharForLabel: []rune (unique runes from text input)
func NewTextData(filename string) (retVal TextData, err error) {
filePath, err := filepath.Abs(filename)
if err != nil {
return retVal, err
}
r, err := os.Open(filePath)
buffer, err := ioutil.ReadAll(r)
if err != nil {
return retVal, err
}
var labelForChar map[byte]uint8 = make(map[byte]uint8, 0)
var charForLabel []rune
var mutBuffer []byte
for idx, runeVal := range buffer {
if idx == 0 {
mutBuffer = append(mutBuffer, 0)
labelForChar[runeVal] = 1
charForLabel = append(charForLabel, rune(runeVal))
} else {
label, ok := labelForChar[runeVal]
pos := len(labelForChar)
if !ok {
mutBuffer = append(mutBuffer, uint8(pos))
labelForChar[runeVal] = uint8(1)
charForLabel = append(charForLabel, rune(runeVal))
} else {
labelForChar[runeVal] = label + uint8(1)
mutBuffer = append(mutBuffer, uint8(pos-1))
}
}
}
data := MustOfSlice(mutBuffer)
return TextData{
Data: data,
CharForLabel: charForLabel,
}, nil
}
// Labels returns the number of different `character` (rune) used by the dataset.
func (td TextData) Labels() (retVal int64) {
return int64(len(td.CharForLabel))
}
// Data returns a shallow copy of the data.
func (td TextData) CloneData() (retVal Tensor) {
return td.Data.MustShallowClone()
}
// LabelForChar returns a corresponding `char` (rune) for
// specified label input
func (td TextData) LabelForChar(label int64) (retVal rune) {
return td.CharForLabel[int(label)]
}
// IterShuffle returns a batch iterator over the dataset.
// Each sample is made of seq_len characters.
func (td TextData) IterShuffle(seqLen int64, batchSize int64) (retVal TextDataIter) {
indexesLen := td.Data.MustSize()[0] - seqLen + 1
return TextDataIter{
Data: td.Data.MustShallowClone(),
SeqLen: seqLen,
BatchIndex: 0,
BatchSize: batchSize,
Indexes: MustRandperm(indexesLen, gotch.Int64, gotch.CPU),
IndexesLen: indexesLen,
}
}
// TODO: implement iterator for TextDataIter
func (tdi *TextDataIter) Next() (retVal Tensor, ok bool) {
start := tdi.BatchIndex * tdi.BatchSize
size := tdi.BatchSize
if (tdi.IndexesLen - start) < size {
size = tdi.IndexesLen - start
}
if size < tdi.BatchSize {
return retVal, false
}
tdi.BatchIndex += 1
narrowIdx := NewNarrow(start, start+size)
indexesTs := tdi.Indexes.Idx(narrowIdx)
values := indexesTs.Float64Values()
var indexes []int64
for _, v := range values {
indexes = append(indexes, int64(v))
}
var batch []Tensor
for _, idx := range indexes {
narrowIdx := NewNarrow(idx, idx+tdi.SeqLen)
idxTs := tdi.Indexes.Idx(narrowIdx)
batch = append(batch, idxTs)
}
retVal = MustStack(batch, 0)
return retVal, true
}