feat(tensor/data_test): added some unit tests for TextData

This commit is contained in:
sugarme 2020-07-25 17:42:15 +10:00
parent d1967817ae
commit cddf4c8a77
2 changed files with 55 additions and 2 deletions

55
tensor/data_test.go Normal file
View File

@ -0,0 +1,55 @@
package tensor_test
import (
ts "github.com/sugarme/gotch/tensor"
"io/ioutil"
"log"
"path/filepath"
"reflect"
"testing"
)
func TestTextData_NewTextData(t *testing.T) {
// Create text file to test
filename := "/tmp/test.txt"
filePath, err := filepath.Abs(filename)
if err != nil {
log.Fatal(err)
}
txt := `héllo`
// txt := "h\xC3\xA9llo"
err = ioutil.WriteFile(filePath, []byte(txt), 0644)
if err != nil {
log.Fatal(err)
}
textData, err := ts.NewTextData(filename)
if err != nil {
log.Fatal(err)
}
wantData := []float64{0, 1, 2, 3, 3, 4}
gotData := textData.CloneData().Float64Values()
if !reflect.DeepEqual(wantData, gotData) {
t.Errorf("Want data: %v\n", wantData)
t.Errorf("Got data: %v\n", gotData)
}
wantLabelLen := int64(5)
gotLabelLen := textData.Labels()
if !reflect.DeepEqual(wantLabelLen, gotLabelLen) {
t.Errorf("Want label len: %v\n", wantLabelLen)
t.Errorf("Got label len: %v\n", gotLabelLen)
}
wantChar := rune(195)
gotChar := textData.LabelForChar(int64(1))
if !reflect.DeepEqual(wantChar, gotChar) {
t.Errorf("Want Char: %q\n", wantChar)
t.Errorf("Got Char: %q\n", gotChar)
}
}

View File

@ -11,8 +11,6 @@ import (
func TestTensorInit(t *testing.T) {
tensor := ts.MustArange1(ts.IntScalar(1), ts.IntScalar(5), gotch.Int64, gotch.CPU)
tensor.Print()
want := []float64{1, 2, 3, 4}
got := tensor.Float64Values()