feat(wrapper/tensor): CopyData function
This commit is contained in:
parent
fe6c76a2b8
commit
67fd311ed2
2
dtype.go
2
dtype.go
|
@ -198,7 +198,7 @@ func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
|
||||||
|
|
||||||
return goType, total, nil
|
return goType, total, nil
|
||||||
|
|
||||||
case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||||
total++
|
total++
|
||||||
if goType.String() != "invalid" {
|
if goType.String() != "invalid" {
|
||||||
goType = v.Type()
|
goType = v.Type()
|
||||||
|
|
|
@ -12,14 +12,14 @@ func main() {
|
||||||
// TODO: Check Go type of data and tensor DType
|
// TODO: Check Go type of data and tensor DType
|
||||||
// For. if data is []int and DType is Bool
|
// For. if data is []int and DType is Bool
|
||||||
// It is still running but get wrong result.
|
// It is still running but get wrong result.
|
||||||
// data := [][]int16{
|
data := [][]int64{
|
||||||
// {1, 1, 1, 2, 2, 2, 3, 3},
|
{1, 1, 1, 2, 2, 2, 3, 3},
|
||||||
// {1, 1, 1, 2, 2, 2, 4, 4},
|
{1, 1, 1, 2, 2, 2, 4, 4},
|
||||||
// }
|
}
|
||||||
// shape := []int64{2, 8}
|
shape := []int64{2, 8}
|
||||||
|
|
||||||
data := []int16{1, 1, 1, 2, 2, 2, 3, 3}
|
// data := []int16{1, 1, 1, 2, 2, 2, 3, 3}
|
||||||
shape := []int64{1, 8}
|
// shape := []int64{1, 8}
|
||||||
|
|
||||||
ts, err := wrapper.NewTensorFromData(data, shape)
|
ts, err := wrapper.NewTensorFromData(data, shape)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -28,9 +28,9 @@ func main() {
|
||||||
|
|
||||||
ts.Print()
|
ts.Print()
|
||||||
|
|
||||||
numel := uint(11)
|
numel := uint(6)
|
||||||
// dst := make([]uint8, numel)
|
// dst := make([]uint8, numel)
|
||||||
var dst = make([]uint8, 1)
|
var dst = make([]int64, 6)
|
||||||
|
|
||||||
ts.MustCopyData(dst, numel)
|
ts.MustCopyData(dst, numel)
|
||||||
|
|
||||||
|
|
|
@ -495,7 +495,12 @@ func (ts Tensor) MustCopyDataUint8(dst []uint8, numel uint) {
|
||||||
// NOTE: `dst` located in Go memory. Should it be?
|
// NOTE: `dst` located in Go memory. Should it be?
|
||||||
func (ts Tensor) CopyData(dst interface{}, numel uint) (err error) {
|
func (ts Tensor) CopyData(dst interface{}, numel uint) (err error) {
|
||||||
|
|
||||||
dtype, dlen, err := DataCheck(dst)
|
gotype, dlen, err := DataCheck(dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dtype, err := gotch.ToDType(gotype)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -533,7 +538,7 @@ func (ts Tensor) CopyData(dst interface{}, numel uint) (err error) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
elt_size_in_bytes, err := gotch.DTypeSize(dtype.(gotch.DType))
|
elt_size_in_bytes, err := gotch.DTypeSize(dtype)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -162,6 +162,7 @@ func ElementCount(shape []int64) int64 {
|
||||||
func DataDim(data interface{}) (retVal int, err error) {
|
func DataDim(data interface{}) (retVal int, err error) {
|
||||||
|
|
||||||
_, count, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
|
_, count, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
|
||||||
|
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -196,7 +197,7 @@ func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
|
||||||
|
|
||||||
return goType, total, nil
|
return goType, total, nil
|
||||||
|
|
||||||
case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||||
total++
|
total++
|
||||||
if goType.String() != "invalid" {
|
if goType.String() != "invalid" {
|
||||||
goType = v.Type()
|
goType = v.Type()
|
||||||
|
@ -357,7 +358,7 @@ func flattenData(data interface{}, round int, flat []interface{}) (f []interface
|
||||||
|
|
||||||
return flatData, nil
|
return flatData, nil
|
||||||
|
|
||||||
case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||||
flatData = append(flatData, data)
|
flatData = append(flatData, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user