115 lines
2.6 KiB
Go
115 lines
2.6 KiB
Go
package ts_test
|
|
|
|
import (
|
|
// "fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"path/filepath"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"git.andr3h3nriqu3s.com/andr3/gotch"
|
|
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
|
)
|
|
|
|
func TestTextData_NewTextData(t *testing.T) {
|
|
// Create text file to test
|
|
filename := "/tmp/test.txt"
|
|
filePath, err := filepath.Abs(filename)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
// txt := `héllo`
|
|
txt := "h\xC3\xA9llo"
|
|
err = ioutil.WriteFile(filePath, []byte(txt), 0644)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
textData, err := ts.NewTextData(filename)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
wantData := []float64{0, 1, 2, 3, 3, 4}
|
|
gotData := textData.Data.Float64Values()
|
|
|
|
if !reflect.DeepEqual(wantData, gotData) {
|
|
t.Errorf("Want data: %v\n", wantData)
|
|
t.Errorf("Got data: %v\n", gotData)
|
|
}
|
|
|
|
wantLabelLen := int64(5)
|
|
gotLabelLen := textData.Labels()
|
|
|
|
if !reflect.DeepEqual(wantLabelLen, gotLabelLen) {
|
|
t.Errorf("Want label len: %v\n", wantLabelLen)
|
|
t.Errorf("Got label len: %v\n", gotLabelLen)
|
|
}
|
|
|
|
wantChar := rune(195)
|
|
gotChar := textData.LabelForChar(int64(1))
|
|
|
|
if !reflect.DeepEqual(wantChar, gotChar) {
|
|
t.Errorf("Want Char: %q\n", wantChar)
|
|
t.Errorf("Got Char: %q\n", gotChar)
|
|
}
|
|
}
|
|
|
|
func TestTextDataIter(t *testing.T) {
|
|
|
|
filename := "/tmp/test.txt"
|
|
filePath, err := filepath.Abs(filename)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
txt := "01234567890123456789"
|
|
// txt := `hello world`
|
|
err = ioutil.WriteFile(filePath, []byte(txt), 0644)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
textData, err := ts.NewTextData(filename)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
iter := textData.IterShuffle(2, 5) // (seqLen, batchSize)
|
|
// fmt.Printf("indexesLen: %v\n", iter.IndexesLen)
|
|
// fmt.Printf("data: %v\n", iter.Data.Int64Values())
|
|
// fmt.Printf("Indexes: %v\n", iter.Indexes.Int64Values())
|
|
|
|
for {
|
|
xs, ok := iter.Next()
|
|
if !ok {
|
|
break
|
|
}
|
|
|
|
size := xs.MustSize()
|
|
idxCol := ts.NewNarrow(0, size[0]) // column
|
|
idxCol1 := ts.NewNarrow(0, 1) // first column
|
|
// idxNextEl := ts.NewSelect(1)
|
|
col1 := xs.Idx([]ts.TensorIndexer{idxCol, idxCol1})
|
|
// nextEl := xs.Idx([]ts.TensorIndexer{idxNextEl})
|
|
// col1PlusOne := ts.MustStack([]ts.Tensor{col1, nextEl}, 0)
|
|
// col1Fmod := col1PlusOne.MustFmod(ts.IntScalar(10), false)
|
|
col1Fmod := col1.MustFmod(ts.IntScalar(10), false)
|
|
|
|
// t.Errorf("col1 shape: %v\n", col1Fmod.MustSize())
|
|
|
|
idxCol2 := ts.NewNarrow(1, 2)
|
|
col2 := xs.Idx([]ts.TensorIndexer{idxCol, idxCol2})
|
|
// t.Errorf("col2 shape: %v\n", col2.MustSize())
|
|
|
|
pow := col1Fmod.MustSub(col2, true).MustPowTensorScalar(ts.IntScalar(2), true)
|
|
sum := pow.MustSum(gotch.Float, true)
|
|
|
|
// Will pass if there's no panic
|
|
vals := sum.Int64Values()
|
|
t.Logf("sum: %v\n", vals)
|
|
}
|
|
}
|