feat(wrapper/tensor): CopyData function
This commit is contained in:
parent
fe6c76a2b8
commit
67fd311ed2
2
dtype.go
2
dtype.go
|
@ -198,7 +198,7 @@ func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
|
|||
|
||||
return goType, total, nil
|
||||
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||
total++
|
||||
if goType.String() != "invalid" {
|
||||
goType = v.Type()
|
||||
|
|
|
@ -12,14 +12,14 @@ func main() {
|
|||
// TODO: Check Go type of data and tensor DType
|
||||
// For. if data is []int and DType is Bool
|
||||
// It is still running but get wrong result.
|
||||
// data := [][]int16{
|
||||
// {1, 1, 1, 2, 2, 2, 3, 3},
|
||||
// {1, 1, 1, 2, 2, 2, 4, 4},
|
||||
// }
|
||||
// shape := []int64{2, 8}
|
||||
data := [][]int64{
|
||||
{1, 1, 1, 2, 2, 2, 3, 3},
|
||||
{1, 1, 1, 2, 2, 2, 4, 4},
|
||||
}
|
||||
shape := []int64{2, 8}
|
||||
|
||||
data := []int16{1, 1, 1, 2, 2, 2, 3, 3}
|
||||
shape := []int64{1, 8}
|
||||
// data := []int16{1, 1, 1, 2, 2, 2, 3, 3}
|
||||
// shape := []int64{1, 8}
|
||||
|
||||
ts, err := wrapper.NewTensorFromData(data, shape)
|
||||
if err != nil {
|
||||
|
@ -28,9 +28,9 @@ func main() {
|
|||
|
||||
ts.Print()
|
||||
|
||||
numel := uint(11)
|
||||
numel := uint(6)
|
||||
// dst := make([]uint8, numel)
|
||||
var dst = make([]uint8, 1)
|
||||
var dst = make([]int64, 6)
|
||||
|
||||
ts.MustCopyData(dst, numel)
|
||||
|
||||
|
|
|
@ -495,7 +495,12 @@ func (ts Tensor) MustCopyDataUint8(dst []uint8, numel uint) {
|
|||
// NOTE: `dst` located in Go memory. Should it be?
|
||||
func (ts Tensor) CopyData(dst interface{}, numel uint) (err error) {
|
||||
|
||||
dtype, dlen, err := DataCheck(dst)
|
||||
gotype, dlen, err := DataCheck(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dtype, err := gotch.ToDType(gotype)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -533,7 +538,7 @@ func (ts Tensor) CopyData(dst interface{}, numel uint) (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
elt_size_in_bytes, err := gotch.DTypeSize(dtype.(gotch.DType))
|
||||
elt_size_in_bytes, err := gotch.DTypeSize(dtype)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -162,6 +162,7 @@ func ElementCount(shape []int64) int64 {
|
|||
func DataDim(data interface{}) (retVal int, err error) {
|
||||
|
||||
_, count, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
|
||||
|
||||
return count, err
|
||||
}
|
||||
|
||||
|
@ -196,7 +197,7 @@ func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
|
|||
|
||||
return goType, total, nil
|
||||
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||
total++
|
||||
if goType.String() != "invalid" {
|
||||
goType = v.Type()
|
||||
|
@ -357,7 +358,7 @@ func flattenData(data interface{}, round int, flat []interface{}) (f []interface
|
|||
|
||||
return flatData, nil
|
||||
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||
flatData = append(flatData, data)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user