From 961080760fa095dcf8df2ebf61673633260158e1 Mon Sep 17 00:00:00 2001 From: sugarme Date: Sun, 13 Feb 2022 22:46:50 +1100 Subject: [PATCH] fixed ts.OfSlice() not supporting []int data type --- CHANGELOG.md | 1 + tensor/tensor.go | 4 ++++ tensor/tensor_test.go | 13 +++++++++++++ tensor/util.go | 9 +++++++++ 4 files changed, 27 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 169a3b4..02f235a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - added new API `ConstantPadNdWithVal` `ato_constant_pad_nd` with padding value. - fixed "nn/rnn NewLSTM() clashed weight names" - fixed some old API at `vision/aug/function.go` +- fixed `tensor.OfSlice()` not supporting `[]int` data type ## [Nofix] - ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box. diff --git a/tensor/tensor.go b/tensor/tensor.go index 53dcb9a..71abe61 100644 --- a/tensor/tensor.go +++ b/tensor/tensor.go @@ -155,6 +155,10 @@ func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 { // OfSlice creates tensor from a slice data func OfSlice(data interface{}) (*Tensor, error) { + // convert []int -> int32. `binary.Write()` can't write `[]int` because it's not fixed-size! + if reflect.TypeOf(data).String() == "[]int" { + data = sliceIntToInt32(data.([]int)) + } typ, dataLen, err := DataCheck(data) if err != nil { diff --git a/tensor/tensor_test.go b/tensor/tensor_test.go index b5c4024..b8b58b0 100644 --- a/tensor/tensor_test.go +++ b/tensor/tensor_test.go @@ -118,3 +118,16 @@ func TestOnehot(t *testing.T) { * vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0] * ); * assert_eq!(onehot.size(), vec![4, 4]) */ + +func TestOfSlice(t *testing.T) { + data := []int{1, 2, 3, 4, 5} + x := ts.MustOfSlice(data) + + want := gotch.Int + got := x.DType() + + if !reflect.DeepEqual(want, got) { + t.Errorf("Expected dtype: %v\n", want) + t.Errorf("Got dtype: %v\n", got) + } +} diff --git a/tensor/util.go b/tensor/util.go index e5814da..da7d601 100644 --- a/tensor/util.go +++ b/tensor/util.go @@ -7,6 +7,7 @@ import ( "bytes" "encoding/binary" "fmt" + // "log" "reflect" "unsafe" @@ -481,3 +482,11 @@ func Must(ts Tensor, err error) (retVal Tensor) { } return ts } + +func sliceIntToInt32(input []int) []int32 { + out := make([]int32, len(input)) + for i := 0; i < len(input); i++ { + out[i] = int32(input[i]) + } + return out +}