fixed ts.OfSlice() not supporting []int data type
This commit is contained in:
parent
b738c52d5b
commit
961080760f
|
@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- added new API `ConstantPadNdWithVal` `ato_constant_pad_nd` with padding value.
|
||||
- fixed "nn/rnn NewLSTM() clashed weight names"
|
||||
- fixed some old API at `vision/aug/function.go`
|
||||
- fixed `tensor.OfSlice()` not supporting `[]int` data type
|
||||
|
||||
## [Nofix]
|
||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||
|
|
|
@ -155,6 +155,10 @@ func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
|
|||
|
||||
// OfSlice creates tensor from a slice data
|
||||
func OfSlice(data interface{}) (*Tensor, error) {
|
||||
// convert []int -> int32. `binary.Write()` can't write `[]int` because it's not fixed-size!
|
||||
if reflect.TypeOf(data).String() == "[]int" {
|
||||
data = sliceIntToInt32(data.([]int))
|
||||
}
|
||||
|
||||
typ, dataLen, err := DataCheck(data)
|
||||
if err != nil {
|
||||
|
|
|
@ -118,3 +118,16 @@ func TestOnehot(t *testing.T) {
|
|||
* vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]
|
||||
* );
|
||||
* assert_eq!(onehot.size(), vec![4, 4]) */
|
||||
|
||||
func TestOfSlice(t *testing.T) {
|
||||
data := []int{1, 2, 3, 4, 5}
|
||||
x := ts.MustOfSlice(data)
|
||||
|
||||
want := gotch.Int
|
||||
got := x.DType()
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("Expected dtype: %v\n", want)
|
||||
t.Errorf("Got dtype: %v\n", got)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
// "log"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
@ -481,3 +482,11 @@ func Must(ts Tensor, err error) (retVal Tensor) {
|
|||
}
|
||||
return ts
|
||||
}
|
||||
|
||||
func sliceIntToInt32(input []int) []int32 {
|
||||
out := make([]int32, len(input))
|
||||
for i := 0; i < len(input); i++ {
|
||||
out[i] = int32(input[i])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user