diff --git a/dtype.go b/dtype.go index ef124b4..0d8f928 100644 --- a/dtype.go +++ b/dtype.go @@ -198,7 +198,7 @@ func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) { return goType, total, nil - case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: + case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: total++ if goType.String() != "invalid" { goType = v.Type() diff --git a/example/tensor-copy-data/main.go b/example/tensor-copy-data/main.go index 5e1db5f..b3bca48 100644 --- a/example/tensor-copy-data/main.go +++ b/example/tensor-copy-data/main.go @@ -12,14 +12,14 @@ func main() { // TODO: Check Go type of data and tensor DType // For. if data is []int and DType is Bool // It is still running but get wrong result. - // data := [][]int16{ - // {1, 1, 1, 2, 2, 2, 3, 3}, - // {1, 1, 1, 2, 2, 2, 4, 4}, - // } - // shape := []int64{2, 8} + data := [][]int64{ + {1, 1, 1, 2, 2, 2, 3, 3}, + {1, 1, 1, 2, 2, 2, 4, 4}, + } + shape := []int64{2, 8} - data := []int16{1, 1, 1, 2, 2, 2, 3, 3} - shape := []int64{1, 8} + // data := []int16{1, 1, 1, 2, 2, 2, 3, 3} + // shape := []int64{1, 8} ts, err := wrapper.NewTensorFromData(data, shape) if err != nil { @@ -28,9 +28,9 @@ func main() { ts.Print() - numel := uint(11) + numel := uint(6) // dst := make([]uint8, numel) - var dst = make([]uint8, 1) + var dst = make([]int64, 6) ts.MustCopyData(dst, numel) diff --git a/wrapper/tensor.go b/wrapper/tensor.go index baf6bef..a5aec7a 100644 --- a/wrapper/tensor.go +++ b/wrapper/tensor.go @@ -495,7 +495,12 @@ func (ts Tensor) MustCopyDataUint8(dst []uint8, numel uint) { // NOTE: `dst` located in Go memory. Should it be? func (ts Tensor) CopyData(dst interface{}, numel uint) (err error) { - dtype, dlen, err := DataCheck(dst) + gotype, dlen, err := DataCheck(dst) + if err != nil { + return err + } + + dtype, err := gotch.ToDType(gotype) if err != nil { return err } @@ -533,7 +538,7 @@ func (ts Tensor) CopyData(dst interface{}, numel uint) (err error) { return err } - elt_size_in_bytes, err := gotch.DTypeSize(dtype.(gotch.DType)) + elt_size_in_bytes, err := gotch.DTypeSize(dtype) if err != nil { return err } diff --git a/wrapper/util.go b/wrapper/util.go index 3a166f2..44b963a 100644 --- a/wrapper/util.go +++ b/wrapper/util.go @@ -162,6 +162,7 @@ func ElementCount(shape []int64) int64 { func DataDim(data interface{}) (retVal int, err error) { _, count, err := dataCheck(reflect.ValueOf(data).Interface(), 0) + return count, err } @@ -196,7 +197,7 @@ func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) { return goType, total, nil - case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: + case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: total++ if goType.String() != "invalid" { goType = v.Type() @@ -357,7 +358,7 @@ func flattenData(data interface{}, round int, flat []interface{}) (f []interface return flatData, nil - case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: + case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: flatData = append(flatData, data) }