feat(dtype), WIP(wrapper), example/tensor
This commit is contained in:
parent
773f423fff
commit
6b0d6105ae
250
dtype.go
Normal file
250
dtype.go
Normal file
|
@ -0,0 +1,250 @@
|
||||||
|
package gotch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
// "log"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CInt is equal to C type int. Go type is int32
|
||||||
|
type CInt = int32
|
||||||
|
|
||||||
|
// DType represents different kind of element that a tensor can hold.
|
||||||
|
// It has an embedded `reflect.Type` for type reflection.
|
||||||
|
type DType struct {
|
||||||
|
reflect.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* // Custom-made Float16 as not exist in Go
|
||||||
|
* // Ref: https://github.com/golang/go/issues/32022
|
||||||
|
* type GoFloat16 = int16 // not implemented yet
|
||||||
|
* type GoComplexHalf = interface{} // not implemented yet!
|
||||||
|
* */
|
||||||
|
|
||||||
|
// TODO: double check these Torch DType to Go type
|
||||||
|
var (
|
||||||
|
Uint8 DType = DType{reflect.TypeOf(uint8(1))} // 0
|
||||||
|
Int8 DType = DType{reflect.TypeOf(int8(1))} // 1
|
||||||
|
Int16 DType = DType{reflect.TypeOf(int16(1))} // 2
|
||||||
|
Int DType = DType{reflect.TypeOf(int32(1))} // 3
|
||||||
|
Int64 DType = DType{reflect.TypeOf(int64(1))} // 4
|
||||||
|
// Half DType = DType{reflect.TypeOf(GoFloat16(1))} // 5
|
||||||
|
Float DType = DType{reflect.TypeOf(float32(1))} // 6
|
||||||
|
Double DType = DType{reflect.TypeOf(float64(1))} // 7
|
||||||
|
// ComplexHalf DType = DType{reflect.TypeOf(GoComplexHalf(1))} // 8
|
||||||
|
// ComplexFloat DType = DType{reflect.TypeOf(complex64(1))} // 9
|
||||||
|
// ComplexDouble DType = DType{reflect.TypeOf(complex128(1))} // 10
|
||||||
|
Bool DType = DType{reflect.TypeOf(true)} // 11
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
* // ToCInt converts DType to CInt type value which is `C int`
|
||||||
|
* func (dt DType) ToCInt() CInt {
|
||||||
|
* switch dt.Kind() {
|
||||||
|
* case reflect.Uint8:
|
||||||
|
* return 0
|
||||||
|
* case reflect.Int8:
|
||||||
|
* return 1
|
||||||
|
* case reflect.Int16:
|
||||||
|
* return 2
|
||||||
|
* case reflect.Int32:
|
||||||
|
* return 3
|
||||||
|
* case reflect.Int64:
|
||||||
|
* return 4
|
||||||
|
* case reflect.Float32:
|
||||||
|
* return 6
|
||||||
|
* case reflect.Float64:
|
||||||
|
* return 7
|
||||||
|
* case reflect.Bool:
|
||||||
|
* return 11
|
||||||
|
* default:
|
||||||
|
* log.Fatalf("Unsupported type.")
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* // unreachable
|
||||||
|
* return CInt(-1)
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* // OfCInt converts a value of type CInt to DType type value
|
||||||
|
* func (dt DType) OfCInt(v CInt) DType {
|
||||||
|
* switch v {
|
||||||
|
* case 0:
|
||||||
|
* return Uint8
|
||||||
|
* case 1:
|
||||||
|
* return Int8
|
||||||
|
* case 2:
|
||||||
|
* return Int16
|
||||||
|
* case 3:
|
||||||
|
* return Int
|
||||||
|
* case 4:
|
||||||
|
* return Int64
|
||||||
|
* case 6:
|
||||||
|
* return Float
|
||||||
|
* case 7:
|
||||||
|
* return Double
|
||||||
|
* case 8:
|
||||||
|
* case 11:
|
||||||
|
* return Bool
|
||||||
|
* default:
|
||||||
|
* log.Fatalf("Unexpected DType %v\n", v)
|
||||||
|
* }
|
||||||
|
* return DType{reflect.TypeOf(false)}
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* // EltSizeInBytes converts a DType value to number of bytes
|
||||||
|
* // This is a ELement Size In Bytes in Libtorch.
|
||||||
|
* // Has it been deprecated?
|
||||||
|
* func (dt DType) EltSizeInBytes() uint {
|
||||||
|
* switch dt.Kind() {
|
||||||
|
* case reflect.Uint8:
|
||||||
|
* return 1
|
||||||
|
* case reflect.Int8:
|
||||||
|
* return 1
|
||||||
|
* case reflect.Int16:
|
||||||
|
* return 2
|
||||||
|
* case reflect.Int:
|
||||||
|
* return 4
|
||||||
|
* case reflect.Int64:
|
||||||
|
* return 8
|
||||||
|
* case reflect.Float32:
|
||||||
|
* return 4
|
||||||
|
* case reflect.Float64:
|
||||||
|
* return 8
|
||||||
|
* case reflect.Bool:
|
||||||
|
* return 1
|
||||||
|
* default:
|
||||||
|
* log.Fatalf("Unsupported Type %v\n", dt.Type)
|
||||||
|
* }
|
||||||
|
* return uint(0)
|
||||||
|
* }
|
||||||
|
* */
|
||||||
|
|
||||||
|
// ToGoType converts DType to Go type
|
||||||
|
func (dt DType) ToGoType() reflect.Type {
|
||||||
|
return dt.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
var dtypeCInt = map[DType]CInt{
|
||||||
|
Uint8: 0,
|
||||||
|
Int8: 1,
|
||||||
|
Int16: 2,
|
||||||
|
Int: 3,
|
||||||
|
Int64: 4,
|
||||||
|
Float: 6,
|
||||||
|
Double: 7,
|
||||||
|
Bool: 11,
|
||||||
|
}
|
||||||
|
|
||||||
|
func DType2CInt(dt DType) CInt {
|
||||||
|
return dtypeCInt[dt]
|
||||||
|
}
|
||||||
|
|
||||||
|
func CInt2DType(v CInt) (dtype DType, err error) {
|
||||||
|
var found = false
|
||||||
|
for key, val := range dtypeCInt {
|
||||||
|
if val == v {
|
||||||
|
dtype = key
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
err = fmt.Errorf("Unsuported DType for CInt %v\n", v)
|
||||||
|
return DType{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return dtype, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// dtypeSize is a map of DType and its size in Bytes
|
||||||
|
var dtypeSize = map[DType]uint{
|
||||||
|
Uint8: 1,
|
||||||
|
Int8: 1,
|
||||||
|
Int16: 2,
|
||||||
|
Int: 4,
|
||||||
|
Int64: 8,
|
||||||
|
Float: 4,
|
||||||
|
Double: 8,
|
||||||
|
Bool: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
// DTypeSize returns DType size in Bytes
|
||||||
|
func DTypeSize(dt DType) uint {
|
||||||
|
return dtypeSize[dt]
|
||||||
|
}
|
||||||
|
|
||||||
|
type DTypeDevice struct {
|
||||||
|
DType DType
|
||||||
|
Device Device
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
FloatCPU DTypeDevice = DTypeDevice{Float, CPU}
|
||||||
|
DoubleCPU DTypeDevice = DTypeDevice{Double, CPU}
|
||||||
|
Int64CPU DTypeDevice = DTypeDevice{Int64, CPU}
|
||||||
|
|
||||||
|
FloatCUDA DTypeDevice = DTypeDevice{Float, CudaBuilder(0)}
|
||||||
|
DoubleCUDA DTypeDevice = DTypeDevice{Double, CudaBuilder(0)}
|
||||||
|
Int64CUDA DTypeDevice = DTypeDevice{Int64, CudaBuilder(0)}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Type Inferring:
|
||||||
|
// ===============
|
||||||
|
|
||||||
|
// DataDType infers and returns data type of tensor data
|
||||||
|
func DataDType(v interface{}, shape []int64) (retVal DType, err error) {
|
||||||
|
// assuming that all elements in data have the same type
|
||||||
|
switch {
|
||||||
|
case len(shape) == 0:
|
||||||
|
retVal, err = ElementDType(v)
|
||||||
|
case len(shape) > 0:
|
||||||
|
return ElementDType(v.([]interface{})[0])
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
|
||||||
|
return DType{}, err
|
||||||
|
}
|
||||||
|
return DType{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ElementDType infers and returns its own tensor data type
|
||||||
|
func ElementDType(v interface{}) (retVal DType, err error) {
|
||||||
|
switch v.(type) {
|
||||||
|
case uint8:
|
||||||
|
retVal = Uint8
|
||||||
|
case int8:
|
||||||
|
retVal = Int8
|
||||||
|
case int16:
|
||||||
|
retVal = Int16
|
||||||
|
case int32:
|
||||||
|
retVal = Int
|
||||||
|
case int64:
|
||||||
|
retVal = Int64
|
||||||
|
case float32:
|
||||||
|
retVal = Float
|
||||||
|
case float64:
|
||||||
|
retVal = Double
|
||||||
|
case bool:
|
||||||
|
retVal = Bool
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeOf infers and returns element Go type from given tensor DType and shape
|
||||||
|
func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) {
|
||||||
|
typ := dt.ToGoType()
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(shape) == 0:
|
||||||
|
return typ, nil
|
||||||
|
case len(shape) > 0:
|
||||||
|
return reflect.SliceOf(typ), nil
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("Unsupported data type.")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,15 +1,25 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"log"
|
||||||
|
|
||||||
tensor "github.com/sugarme/gotch/tensor"
|
gotch "github.com/sugarme/gotch"
|
||||||
|
wrapper "github.com/sugarme/gotch/wrapper"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
_, err := tensor.FnOfSlice()
|
|
||||||
|
// TODO: Check Go type of data and tensor DType
|
||||||
|
// For. if data is []int and DType is Bool
|
||||||
|
// It is still running but get wrong result.
|
||||||
|
data := []bool{true, true, false}
|
||||||
|
dtype := gotch.Bool
|
||||||
|
|
||||||
|
ts := wrapper.NewTensor()
|
||||||
|
sliceTensor, err := ts.FOfSlice(data, dtype)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sliceTensor.Print()
|
||||||
}
|
}
|
||||||
|
|
135
kind.go
135
kind.go
|
@ -1,135 +0,0 @@
|
||||||
package gotch
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CInt is equal to C type int. Go type is int32
|
|
||||||
type CInt = int32
|
|
||||||
|
|
||||||
// Kind represents different kind of element that a tensor can hold.
|
|
||||||
// It has an embedded `reflect.Type` for type reflection.
|
|
||||||
type Kind struct {
|
|
||||||
reflect.Type
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: double check these Torch DType to Go type
|
|
||||||
var (
|
|
||||||
Uint8 = Kind{reflect.TypeOf(uint8(1))} // 0
|
|
||||||
Int8 = Kind{reflect.TypeOf(int8(1))} // 1
|
|
||||||
Int16 = Kind{reflect.TypeOf(int16(1))} // 2
|
|
||||||
Int = Kind{reflect.TypeOf(int(1))} // 3
|
|
||||||
Int64 = Kind{reflect.TypeOf(int64(1))} // 4
|
|
||||||
Half = Kind{reflect.TypeOf(float32(1))} // 5
|
|
||||||
Float = Kind{reflect.TypeOf(float64(1))} // 6
|
|
||||||
Double = Kind{reflect.TypeOf(float64(1))} // 7
|
|
||||||
ComplexHalf = kind{reflect.TypeOf(complex(1))} // 8
|
|
||||||
ComplexFloat = Kind{reflect.TypeOf(complex64(1))} // 9
|
|
||||||
ComplexDouble = kind{reflect.TypeOf(complex128(1))} // 10
|
|
||||||
Bool = kind{reflect.TypeOf(true)} // 11
|
|
||||||
)
|
|
||||||
|
|
||||||
// ToCInt converts Kind to CInt type value which is `C int`
|
|
||||||
func (k Kind) ToCInt() CInt {
|
|
||||||
switch {
|
|
||||||
case k.Kind() == uint8:
|
|
||||||
return 0
|
|
||||||
case k.Kind() == int8:
|
|
||||||
return 1
|
|
||||||
case k.Kind() == int16:
|
|
||||||
return 2
|
|
||||||
case k.Kind() == int:
|
|
||||||
return 3
|
|
||||||
case k.Kind() == int64:
|
|
||||||
return 4
|
|
||||||
case k.Kind() == float32:
|
|
||||||
return 5
|
|
||||||
default:
|
|
||||||
log.Fatalf("Unsupported type.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// unreachable
|
|
||||||
return CInt(-1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OfCInt converts a value of type CInt to Kind type value
|
|
||||||
func (k Kind) OfCInt(v CInt) Kind {
|
|
||||||
switch v {
|
|
||||||
case 0:
|
|
||||||
return Uint8
|
|
||||||
case 1:
|
|
||||||
return Int8
|
|
||||||
case 2:
|
|
||||||
return Int16
|
|
||||||
case 3:
|
|
||||||
return Int
|
|
||||||
case 4:
|
|
||||||
return Int64
|
|
||||||
case 5:
|
|
||||||
return Half
|
|
||||||
case 6:
|
|
||||||
return Float
|
|
||||||
case 7:
|
|
||||||
return Double
|
|
||||||
case 8:
|
|
||||||
return ComplexHalf
|
|
||||||
case 9:
|
|
||||||
return ComplexFloat
|
|
||||||
case 10:
|
|
||||||
return ComplexDouble
|
|
||||||
case 11:
|
|
||||||
return Bool
|
|
||||||
default:
|
|
||||||
log.Fatalf("Unexpected kind %v\n", v)
|
|
||||||
}
|
|
||||||
return Kind{reflect.TypeOf(false)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// EltSizeInBytes converts a Kind value to number of bytes
|
|
||||||
// This is a ELement Size In Byte in Libtorch.
|
|
||||||
// Has it been deprecated?
|
|
||||||
func (k Kind) EltSizeInBytes() uint {
|
|
||||||
switch {
|
|
||||||
case k.ToCInt() == int32(Uint8):
|
|
||||||
return 1
|
|
||||||
case k.ToCInt() == int32(Int8):
|
|
||||||
return 1
|
|
||||||
case k.ToCInt() == int32(Int16):
|
|
||||||
return 2
|
|
||||||
case k.ToCInt() == int32(Int):
|
|
||||||
return 4
|
|
||||||
case k.ToCInt() == int32(Int64):
|
|
||||||
return 8
|
|
||||||
case k.ToCInt() == int32(Half):
|
|
||||||
return 2
|
|
||||||
case k.ToCInt() == int32(Float):
|
|
||||||
return 4
|
|
||||||
case k.ToCInt() == int32(Double):
|
|
||||||
return 8
|
|
||||||
case k.ToCInt() == int32(ComplexHalf):
|
|
||||||
return 4
|
|
||||||
case k.ToCInt() == int32(ComplexDouble):
|
|
||||||
return 16
|
|
||||||
case k.ToCInt() == int32(Bool):
|
|
||||||
return 1
|
|
||||||
default:
|
|
||||||
log.Fatalf("Unreachable")
|
|
||||||
}
|
|
||||||
return uint(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type KindDevice struct {
|
|
||||||
Kind Kind
|
|
||||||
Device Device
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
FloatCPU KindDevice = KindDevice{Float, CPU}
|
|
||||||
DoubleCPU KindDevice = KindDevice{Double, CPU}
|
|
||||||
Int64CPU KindDevice = KindDevice{Int64, CPU}
|
|
||||||
|
|
||||||
FloatCUDA KindDevice = KindDevice{Float, CudaBuilder(0)}
|
|
||||||
DoubleCUDA KindDevice = KindDevice{Double, CudaBuilder(0)}
|
|
||||||
Int64CUDA KindDevice = KindDevice{Int64, CudaBuilder(0)}
|
|
||||||
)
|
|
|
@ -5,8 +5,8 @@ package libtch
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
// "fmt"
|
||||||
"reflect"
|
// "reflect"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -33,27 +33,27 @@ func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_byt
|
||||||
|
|
||||||
// t is of type `unsafe.Pointer` in Go and `*void` in C
|
// t is of type `unsafe.Pointer` in Go and `*void` in C
|
||||||
t := C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
|
t := C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
|
||||||
fmt.Printf("t type: %v\n", reflect.TypeOf(t).Kind())
|
// fmt.Printf("t type: %v\n", reflect.TypeOf(t).Kind())
|
||||||
fmt.Printf("1. C.tensor AtTensorOfData returned from C call: %v\n", t)
|
// fmt.Printf("1. C.tensor AtTensorOfData returned from C call: %v\n", t)
|
||||||
// Keep C pointer value tin Go struct
|
// Keep C pointer value tin Go struct
|
||||||
cTensorPtrVal := unsafe.Pointer(t)
|
cTensorPtrVal := unsafe.Pointer(t)
|
||||||
fmt.Printf("2. cTensorPtrVal: %v\n", cTensorPtrVal)
|
// fmt.Printf("2. cTensorPtrVal: %v\n", cTensorPtrVal)
|
||||||
|
|
||||||
var retVal *C_tensor
|
var retVal *C_tensor
|
||||||
retVal = &C_tensor{private: cTensorPtrVal}
|
retVal = &C_tensor{private: cTensorPtrVal}
|
||||||
fmt.Printf("3. C_tensor.private: %v\n", (*retVal).private)
|
// fmt.Printf("3. C_tensor.private: %v\n", (*retVal).private)
|
||||||
|
|
||||||
// test call C.at_print to print out tensor
|
// test call C.at_print to print out tensor
|
||||||
// C.at_print(*(*C.tensor)(unsafe.Pointer(&t)))
|
// C.at_print(*(*C.tensor)(unsafe.Pointer(&t)))
|
||||||
AtPrint(retVal)
|
// AtPrint(retVal)
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
func AtPrint(t *C_tensor) {
|
func AtPrint(t *C_tensor) {
|
||||||
fmt.Printf("4. C_tensor.private AtPrint: %v\n", (*t).private)
|
// fmt.Printf("4. C_tensor.private AtPrint: %v\n", (*t).private)
|
||||||
cTensor := (C.tensor)((*t).private)
|
cTensor := (C.tensor)((*t).private)
|
||||||
fmt.Printf("5. C.tensor AtPrint: %v\n", cTensor)
|
// fmt.Printf("5. C.tensor AtPrint: %v\n", cTensor)
|
||||||
|
|
||||||
C.at_print(cTensor)
|
C.at_print(cTensor)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,20 +1,18 @@
|
||||||
package wrapper
|
package wrapper
|
||||||
|
|
||||||
//#include <stdlib.h>
|
// #include <stdlib.h>
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
// "fmt"
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
|
gotch "github.com/sugarme/gotch"
|
||||||
lib "github.com/sugarme/gotch/libtch"
|
lib "github.com/sugarme/gotch/libtch"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Tensor struct {
|
type Tensor struct {
|
||||||
ctensor *t.C_tensor
|
ctensor *lib.C_tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTensor creates a new tensor
|
// NewTensor creates a new tensor
|
||||||
|
@ -24,45 +22,43 @@ func NewTensor() Tensor {
|
||||||
}
|
}
|
||||||
|
|
||||||
// FOfSlice creates tensor from a slice data
|
// FOfSlice creates tensor from a slice data
|
||||||
func(ts Tensor) FOfSlice(data []inteface{}) (retVal Tensor, err error) {
|
func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor, err error) {
|
||||||
|
|
||||||
data := []int{0, 0, 0, 0}
|
dataLen := reflect.ValueOf(data).Len()
|
||||||
shape := []int64{int64(len(data))}
|
shape := []int64{int64(dataLen)}
|
||||||
nflattened := numElements(shape)
|
elementNum := ElementCount(shape)
|
||||||
dtype := 3 // Kind.Int
|
// eltSizeInBytes := dtype.EltSizeInBytes() // Element Size in Byte for Int dtype
|
||||||
eltSizeInBytes := 4 // Element Size in Byte for Int dtype
|
eltSizeInBytes := gotch.DTypeSize(dtype)
|
||||||
|
|
||||||
nbytes := eltSizeInBytes * int(uintptr(nflattened))
|
nbytes := int(eltSizeInBytes) * int(elementNum)
|
||||||
|
|
||||||
// NOTE: dataPrt is type of `*void` in C or type of `unsafe.Pointer` in Go
|
dataPtr, buff := CMalloc(nbytes)
|
||||||
// data should be allocated to memory BY `C` side
|
|
||||||
dataPtr := C.malloc(C.size_t(nbytes))
|
|
||||||
|
|
||||||
// Recall: 1 << 30 = 1 * 2 * 30
|
if err = EncodeTensor(buff, reflect.ValueOf(data), shape); err != nil {
|
||||||
// Ref. See more at https://stackoverflow.com/questions/48756732
|
return nil, err
|
||||||
dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes]
|
|
||||||
|
|
||||||
buf := bytes.NewBuffer(dataSlice[:0:nbytes])
|
|
||||||
|
|
||||||
EncodeTensor(buf, reflect.ValueOf(data), shape)
|
|
||||||
|
|
||||||
c_tensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(dtype))
|
|
||||||
|
|
||||||
retVal = Tensor{c_tensor}
|
|
||||||
|
|
||||||
// Read back created tensor values by C libtorch
|
|
||||||
readDataPtr := lib.AtDataPtr(retVal.c_tensor)
|
|
||||||
readDataSlice := (*[1 << 30]byte)(readDataPtr)[:nbytes:nbytes]
|
|
||||||
// typ := typeOf(dtype, shape)
|
|
||||||
typ := reflect.TypeOf(int32(0)) // C. type `int` ~ Go type `int32`
|
|
||||||
val := reflect.New(typ)
|
|
||||||
if err := DecodeTensor(bytes.NewReader(readDataSlice), shape, typ, val); err != nil {
|
|
||||||
panic(fmt.Sprintf("unable to decode Tensor of type %v and shape %v - %v", dtype, shape, err))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorData := reflect.Indirect(val).Interface()
|
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(gotch.DType2CInt(dtype)))
|
||||||
|
|
||||||
fmt.Println("%v", tensorData)
|
retVal = &Tensor{ctensor}
|
||||||
|
|
||||||
|
// Read back created tensor values by C libtorch
|
||||||
|
// readDataPtr := lib.AtDataPtr(retVal.ctensor)
|
||||||
|
// readDataSlice := (*[1 << 30]byte)(readDataPtr)[:nbytes:nbytes]
|
||||||
|
// // typ := typeOf(dtype, shape)
|
||||||
|
// typ := reflect.TypeOf(int32(0)) // C. type `int` ~ Go type `int32`
|
||||||
|
// val := reflect.New(typ)
|
||||||
|
// if err := DecodeTensor(bytes.NewReader(readDataSlice), shape, typ, val); err != nil {
|
||||||
|
// panic(fmt.Sprintf("unable to decode Tensor of type %v and shape %v - %v", dtype, shape, err))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// tensorData := reflect.Indirect(val).Interface()
|
||||||
|
//
|
||||||
|
// fmt.Println("%v", tensorData)
|
||||||
|
|
||||||
return retVal, nil
|
return retVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Print() {
|
||||||
|
lib.AtPrint(ts.ctensor)
|
||||||
|
}
|
||||||
|
|
|
@ -1,16 +1,20 @@
|
||||||
package wrapper
|
package wrapper
|
||||||
|
|
||||||
|
// #include <stdlib.h>
|
||||||
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
// gotch "github.com/sugarme/gotch"
|
||||||
gotch "github.com/sugarme/gotch"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// nativeEndian is a ByteOrder for local platform.
|
||||||
|
// Ref. https://stackoverflow.com/a/53286786
|
||||||
|
// Ref. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/go/tensor.go#L488-L505
|
||||||
var nativeEndian binary.ByteOrder
|
var nativeEndian binary.ByteOrder
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -27,6 +31,36 @@ func init() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CMalloc allocates a given number of bytes to C side memory.
|
||||||
|
// It returns
|
||||||
|
// - dataPtr: a C pointer type of `*void` (`unsafe.Pointer` in Go).
|
||||||
|
// - buf : a Go pointer points to a given bytes of buffer (empty) in C memory
|
||||||
|
// allocated by C waiting for writing data to.
|
||||||
|
//
|
||||||
|
// NOTE:
|
||||||
|
// 1. Go pointer is a pointer to Go memory. C pointer is a pointer to C memory.
|
||||||
|
// 2. General rule is Go code can use C pointers. Go code may pass Go pointer to C
|
||||||
|
// provided that the Go memory to which it points does NOT contain any Go
|
||||||
|
// pointers. BUT C code must not store any Go pointers in Go memory, even
|
||||||
|
// temporarily.
|
||||||
|
// 3. Some Go values contain Go pointers IMPLICITLY: strings, slices, maps,
|
||||||
|
// channels and function values. Thus, pointers to these values should not be
|
||||||
|
// passed to C side. Instead, data should be allocated to C memory and return a
|
||||||
|
// C pointer to it using `C.malloc`.
|
||||||
|
// Ref: https://github.com/golang/proposal/blob/master/design/12416-cgo-pointers.md
|
||||||
|
func CMalloc(nbytes int) (dataPtr unsafe.Pointer, buf *bytes.Buffer) {
|
||||||
|
|
||||||
|
dataPtr = C.malloc(C.size_t(nbytes))
|
||||||
|
|
||||||
|
// Recall: 1 << 30 = 1 * 2 * 30
|
||||||
|
// Ref. See more at https://stackoverflow.com/questions/48756732
|
||||||
|
dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes]
|
||||||
|
buf = bytes.NewBuffer(dataSlice[:0:nbytes])
|
||||||
|
|
||||||
|
return dataPtr, buf
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeTensor loads tensor data to C memory and returns a C pointer.
|
||||||
func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
||||||
switch v.Kind() {
|
switch v.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
|
@ -37,7 +71,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
||||||
if err := w.WriteByte(b); err != nil {
|
if err := w.WriteByte(b); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||||
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
|
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -55,7 +89,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
||||||
// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
|
// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
|
||||||
if len(shape) == 1 && v.Len() > 0 {
|
if len(shape) == 1 && v.Len() > 0 {
|
||||||
switch v.Index(0).Kind() {
|
switch v.Index(0).Kind() {
|
||||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||||
return binary.Write(w, nativeEndian, v.Interface())
|
return binary.Write(w, nativeEndian, v.Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -74,8 +108,8 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecodeTensor decodes the Tensor from the buffer to ptr using the format
|
// DecodeTensor decodes tensor value from a C memory buffer given
|
||||||
// specified in c_api.h. Use stringDecoder for String tensors.
|
// C pointer, data type and shape and returns data value of type interface
|
||||||
func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
|
func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
|
||||||
switch typ.Kind() {
|
switch typ.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
|
@ -84,7 +118,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ptr.Elem().SetBool(b == 1)
|
ptr.Elem().SetBool(b == 1)
|
||||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||||
if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
|
if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -96,7 +130,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
|
||||||
// Optimization: if only one dimension is left we can use binary.Read() directly for this slice
|
// Optimization: if only one dimension is left we can use binary.Read() directly for this slice
|
||||||
if len(shape) == 1 && val.Len() > 0 {
|
if len(shape) == 1 && val.Len() > 0 {
|
||||||
switch val.Index(0).Kind() {
|
switch val.Index(0).Kind() {
|
||||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||||
return binary.Read(r, nativeEndian, val.Interface())
|
return binary.Read(r, nativeEndian, val.Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -113,47 +147,11 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func numElements(shape []int64) int64 {
|
// ElementCount counts number of element in the tensor given a shape
|
||||||
|
func ElementCount(shape []int64) int64 {
|
||||||
n := int64(1)
|
n := int64(1)
|
||||||
for _, d := range shape {
|
for _, d := range shape {
|
||||||
n *= d
|
n *= d
|
||||||
}
|
}
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetKind returns data type `Kind` (a element of tensor can hold)
|
|
||||||
// v - a value of a data element
|
|
||||||
func GetKind(v interface{}) (retVal gotch.Kind, err error) {
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case reflect.TypeOf(v) == int:
|
|
||||||
retVal = gotch.Int
|
|
||||||
case reflect.TypeOf(v) == uint8:
|
|
||||||
retVal = gotch.Uint8
|
|
||||||
|
|
||||||
default:
|
|
||||||
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return retVal, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// // TypeOf converts from a DType and Shape to the equivalent Go type.
|
|
||||||
// func TypeOf(dt DType, shape []int64) reflect.Type {
|
|
||||||
// var ret reflect.Type
|
|
||||||
// for _, t := range types {
|
|
||||||
// if dt == DType(t.dataType) {
|
|
||||||
// ret = t.typ
|
|
||||||
// break
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// if ret == nil {
|
|
||||||
// // TODO get tensor name
|
|
||||||
// panic(fmt.Sprintf("Unsupported DType %d", int(dt)))
|
|
||||||
// }
|
|
||||||
// for range shape {
|
|
||||||
// ret = reflect.SliceOf(ret)
|
|
||||||
// }
|
|
||||||
// return ret
|
|
||||||
// }
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user