fixed ts.OfSlice() not supporting []int data type

This commit is contained in:
sugarme 2022-02-13 22:46:50 +11:00
parent b738c52d5b
commit 961080760f
4 changed files with 27 additions and 0 deletions

View File

@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- added new API `ConstantPadNdWithVal` `ato_constant_pad_nd` with padding value.
- fixed "nn/rnn NewLSTM() clashed weight names"
- fixed some old API at `vision/aug/function.go`
- fixed `tensor.OfSlice()` not supporting `[]int` data type
## [Nofix]
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.

View File

@ -155,6 +155,10 @@ func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
// OfSlice creates tensor from a slice data
func OfSlice(data interface{}) (*Tensor, error) {
// convert []int -> int32. `binary.Write()` can't write `[]int` because it's not fixed-size!
if reflect.TypeOf(data).String() == "[]int" {
data = sliceIntToInt32(data.([]int))
}
typ, dataLen, err := DataCheck(data)
if err != nil {

View File

@ -118,3 +118,16 @@ func TestOnehot(t *testing.T) {
* vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]
* );
* assert_eq!(onehot.size(), vec![4, 4]) */
func TestOfSlice(t *testing.T) {
data := []int{1, 2, 3, 4, 5}
x := ts.MustOfSlice(data)
want := gotch.Int
got := x.DType()
if !reflect.DeepEqual(want, got) {
t.Errorf("Expected dtype: %v\n", want)
t.Errorf("Got dtype: %v\n", got)
}
}

View File

@ -7,6 +7,7 @@ import (
"bytes"
"encoding/binary"
"fmt"
// "log"
"reflect"
"unsafe"
@ -481,3 +482,11 @@ func Must(ts Tensor, err error) (retVal Tensor) {
}
return ts
}
func sliceIntToInt32(input []int) []int32 {
out := make([]int32, len(input))
for i := 0; i < len(input); i++ {
out[i] = int32(input[i])
}
return out
}