fix(tensor/data): correct NewTextData func)

This commit is contained in:
sugarme 2020-07-27 11:31:42 +10:00
parent 490ea05de2
commit c69229c43a
5 changed files with 148 additions and 31 deletions

View File

@ -77,13 +77,31 @@ func main() {
break
}
batchNarrow := batchTs.MustNarrow(1, 1, SeqLen, false)
xsOnehot := batchNarrow.MustOneHot(labels, false)
testTs := ts.MustOfSlice([]int64{0, 1, 2, 3, 4, 5, 6, 7, 8})
testVal := testTs.Onehot(14).Float64Values()
fmt.Printf("testVal: %v\n", testVal)
fmt.Printf("batchTs shape: %v\n", batchTs.MustSize())
batchNarrow := batchTs.MustNarrow(1, 0, SeqLen, false)
fmt.Printf("batchNarrow shape: %v\n", batchNarrow.MustSize())
xsOnehotTmp := batchNarrow.Onehot(labels)
xsOnehot := xsOnehotTmp.MustTo(device, true) // shape: [256, 180, 65]
ys := batchTs.MustNarrow(1, 1, SeqLen, false).MustTotype(gotch.Int64, false)
lstmOut, _ := lstm.Seq(xsOnehot.MustTo(device, false))
// NOTE. WARNING occurred here...
// Warning: RNN module weights are not part of single contiguous chunk of memory.
// This means they need to be compacted at every call, possibly greatly increasing memory usage.
// To compact weights again call flatten_parameters().
// See: https://discuss.pytorch.org/t/rnn-module-weights-are-not-part-of-single-contiguous-chunk-of-memory/6011/21
lstmOut, _ := lstm.Seq(xsOnehot)
panic("reached")
logits := linear.Forward(lstmOut)
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, false)
loss := lossView.CrossEntropyForLogits(ys.MustTo(device, false)).MustView([]int64{BatchSize * SeqLen}, false)
opt.BackwardStepClip(loss, 0.5)

View File

@ -173,10 +173,10 @@ type TextDataIter struct {
// NewTextData creates a text dataset from a file
//
// It reads text input from file to `[]byte` buffer
// - Loops over each byte and counts its occurence
// - Loops over each byte
// - first byte will be labelled `0`
// - next byte if exist will be labelled same as previous, otherwise
// will labelled `previous + 1`
// - next byte if exist will be labelled with existing label (index), otherwise
// will labelled with new label(index)
// Data: tensor of labels
// CharForLabel: []rune (unique runes from text input)
func NewTextData(filename string) (retVal TextData, err error) {
@ -194,29 +194,27 @@ func NewTextData(filename string) (retVal TextData, err error) {
var labelForChar map[byte]uint8 = make(map[byte]uint8, 0)
var charForLabel []rune
var mutBuffer []byte
var dataIndexes []uint8
for idx, runeVal := range buffer {
if idx == 0 {
mutBuffer = append(mutBuffer, 0)
labelForChar[runeVal] = 1
for _, runeVal := range buffer {
if len(labelForChar) == 0 {
labelForChar[runeVal] = 0
dataIndexes = append(dataIndexes, 0)
charForLabel = append(charForLabel, rune(runeVal))
} else {
label, ok := labelForChar[runeVal]
pos := len(labelForChar)
if !ok {
mutBuffer = append(mutBuffer, uint8(pos))
labelForChar[runeVal] = uint8(1)
newLabel := uint8(len(labelForChar))
labelForChar[runeVal] = newLabel
dataIndexes = append(dataIndexes, newLabel)
charForLabel = append(charForLabel, rune(runeVal))
} else {
labelForChar[runeVal] = label + uint8(1)
mutBuffer = append(mutBuffer, uint8(pos-1))
dataIndexes = append(dataIndexes, label)
}
}
}
data := MustOfSlice(mutBuffer)
data := MustOfSlice(dataIndexes)
return TextData{
Data: data,
@ -256,38 +254,38 @@ func (td TextData) IterShuffle(seqLen int64, batchSize int64) (retVal TextDataIt
}
}
// TODO: implement iterator for TextDataIter
// Next implements iterator for TextDataIter
func (tdi *TextDataIter) Next() (retVal Tensor, ok bool) {
start := tdi.BatchIndex * tdi.BatchSize
size := tdi.BatchSize
if (tdi.IndexesLen - start) < size {
size = tdi.IndexesLen - start
}
size := min(tdi.BatchSize, tdi.IndexesLen-start)
if size < tdi.BatchSize {
return retVal, false
}
tdi.BatchIndex += 1
narrowIdx := NewNarrow(start, start+size)
indexesTs := tdi.Indexes.Idx(narrowIdx)
values := indexesTs.Float64Values()
var indexes []int64
for _, v := range values {
indexes = append(indexes, int64(v))
}
indexes := indexesTs.Int64Values()
var batch []Tensor
for _, idx := range indexes {
narrowIdx := NewNarrow(idx, idx+tdi.SeqLen)
idxTs := tdi.Indexes.Idx(narrowIdx)
idxTs := tdi.Data.Idx(narrowIdx)
batch = append(batch, idxTs)
}
retVal = MustStack(batch, 0)
return retVal, true
}
func min(v1, v2 int64) (retVal int64) {
if v1 < v2 {
return v1
}
return v2
}

View File

@ -1,12 +1,15 @@
package tensor_test
import (
ts "github.com/sugarme/gotch/tensor"
// "fmt"
"io/ioutil"
"log"
"path/filepath"
"reflect"
"testing"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
func TestTextData_NewTextData(t *testing.T) {
@ -53,3 +56,60 @@ func TestTextData_NewTextData(t *testing.T) {
t.Errorf("Got Char: %q\n", gotChar)
}
}
func TestTextDataIter(t *testing.T) {
filename := "/tmp/test.txt"
filePath, err := filepath.Abs(filename)
if err != nil {
log.Fatal(err)
}
txt := "01234567890123456789"
// txt := `hello world`
err = ioutil.WriteFile(filePath, []byte(txt), 0644)
if err != nil {
log.Fatal(err)
}
textData, err := ts.NewTextData(filename)
if err != nil {
log.Fatal(err)
}
iter := textData.IterShuffle(2, 5) // (seqLen, batchSize)
// fmt.Printf("indexesLen: %v\n", iter.IndexesLen)
// fmt.Printf("data: %v\n", iter.Data.Int64Values())
// fmt.Printf("Indexes: %v\n", iter.Indexes.Int64Values())
for {
xs, ok := iter.Next()
if !ok {
break
}
size := xs.MustSize()
idxCol := ts.NewNarrow(0, size[0]) // column
idxCol1 := ts.NewNarrow(0, 1) // first column
// idxNextEl := ts.NewSelect(1)
col1 := xs.Idx([]ts.TensorIndexer{idxCol, idxCol1})
// nextEl := xs.Idx([]ts.TensorIndexer{idxNextEl})
// col1PlusOne := ts.MustStack([]ts.Tensor{col1, nextEl}, 0)
// col1Fmod := col1PlusOne.MustFmod(ts.IntScalar(10), false)
col1Fmod := col1.MustFmod(ts.IntScalar(10), false)
// t.Errorf("col1 shape: %v\n", col1Fmod.MustSize())
idxCol2 := ts.NewNarrow(1, 2)
col2 := xs.Idx([]ts.TensorIndexer{idxCol, idxCol2})
// t.Errorf("col2 shape: %v\n", col2.MustSize())
pow := col1Fmod.MustSub(col2, true).MustPow(ts.IntScalar(2), true)
sum := pow.MustSum(gotch.Float, true)
// Will pass if there's no panic
vals := sum.Int64Values()
t.Logf("sum: %v\n", vals)
}
}

View File

@ -1110,6 +1110,24 @@ func (ts Tensor) MustZeroPad2d(left, right, top, bottom int64, del bool) (retVal
return retVal
}
// Onehot converts a tensor to a one-hot encoded version.
//
// If the input has a size [N1, N2, ..., Nk], the returned tensor has a size
// [N1, ..., Nk, labels]. The returned tensor uses float values.
// Elements of the input vector are expected to be between 0 and labels-1.
//
// NOTE: There's other `ts.OneHot` and `ts.MustOneHot` generated from Atg C++ API
func (ts Tensor) Onehot(labels int64) (retVal Tensor) {
dims := ts.MustSize()
dims = append(dims, labels)
unsqueezeTs := ts.MustUnsqueeze(-1, false).MustTotype(gotch.Int64, true)
zerosTs := MustZeros(dims, gotch.Float, gotch.CPU)
fmt.Printf("zeroTs shape: %v\n", zerosTs.MustSize())
fmt.Printf("unsqueezeTs shape: %v\n", unsqueezeTs.MustSize())
zerosTs.MustScatter1_(-1, unsqueezeTs, FloatScalar(1.0))
return zerosTs
}
func (ts Tensor) Swish() (retVal Tensor) {
sig := ts.MustSigmoid(false)
retVal = ts.MustMul(sig, false)

View File

@ -95,3 +95,26 @@ func TestIter(t *testing.T) {
t.Errorf("Got tensor values: %v\n", got1)
}
}
func TestOnehot(t *testing.T) {
xs := ts.MustOfSlice([]int64{0, 1, 2, 3}).MustView([]int64{2, 2}, true)
onehot := xs.Onehot(4)
want := []float64{1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0}
got := onehot.Float64Values()
if !reflect.DeepEqual(want, got) {
t.Errorf("Expected onehot tensor values: %v\n", want)
t.Errorf("Got onehot tensor values: %v\n", got)
}
}
/*
* let xs = Tensor::of_slice(&[0, 1, 2, 3]);
* let onehot = xs.onehot(4);
* assert_eq!(
* Vec::<f64>::from(&onehot),
* vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]
* );
* assert_eq!(onehot.size(), vec![4, 4]) */