gotch/ts/data_test.go
Goncalves Henriques, Andre (UG - Computer Science) 9257404edd Move the name of the module
2024-04-21 15:15:00 +01:00

115 lines
2.6 KiB
Go

package ts_test
import (
// "fmt"
"io/ioutil"
"log"
"path/filepath"
"reflect"
"testing"
"git.andr3h3nriqu3s.com/andr3/gotch"
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
)
func TestTextData_NewTextData(t *testing.T) {
// Create text file to test
filename := "/tmp/test.txt"
filePath, err := filepath.Abs(filename)
if err != nil {
log.Fatal(err)
}
// txt := `héllo`
txt := "h\xC3\xA9llo"
err = ioutil.WriteFile(filePath, []byte(txt), 0644)
if err != nil {
log.Fatal(err)
}
textData, err := ts.NewTextData(filename)
if err != nil {
log.Fatal(err)
}
wantData := []float64{0, 1, 2, 3, 3, 4}
gotData := textData.Data.Float64Values()
if !reflect.DeepEqual(wantData, gotData) {
t.Errorf("Want data: %v\n", wantData)
t.Errorf("Got data: %v\n", gotData)
}
wantLabelLen := int64(5)
gotLabelLen := textData.Labels()
if !reflect.DeepEqual(wantLabelLen, gotLabelLen) {
t.Errorf("Want label len: %v\n", wantLabelLen)
t.Errorf("Got label len: %v\n", gotLabelLen)
}
wantChar := rune(195)
gotChar := textData.LabelForChar(int64(1))
if !reflect.DeepEqual(wantChar, gotChar) {
t.Errorf("Want Char: %q\n", wantChar)
t.Errorf("Got Char: %q\n", gotChar)
}
}
func TestTextDataIter(t *testing.T) {
filename := "/tmp/test.txt"
filePath, err := filepath.Abs(filename)
if err != nil {
log.Fatal(err)
}
txt := "01234567890123456789"
// txt := `hello world`
err = ioutil.WriteFile(filePath, []byte(txt), 0644)
if err != nil {
log.Fatal(err)
}
textData, err := ts.NewTextData(filename)
if err != nil {
log.Fatal(err)
}
iter := textData.IterShuffle(2, 5) // (seqLen, batchSize)
// fmt.Printf("indexesLen: %v\n", iter.IndexesLen)
// fmt.Printf("data: %v\n", iter.Data.Int64Values())
// fmt.Printf("Indexes: %v\n", iter.Indexes.Int64Values())
for {
xs, ok := iter.Next()
if !ok {
break
}
size := xs.MustSize()
idxCol := ts.NewNarrow(0, size[0]) // column
idxCol1 := ts.NewNarrow(0, 1) // first column
// idxNextEl := ts.NewSelect(1)
col1 := xs.Idx([]ts.TensorIndexer{idxCol, idxCol1})
// nextEl := xs.Idx([]ts.TensorIndexer{idxNextEl})
// col1PlusOne := ts.MustStack([]ts.Tensor{col1, nextEl}, 0)
// col1Fmod := col1PlusOne.MustFmod(ts.IntScalar(10), false)
col1Fmod := col1.MustFmod(ts.IntScalar(10), false)
// t.Errorf("col1 shape: %v\n", col1Fmod.MustSize())
idxCol2 := ts.NewNarrow(1, 2)
col2 := xs.Idx([]ts.TensorIndexer{idxCol, idxCol2})
// t.Errorf("col2 shape: %v\n", col2.MustSize())
pow := col1Fmod.MustSub(col2, true).MustPowTensorScalar(ts.IntScalar(2), true)
sum := pow.MustSum(gotch.Float, true)
// Will pass if there's no panic
vals := sum.Int64Values()
t.Logf("sum: %v\n", vals)
}
}