feat(tensor/data): TextData and TextDataIter; WIP(example/char-rnn)
This commit is contained in:
parent
b3529b3b1b
commit
af7655d3fc
23
example/char-rnn/main.go
Normal file
23
example/char-rnn/main.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/nn"
|
||||
)
|
||||
|
||||
const (
|
||||
LearningRate float64 = 0.01
|
||||
HiddenSize int64 = 256
|
||||
SeqLen int64 = 180
|
||||
BatchSize int64 = 256
|
||||
Epochs int64 = 100
|
||||
SamplingLen int64 = 1024
|
||||
)
|
||||
|
||||
func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Device) (retVal string) {
|
||||
|
||||
return
|
||||
}
|
142
tensor/data.go
142
tensor/data.go
|
@ -2,7 +2,10 @@ package tensor
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
)
|
||||
|
@ -149,3 +152,142 @@ func (it Iter2) Drop() {
|
|||
it.xs.MustDrop()
|
||||
it.ys.MustDrop()
|
||||
}
|
||||
|
||||
// TextData represent text data in tensor of runes (uint8)
|
||||
// and its corresponding string
|
||||
type TextData struct {
|
||||
Data Tensor // frequency (occurence) of byte value from input text
|
||||
CharForLabel []rune // unique rune values from input text
|
||||
}
|
||||
|
||||
// TextDataIter is a text data interator
|
||||
type TextDataIter struct {
|
||||
Data Tensor
|
||||
SeqLen int64
|
||||
BatchIndex int64
|
||||
BatchSize int64
|
||||
Indexes Tensor
|
||||
IndexesLen int64
|
||||
}
|
||||
|
||||
// NewTextData creates a text dataset from a file
|
||||
//
|
||||
// It reads text input from file to `[]byte` buffer
|
||||
// - Loops over each byte and counts its occurence
|
||||
// - first byte will be labelled `0`
|
||||
// - next byte if exist will be labelled same as previous, otherwise
|
||||
// will labelled `previous + 1`
|
||||
// Data: tensor of labels
|
||||
// CharForLabel: []rune (unique runes from text input)
|
||||
func NewTextData(filename string) (retVal TextData, err error) {
|
||||
filePath, err := filepath.Abs(filename)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
r, err := os.Open(filePath)
|
||||
|
||||
buffer, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
var labelForChar map[byte]uint8 = make(map[byte]uint8, 0)
|
||||
var charForLabel []rune
|
||||
var mutBuffer []byte
|
||||
|
||||
for idx, runeVal := range buffer {
|
||||
if idx == 0 {
|
||||
mutBuffer = append(mutBuffer, 0)
|
||||
labelForChar[runeVal] = 1
|
||||
charForLabel = append(charForLabel, rune(runeVal))
|
||||
} else {
|
||||
label, ok := labelForChar[runeVal]
|
||||
pos := len(labelForChar)
|
||||
if !ok {
|
||||
mutBuffer = append(mutBuffer, uint8(pos))
|
||||
labelForChar[runeVal] = uint8(1)
|
||||
charForLabel = append(charForLabel, rune(runeVal))
|
||||
} else {
|
||||
labelForChar[runeVal] = label + uint8(1)
|
||||
mutBuffer = append(mutBuffer, uint8(pos-1))
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
data := MustOfSlice(mutBuffer)
|
||||
|
||||
return TextData{
|
||||
Data: data,
|
||||
CharForLabel: charForLabel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Labels returns the number of different `character` (rune) used by the dataset.
|
||||
func (td TextData) Labels() (retVal int64) {
|
||||
return int64(len(td.CharForLabel))
|
||||
}
|
||||
|
||||
// Data returns a shallow copy of the data.
|
||||
func (td TextData) CloneData() (retVal Tensor) {
|
||||
return td.Data.MustShallowClone()
|
||||
}
|
||||
|
||||
// LabelForChar returns a corresponding `char` (rune) for
|
||||
// specified label input
|
||||
func (td TextData) LabelForChar(label int64) (retVal rune) {
|
||||
return td.CharForLabel[int(label)]
|
||||
}
|
||||
|
||||
// IterShuffle returns a batch iterator over the dataset.
|
||||
// Each sample is made of seq_len characters.
|
||||
func (td TextData) IterShuffle(seqLen int64, batchSize int64) (retVal TextDataIter) {
|
||||
|
||||
indexesLen := td.Data.MustSize()[0] - seqLen + 1
|
||||
|
||||
return TextDataIter{
|
||||
Data: td.Data.MustShallowClone(),
|
||||
SeqLen: seqLen,
|
||||
BatchIndex: 0,
|
||||
BatchSize: batchSize,
|
||||
Indexes: MustRandperm(indexesLen, gotch.Int64, gotch.CPU),
|
||||
IndexesLen: indexesLen,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: implement iterator for TextDataIter
|
||||
func (tdi *TextDataIter) Next() (retVal Tensor, ok bool) {
|
||||
start := tdi.BatchIndex * tdi.BatchSize
|
||||
size := tdi.BatchSize
|
||||
if (tdi.IndexesLen - start) < size {
|
||||
size = tdi.IndexesLen - start
|
||||
}
|
||||
|
||||
if size < tdi.BatchSize {
|
||||
return retVal, false
|
||||
}
|
||||
|
||||
tdi.BatchIndex += 1
|
||||
narrowIdx := NewNarrow(start, start+size)
|
||||
indexesTs := tdi.Indexes.Idx(narrowIdx)
|
||||
|
||||
values := indexesTs.Float64Values()
|
||||
var indexes []int64
|
||||
for _, v := range values {
|
||||
indexes = append(indexes, int64(v))
|
||||
}
|
||||
|
||||
var batch []Tensor
|
||||
|
||||
for _, idx := range indexes {
|
||||
narrowIdx := NewNarrow(idx, idx+tdi.SeqLen)
|
||||
idxTs := tdi.Indexes.Idx(narrowIdx)
|
||||
batch = append(batch, idxTs)
|
||||
}
|
||||
|
||||
retVal = MustStack(batch, 0)
|
||||
|
||||
return retVal, true
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user