feat(wrapper/tensor): CopyData function

This commit is contained in:
sugarme 2020-06-08 14:33:19 +10:00
parent fe6c76a2b8
commit 67fd311ed2
4 changed files with 20 additions and 14 deletions

View File

@ -198,7 +198,7 @@ func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
return goType, total, nil return goType, total, nil
case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
total++ total++
if goType.String() != "invalid" { if goType.String() != "invalid" {
goType = v.Type() goType = v.Type()

View File

@ -12,14 +12,14 @@ func main() {
// TODO: Check Go type of data and tensor DType // TODO: Check Go type of data and tensor DType
// For. if data is []int and DType is Bool // For. if data is []int and DType is Bool
// It is still running but get wrong result. // It is still running but get wrong result.
// data := [][]int16{ data := [][]int64{
// {1, 1, 1, 2, 2, 2, 3, 3}, {1, 1, 1, 2, 2, 2, 3, 3},
// {1, 1, 1, 2, 2, 2, 4, 4}, {1, 1, 1, 2, 2, 2, 4, 4},
// } }
// shape := []int64{2, 8} shape := []int64{2, 8}
data := []int16{1, 1, 1, 2, 2, 2, 3, 3} // data := []int16{1, 1, 1, 2, 2, 2, 3, 3}
shape := []int64{1, 8} // shape := []int64{1, 8}
ts, err := wrapper.NewTensorFromData(data, shape) ts, err := wrapper.NewTensorFromData(data, shape)
if err != nil { if err != nil {
@ -28,9 +28,9 @@ func main() {
ts.Print() ts.Print()
numel := uint(11) numel := uint(6)
// dst := make([]uint8, numel) // dst := make([]uint8, numel)
var dst = make([]uint8, 1) var dst = make([]int64, 6)
ts.MustCopyData(dst, numel) ts.MustCopyData(dst, numel)

View File

@ -495,7 +495,12 @@ func (ts Tensor) MustCopyDataUint8(dst []uint8, numel uint) {
// NOTE: `dst` located in Go memory. Should it be? // NOTE: `dst` located in Go memory. Should it be?
func (ts Tensor) CopyData(dst interface{}, numel uint) (err error) { func (ts Tensor) CopyData(dst interface{}, numel uint) (err error) {
dtype, dlen, err := DataCheck(dst) gotype, dlen, err := DataCheck(dst)
if err != nil {
return err
}
dtype, err := gotch.ToDType(gotype)
if err != nil { if err != nil {
return err return err
} }
@ -533,7 +538,7 @@ func (ts Tensor) CopyData(dst interface{}, numel uint) (err error) {
return err return err
} }
elt_size_in_bytes, err := gotch.DTypeSize(dtype.(gotch.DType)) elt_size_in_bytes, err := gotch.DTypeSize(dtype)
if err != nil { if err != nil {
return err return err
} }

View File

@ -162,6 +162,7 @@ func ElementCount(shape []int64) int64 {
func DataDim(data interface{}) (retVal int, err error) { func DataDim(data interface{}) (retVal int, err error) {
_, count, err := dataCheck(reflect.ValueOf(data).Interface(), 0) _, count, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
return count, err return count, err
} }
@ -196,7 +197,7 @@ func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
return goType, total, nil return goType, total, nil
case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
total++ total++
if goType.String() != "invalid" { if goType.String() != "invalid" {
goType = v.Type() goType = v.Type()
@ -357,7 +358,7 @@ func flattenData(data interface{}, round int, flat []interface{}) (f []interface
return flatData, nil return flatData, nil
case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
flatData = append(flatData, data) flatData = append(flatData, data)
} }