feat(wrapper/tensor): CopyData function

This commit is contained in:
sugarme 2020-06-08 14:33:19 +10:00
parent fe6c76a2b8
commit 67fd311ed2
4 changed files with 20 additions and 14 deletions

View File

@ -198,7 +198,7 @@ func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
return goType, total, nil
case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
total++
if goType.String() != "invalid" {
goType = v.Type()

View File

@ -12,14 +12,14 @@ func main() {
// TODO: Check Go type of data and tensor DType
// For. if data is []int and DType is Bool
// It is still running but get wrong result.
// data := [][]int16{
// {1, 1, 1, 2, 2, 2, 3, 3},
// {1, 1, 1, 2, 2, 2, 4, 4},
// }
// shape := []int64{2, 8}
data := [][]int64{
{1, 1, 1, 2, 2, 2, 3, 3},
{1, 1, 1, 2, 2, 2, 4, 4},
}
shape := []int64{2, 8}
data := []int16{1, 1, 1, 2, 2, 2, 3, 3}
shape := []int64{1, 8}
// data := []int16{1, 1, 1, 2, 2, 2, 3, 3}
// shape := []int64{1, 8}
ts, err := wrapper.NewTensorFromData(data, shape)
if err != nil {
@ -28,9 +28,9 @@ func main() {
ts.Print()
numel := uint(11)
numel := uint(6)
// dst := make([]uint8, numel)
var dst = make([]uint8, 1)
var dst = make([]int64, 6)
ts.MustCopyData(dst, numel)

View File

@ -495,7 +495,12 @@ func (ts Tensor) MustCopyDataUint8(dst []uint8, numel uint) {
// NOTE: `dst` located in Go memory. Should it be?
func (ts Tensor) CopyData(dst interface{}, numel uint) (err error) {
dtype, dlen, err := DataCheck(dst)
gotype, dlen, err := DataCheck(dst)
if err != nil {
return err
}
dtype, err := gotch.ToDType(gotype)
if err != nil {
return err
}
@ -533,7 +538,7 @@ func (ts Tensor) CopyData(dst interface{}, numel uint) (err error) {
return err
}
elt_size_in_bytes, err := gotch.DTypeSize(dtype.(gotch.DType))
elt_size_in_bytes, err := gotch.DTypeSize(dtype)
if err != nil {
return err
}

View File

@ -162,6 +162,7 @@ func ElementCount(shape []int64) int64 {
func DataDim(data interface{}) (retVal int, err error) {
_, count, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
return count, err
}
@ -196,7 +197,7 @@ func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
return goType, total, nil
case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
total++
if goType.String() != "invalid" {
goType = v.Type()
@ -357,7 +358,7 @@ func flattenData(data interface{}, round int, flat []interface{}) (f []interface
return flatData, nil
case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
flatData = append(flatData, data)
}