fixed ts.OfSlice() not supporting []int data type

This commit is contained in:
sugarme 2022-02-13 22:46:50 +11:00
parent b738c52d5b
commit 961080760f
4 changed files with 27 additions and 0 deletions

View File

@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- added new API `ConstantPadNdWithVal` `ato_constant_pad_nd` with padding value. - added new API `ConstantPadNdWithVal` `ato_constant_pad_nd` with padding value.
- fixed "nn/rnn NewLSTM() clashed weight names" - fixed "nn/rnn NewLSTM() clashed weight names"
- fixed some old API at `vision/aug/function.go` - fixed some old API at `vision/aug/function.go`
- fixed `tensor.OfSlice()` not supporting `[]int` data type
## [Nofix] ## [Nofix]
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box. - ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.

View File

@ -155,6 +155,10 @@ func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
// OfSlice creates tensor from a slice data // OfSlice creates tensor from a slice data
func OfSlice(data interface{}) (*Tensor, error) { func OfSlice(data interface{}) (*Tensor, error) {
// convert []int -> int32. `binary.Write()` can't write `[]int` because it's not fixed-size!
if reflect.TypeOf(data).String() == "[]int" {
data = sliceIntToInt32(data.([]int))
}
typ, dataLen, err := DataCheck(data) typ, dataLen, err := DataCheck(data)
if err != nil { if err != nil {

View File

@ -118,3 +118,16 @@ func TestOnehot(t *testing.T) {
* vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0] * vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]
* ); * );
* assert_eq!(onehot.size(), vec![4, 4]) */ * assert_eq!(onehot.size(), vec![4, 4]) */
func TestOfSlice(t *testing.T) {
data := []int{1, 2, 3, 4, 5}
x := ts.MustOfSlice(data)
want := gotch.Int
got := x.DType()
if !reflect.DeepEqual(want, got) {
t.Errorf("Expected dtype: %v\n", want)
t.Errorf("Got dtype: %v\n", got)
}
}

View File

@ -7,6 +7,7 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
// "log" // "log"
"reflect" "reflect"
"unsafe" "unsafe"
@ -481,3 +482,11 @@ func Must(ts Tensor, err error) (retVal Tensor) {
} }
return ts return ts
} }
func sliceIntToInt32(input []int) []int32 {
out := make([]int32, len(input))
for i := 0; i < len(input); i++ {
out[i] = int32(input[i])
}
return out
}