fix(tensor/data): correct NewTextData func)
This commit is contained in:
parent
490ea05de2
commit
c69229c43a
|
@ -77,13 +77,31 @@ func main() {
|
|||
break
|
||||
}
|
||||
|
||||
batchNarrow := batchTs.MustNarrow(1, 1, SeqLen, false)
|
||||
xsOnehot := batchNarrow.MustOneHot(labels, false)
|
||||
testTs := ts.MustOfSlice([]int64{0, 1, 2, 3, 4, 5, 6, 7, 8})
|
||||
testVal := testTs.Onehot(14).Float64Values()
|
||||
fmt.Printf("testVal: %v\n", testVal)
|
||||
|
||||
fmt.Printf("batchTs shape: %v\n", batchTs.MustSize())
|
||||
|
||||
batchNarrow := batchTs.MustNarrow(1, 0, SeqLen, false)
|
||||
fmt.Printf("batchNarrow shape: %v\n", batchNarrow.MustSize())
|
||||
xsOnehotTmp := batchNarrow.Onehot(labels)
|
||||
xsOnehot := xsOnehotTmp.MustTo(device, true) // shape: [256, 180, 65]
|
||||
ys := batchTs.MustNarrow(1, 1, SeqLen, false).MustTotype(gotch.Int64, false)
|
||||
lstmOut, _ := lstm.Seq(xsOnehot.MustTo(device, false))
|
||||
|
||||
// NOTE. WARNING occurred here...
|
||||
// Warning: RNN module weights are not part of single contiguous chunk of memory.
|
||||
// This means they need to be compacted at every call, possibly greatly increasing memory usage.
|
||||
// To compact weights again call flatten_parameters().
|
||||
// See: https://discuss.pytorch.org/t/rnn-module-weights-are-not-part-of-single-contiguous-chunk-of-memory/6011/21
|
||||
lstmOut, _ := lstm.Seq(xsOnehot)
|
||||
|
||||
panic("reached")
|
||||
|
||||
logits := linear.Forward(lstmOut)
|
||||
|
||||
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, false)
|
||||
|
||||
loss := lossView.CrossEntropyForLogits(ys.MustTo(device, false)).MustView([]int64{BatchSize * SeqLen}, false)
|
||||
|
||||
opt.BackwardStepClip(loss, 0.5)
|
||||
|
|
|
@ -173,10 +173,10 @@ type TextDataIter struct {
|
|||
// NewTextData creates a text dataset from a file
|
||||
//
|
||||
// It reads text input from file to `[]byte` buffer
|
||||
// - Loops over each byte and counts its occurence
|
||||
// - Loops over each byte
|
||||
// - first byte will be labelled `0`
|
||||
// - next byte if exist will be labelled same as previous, otherwise
|
||||
// will labelled `previous + 1`
|
||||
// - next byte if exist will be labelled with existing label (index), otherwise
|
||||
// will labelled with new label(index)
|
||||
// Data: tensor of labels
|
||||
// CharForLabel: []rune (unique runes from text input)
|
||||
func NewTextData(filename string) (retVal TextData, err error) {
|
||||
|
@ -194,29 +194,27 @@ func NewTextData(filename string) (retVal TextData, err error) {
|
|||
|
||||
var labelForChar map[byte]uint8 = make(map[byte]uint8, 0)
|
||||
var charForLabel []rune
|
||||
var mutBuffer []byte
|
||||
var dataIndexes []uint8
|
||||
|
||||
for idx, runeVal := range buffer {
|
||||
if idx == 0 {
|
||||
mutBuffer = append(mutBuffer, 0)
|
||||
labelForChar[runeVal] = 1
|
||||
for _, runeVal := range buffer {
|
||||
if len(labelForChar) == 0 {
|
||||
labelForChar[runeVal] = 0
|
||||
dataIndexes = append(dataIndexes, 0)
|
||||
charForLabel = append(charForLabel, rune(runeVal))
|
||||
} else {
|
||||
label, ok := labelForChar[runeVal]
|
||||
pos := len(labelForChar)
|
||||
if !ok {
|
||||
mutBuffer = append(mutBuffer, uint8(pos))
|
||||
labelForChar[runeVal] = uint8(1)
|
||||
newLabel := uint8(len(labelForChar))
|
||||
labelForChar[runeVal] = newLabel
|
||||
dataIndexes = append(dataIndexes, newLabel)
|
||||
charForLabel = append(charForLabel, rune(runeVal))
|
||||
} else {
|
||||
labelForChar[runeVal] = label + uint8(1)
|
||||
mutBuffer = append(mutBuffer, uint8(pos-1))
|
||||
dataIndexes = append(dataIndexes, label)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
data := MustOfSlice(mutBuffer)
|
||||
data := MustOfSlice(dataIndexes)
|
||||
|
||||
return TextData{
|
||||
Data: data,
|
||||
|
@ -256,38 +254,38 @@ func (td TextData) IterShuffle(seqLen int64, batchSize int64) (retVal TextDataIt
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: implement iterator for TextDataIter
|
||||
// Next implements iterator for TextDataIter
|
||||
func (tdi *TextDataIter) Next() (retVal Tensor, ok bool) {
|
||||
start := tdi.BatchIndex * tdi.BatchSize
|
||||
size := tdi.BatchSize
|
||||
if (tdi.IndexesLen - start) < size {
|
||||
size = tdi.IndexesLen - start
|
||||
}
|
||||
size := min(tdi.BatchSize, tdi.IndexesLen-start)
|
||||
|
||||
if size < tdi.BatchSize {
|
||||
return retVal, false
|
||||
}
|
||||
|
||||
tdi.BatchIndex += 1
|
||||
|
||||
narrowIdx := NewNarrow(start, start+size)
|
||||
indexesTs := tdi.Indexes.Idx(narrowIdx)
|
||||
|
||||
values := indexesTs.Float64Values()
|
||||
var indexes []int64
|
||||
for _, v := range values {
|
||||
indexes = append(indexes, int64(v))
|
||||
}
|
||||
indexes := indexesTs.Int64Values()
|
||||
|
||||
var batch []Tensor
|
||||
|
||||
for _, idx := range indexes {
|
||||
narrowIdx := NewNarrow(idx, idx+tdi.SeqLen)
|
||||
idxTs := tdi.Indexes.Idx(narrowIdx)
|
||||
idxTs := tdi.Data.Idx(narrowIdx)
|
||||
batch = append(batch, idxTs)
|
||||
}
|
||||
|
||||
retVal = MustStack(batch, 0)
|
||||
|
||||
return retVal, true
|
||||
|
||||
}
|
||||
|
||||
func min(v1, v2 int64) (retVal int64) {
|
||||
if v1 < v2 {
|
||||
return v1
|
||||
}
|
||||
return v2
|
||||
}
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
package tensor_test
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
// "fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func TestTextData_NewTextData(t *testing.T) {
|
||||
|
@ -53,3 +56,60 @@ func TestTextData_NewTextData(t *testing.T) {
|
|||
t.Errorf("Got Char: %q\n", gotChar)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextDataIter(t *testing.T) {
|
||||
|
||||
filename := "/tmp/test.txt"
|
||||
filePath, err := filepath.Abs(filename)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
txt := "01234567890123456789"
|
||||
// txt := `hello world`
|
||||
err = ioutil.WriteFile(filePath, []byte(txt), 0644)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
textData, err := ts.NewTextData(filename)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
iter := textData.IterShuffle(2, 5) // (seqLen, batchSize)
|
||||
// fmt.Printf("indexesLen: %v\n", iter.IndexesLen)
|
||||
// fmt.Printf("data: %v\n", iter.Data.Int64Values())
|
||||
// fmt.Printf("Indexes: %v\n", iter.Indexes.Int64Values())
|
||||
|
||||
for {
|
||||
xs, ok := iter.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
size := xs.MustSize()
|
||||
idxCol := ts.NewNarrow(0, size[0]) // column
|
||||
idxCol1 := ts.NewNarrow(0, 1) // first column
|
||||
// idxNextEl := ts.NewSelect(1)
|
||||
col1 := xs.Idx([]ts.TensorIndexer{idxCol, idxCol1})
|
||||
// nextEl := xs.Idx([]ts.TensorIndexer{idxNextEl})
|
||||
// col1PlusOne := ts.MustStack([]ts.Tensor{col1, nextEl}, 0)
|
||||
// col1Fmod := col1PlusOne.MustFmod(ts.IntScalar(10), false)
|
||||
col1Fmod := col1.MustFmod(ts.IntScalar(10), false)
|
||||
|
||||
// t.Errorf("col1 shape: %v\n", col1Fmod.MustSize())
|
||||
|
||||
idxCol2 := ts.NewNarrow(1, 2)
|
||||
col2 := xs.Idx([]ts.TensorIndexer{idxCol, idxCol2})
|
||||
// t.Errorf("col2 shape: %v\n", col2.MustSize())
|
||||
|
||||
pow := col1Fmod.MustSub(col2, true).MustPow(ts.IntScalar(2), true)
|
||||
sum := pow.MustSum(gotch.Float, true)
|
||||
|
||||
// Will pass if there's no panic
|
||||
vals := sum.Int64Values()
|
||||
t.Logf("sum: %v\n", vals)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1110,6 +1110,24 @@ func (ts Tensor) MustZeroPad2d(left, right, top, bottom int64, del bool) (retVal
|
|||
return retVal
|
||||
}
|
||||
|
||||
// Onehot converts a tensor to a one-hot encoded version.
|
||||
//
|
||||
// If the input has a size [N1, N2, ..., Nk], the returned tensor has a size
|
||||
// [N1, ..., Nk, labels]. The returned tensor uses float values.
|
||||
// Elements of the input vector are expected to be between 0 and labels-1.
|
||||
//
|
||||
// NOTE: There's other `ts.OneHot` and `ts.MustOneHot` generated from Atg C++ API
|
||||
func (ts Tensor) Onehot(labels int64) (retVal Tensor) {
|
||||
dims := ts.MustSize()
|
||||
dims = append(dims, labels)
|
||||
unsqueezeTs := ts.MustUnsqueeze(-1, false).MustTotype(gotch.Int64, true)
|
||||
zerosTs := MustZeros(dims, gotch.Float, gotch.CPU)
|
||||
fmt.Printf("zeroTs shape: %v\n", zerosTs.MustSize())
|
||||
fmt.Printf("unsqueezeTs shape: %v\n", unsqueezeTs.MustSize())
|
||||
zerosTs.MustScatter1_(-1, unsqueezeTs, FloatScalar(1.0))
|
||||
return zerosTs
|
||||
}
|
||||
|
||||
func (ts Tensor) Swish() (retVal Tensor) {
|
||||
sig := ts.MustSigmoid(false)
|
||||
retVal = ts.MustMul(sig, false)
|
||||
|
|
|
@ -95,3 +95,26 @@ func TestIter(t *testing.T) {
|
|||
t.Errorf("Got tensor values: %v\n", got1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnehot(t *testing.T) {
|
||||
xs := ts.MustOfSlice([]int64{0, 1, 2, 3}).MustView([]int64{2, 2}, true)
|
||||
onehot := xs.Onehot(4)
|
||||
|
||||
want := []float64{1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0}
|
||||
got := onehot.Float64Values()
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("Expected onehot tensor values: %v\n", want)
|
||||
t.Errorf("Got onehot tensor values: %v\n", got)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/*
|
||||
* let xs = Tensor::of_slice(&[0, 1, 2, 3]);
|
||||
* let onehot = xs.onehot(4);
|
||||
* assert_eq!(
|
||||
* Vec::<f64>::from(&onehot),
|
||||
* vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]
|
||||
* );
|
||||
* assert_eq!(onehot.size(), vec![4, 4]) */
|
||||
|
|
Loading…
Reference in New Issue
Block a user