feat(dtype), WIP(wrapper), example/tensor
This commit is contained in:
parent
773f423fff
commit
6b0d6105ae
250
dtype.go
Normal file
250
dtype.go
Normal file
|
@ -0,0 +1,250 @@
|
|||
package gotch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
// "log"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// CInt is equal to C type int. Go type is int32
|
||||
type CInt = int32
|
||||
|
||||
// DType represents different kind of element that a tensor can hold.
|
||||
// It has an embedded `reflect.Type` for type reflection.
|
||||
type DType struct {
|
||||
reflect.Type
|
||||
}
|
||||
|
||||
/*
|
||||
* // Custom-made Float16 as not exist in Go
|
||||
* // Ref: https://github.com/golang/go/issues/32022
|
||||
* type GoFloat16 = int16 // not implemented yet
|
||||
* type GoComplexHalf = interface{} // not implemented yet!
|
||||
* */
|
||||
|
||||
// TODO: double check these Torch DType to Go type
|
||||
var (
|
||||
Uint8 DType = DType{reflect.TypeOf(uint8(1))} // 0
|
||||
Int8 DType = DType{reflect.TypeOf(int8(1))} // 1
|
||||
Int16 DType = DType{reflect.TypeOf(int16(1))} // 2
|
||||
Int DType = DType{reflect.TypeOf(int32(1))} // 3
|
||||
Int64 DType = DType{reflect.TypeOf(int64(1))} // 4
|
||||
// Half DType = DType{reflect.TypeOf(GoFloat16(1))} // 5
|
||||
Float DType = DType{reflect.TypeOf(float32(1))} // 6
|
||||
Double DType = DType{reflect.TypeOf(float64(1))} // 7
|
||||
// ComplexHalf DType = DType{reflect.TypeOf(GoComplexHalf(1))} // 8
|
||||
// ComplexFloat DType = DType{reflect.TypeOf(complex64(1))} // 9
|
||||
// ComplexDouble DType = DType{reflect.TypeOf(complex128(1))} // 10
|
||||
Bool DType = DType{reflect.TypeOf(true)} // 11
|
||||
)
|
||||
|
||||
/*
|
||||
* // ToCInt converts DType to CInt type value which is `C int`
|
||||
* func (dt DType) ToCInt() CInt {
|
||||
* switch dt.Kind() {
|
||||
* case reflect.Uint8:
|
||||
* return 0
|
||||
* case reflect.Int8:
|
||||
* return 1
|
||||
* case reflect.Int16:
|
||||
* return 2
|
||||
* case reflect.Int32:
|
||||
* return 3
|
||||
* case reflect.Int64:
|
||||
* return 4
|
||||
* case reflect.Float32:
|
||||
* return 6
|
||||
* case reflect.Float64:
|
||||
* return 7
|
||||
* case reflect.Bool:
|
||||
* return 11
|
||||
* default:
|
||||
* log.Fatalf("Unsupported type.")
|
||||
* }
|
||||
*
|
||||
* // unreachable
|
||||
* return CInt(-1)
|
||||
* }
|
||||
*
|
||||
* // OfCInt converts a value of type CInt to DType type value
|
||||
* func (dt DType) OfCInt(v CInt) DType {
|
||||
* switch v {
|
||||
* case 0:
|
||||
* return Uint8
|
||||
* case 1:
|
||||
* return Int8
|
||||
* case 2:
|
||||
* return Int16
|
||||
* case 3:
|
||||
* return Int
|
||||
* case 4:
|
||||
* return Int64
|
||||
* case 6:
|
||||
* return Float
|
||||
* case 7:
|
||||
* return Double
|
||||
* case 8:
|
||||
* case 11:
|
||||
* return Bool
|
||||
* default:
|
||||
* log.Fatalf("Unexpected DType %v\n", v)
|
||||
* }
|
||||
* return DType{reflect.TypeOf(false)}
|
||||
* }
|
||||
*
|
||||
* // EltSizeInBytes converts a DType value to number of bytes
|
||||
* // This is a ELement Size In Bytes in Libtorch.
|
||||
* // Has it been deprecated?
|
||||
* func (dt DType) EltSizeInBytes() uint {
|
||||
* switch dt.Kind() {
|
||||
* case reflect.Uint8:
|
||||
* return 1
|
||||
* case reflect.Int8:
|
||||
* return 1
|
||||
* case reflect.Int16:
|
||||
* return 2
|
||||
* case reflect.Int:
|
||||
* return 4
|
||||
* case reflect.Int64:
|
||||
* return 8
|
||||
* case reflect.Float32:
|
||||
* return 4
|
||||
* case reflect.Float64:
|
||||
* return 8
|
||||
* case reflect.Bool:
|
||||
* return 1
|
||||
* default:
|
||||
* log.Fatalf("Unsupported Type %v\n", dt.Type)
|
||||
* }
|
||||
* return uint(0)
|
||||
* }
|
||||
* */
|
||||
|
||||
// ToGoType converts DType to Go type
|
||||
func (dt DType) ToGoType() reflect.Type {
|
||||
return dt.Type
|
||||
}
|
||||
|
||||
var dtypeCInt = map[DType]CInt{
|
||||
Uint8: 0,
|
||||
Int8: 1,
|
||||
Int16: 2,
|
||||
Int: 3,
|
||||
Int64: 4,
|
||||
Float: 6,
|
||||
Double: 7,
|
||||
Bool: 11,
|
||||
}
|
||||
|
||||
func DType2CInt(dt DType) CInt {
|
||||
return dtypeCInt[dt]
|
||||
}
|
||||
|
||||
func CInt2DType(v CInt) (dtype DType, err error) {
|
||||
var found = false
|
||||
for key, val := range dtypeCInt {
|
||||
if val == v {
|
||||
dtype = key
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
err = fmt.Errorf("Unsuported DType for CInt %v\n", v)
|
||||
return DType{}, err
|
||||
}
|
||||
|
||||
return dtype, nil
|
||||
|
||||
}
|
||||
|
||||
// dtypeSize is a map of DType and its size in Bytes
|
||||
var dtypeSize = map[DType]uint{
|
||||
Uint8: 1,
|
||||
Int8: 1,
|
||||
Int16: 2,
|
||||
Int: 4,
|
||||
Int64: 8,
|
||||
Float: 4,
|
||||
Double: 8,
|
||||
Bool: 1,
|
||||
}
|
||||
|
||||
// DTypeSize returns DType size in Bytes
|
||||
func DTypeSize(dt DType) uint {
|
||||
return dtypeSize[dt]
|
||||
}
|
||||
|
||||
type DTypeDevice struct {
|
||||
DType DType
|
||||
Device Device
|
||||
}
|
||||
|
||||
var (
|
||||
FloatCPU DTypeDevice = DTypeDevice{Float, CPU}
|
||||
DoubleCPU DTypeDevice = DTypeDevice{Double, CPU}
|
||||
Int64CPU DTypeDevice = DTypeDevice{Int64, CPU}
|
||||
|
||||
FloatCUDA DTypeDevice = DTypeDevice{Float, CudaBuilder(0)}
|
||||
DoubleCUDA DTypeDevice = DTypeDevice{Double, CudaBuilder(0)}
|
||||
Int64CUDA DTypeDevice = DTypeDevice{Int64, CudaBuilder(0)}
|
||||
)
|
||||
|
||||
// Type Inferring:
|
||||
// ===============
|
||||
|
||||
// DataDType infers and returns data type of tensor data
|
||||
func DataDType(v interface{}, shape []int64) (retVal DType, err error) {
|
||||
// assuming that all elements in data have the same type
|
||||
switch {
|
||||
case len(shape) == 0:
|
||||
retVal, err = ElementDType(v)
|
||||
case len(shape) > 0:
|
||||
return ElementDType(v.([]interface{})[0])
|
||||
default:
|
||||
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
|
||||
return DType{}, err
|
||||
}
|
||||
return DType{}, nil
|
||||
}
|
||||
|
||||
// ElementDType infers and returns its own tensor data type
|
||||
func ElementDType(v interface{}) (retVal DType, err error) {
|
||||
switch v.(type) {
|
||||
case uint8:
|
||||
retVal = Uint8
|
||||
case int8:
|
||||
retVal = Int8
|
||||
case int16:
|
||||
retVal = Int16
|
||||
case int32:
|
||||
retVal = Int
|
||||
case int64:
|
||||
retVal = Int64
|
||||
case float32:
|
||||
retVal = Float
|
||||
case float64:
|
||||
retVal = Double
|
||||
case bool:
|
||||
retVal = Bool
|
||||
default:
|
||||
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// TypeOf infers and returns element Go type from given tensor DType and shape
|
||||
func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) {
|
||||
typ := dt.ToGoType()
|
||||
|
||||
switch {
|
||||
case len(shape) == 0:
|
||||
return typ, nil
|
||||
case len(shape) > 0:
|
||||
return reflect.SliceOf(typ), nil
|
||||
default:
|
||||
err = fmt.Errorf("Unsupported data type.")
|
||||
return nil, err
|
||||
}
|
||||
}
|
|
@ -1,15 +1,25 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
tensor "github.com/sugarme/gotch/tensor"
|
||||
gotch "github.com/sugarme/gotch"
|
||||
wrapper "github.com/sugarme/gotch/wrapper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
_, err := tensor.FnOfSlice()
|
||||
|
||||
// TODO: Check Go type of data and tensor DType
|
||||
// For. if data is []int and DType is Bool
|
||||
// It is still running but get wrong result.
|
||||
data := []bool{true, true, false}
|
||||
dtype := gotch.Bool
|
||||
|
||||
ts := wrapper.NewTensor()
|
||||
sliceTensor, err := ts.FOfSlice(data, dtype)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
sliceTensor.Print()
|
||||
}
|
||||
|
|
135
kind.go
135
kind.go
|
@ -1,135 +0,0 @@
|
|||
package gotch
|
||||
|
||||
import (
|
||||
"log"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// CInt is equal to C type int. Go type is int32
|
||||
type CInt = int32
|
||||
|
||||
// Kind represents different kind of element that a tensor can hold.
|
||||
// It has an embedded `reflect.Type` for type reflection.
|
||||
type Kind struct {
|
||||
reflect.Type
|
||||
}
|
||||
|
||||
// TODO: double check these Torch DType to Go type
|
||||
var (
|
||||
Uint8 = Kind{reflect.TypeOf(uint8(1))} // 0
|
||||
Int8 = Kind{reflect.TypeOf(int8(1))} // 1
|
||||
Int16 = Kind{reflect.TypeOf(int16(1))} // 2
|
||||
Int = Kind{reflect.TypeOf(int(1))} // 3
|
||||
Int64 = Kind{reflect.TypeOf(int64(1))} // 4
|
||||
Half = Kind{reflect.TypeOf(float32(1))} // 5
|
||||
Float = Kind{reflect.TypeOf(float64(1))} // 6
|
||||
Double = Kind{reflect.TypeOf(float64(1))} // 7
|
||||
ComplexHalf = kind{reflect.TypeOf(complex(1))} // 8
|
||||
ComplexFloat = Kind{reflect.TypeOf(complex64(1))} // 9
|
||||
ComplexDouble = kind{reflect.TypeOf(complex128(1))} // 10
|
||||
Bool = kind{reflect.TypeOf(true)} // 11
|
||||
)
|
||||
|
||||
// ToCInt converts Kind to CInt type value which is `C int`
|
||||
func (k Kind) ToCInt() CInt {
|
||||
switch {
|
||||
case k.Kind() == uint8:
|
||||
return 0
|
||||
case k.Kind() == int8:
|
||||
return 1
|
||||
case k.Kind() == int16:
|
||||
return 2
|
||||
case k.Kind() == int:
|
||||
return 3
|
||||
case k.Kind() == int64:
|
||||
return 4
|
||||
case k.Kind() == float32:
|
||||
return 5
|
||||
default:
|
||||
log.Fatalf("Unsupported type.")
|
||||
}
|
||||
|
||||
// unreachable
|
||||
return CInt(-1)
|
||||
}
|
||||
|
||||
// OfCInt converts a value of type CInt to Kind type value
|
||||
func (k Kind) OfCInt(v CInt) Kind {
|
||||
switch v {
|
||||
case 0:
|
||||
return Uint8
|
||||
case 1:
|
||||
return Int8
|
||||
case 2:
|
||||
return Int16
|
||||
case 3:
|
||||
return Int
|
||||
case 4:
|
||||
return Int64
|
||||
case 5:
|
||||
return Half
|
||||
case 6:
|
||||
return Float
|
||||
case 7:
|
||||
return Double
|
||||
case 8:
|
||||
return ComplexHalf
|
||||
case 9:
|
||||
return ComplexFloat
|
||||
case 10:
|
||||
return ComplexDouble
|
||||
case 11:
|
||||
return Bool
|
||||
default:
|
||||
log.Fatalf("Unexpected kind %v\n", v)
|
||||
}
|
||||
return Kind{reflect.TypeOf(false)}
|
||||
}
|
||||
|
||||
// EltSizeInBytes converts a Kind value to number of bytes
|
||||
// This is a ELement Size In Byte in Libtorch.
|
||||
// Has it been deprecated?
|
||||
func (k Kind) EltSizeInBytes() uint {
|
||||
switch {
|
||||
case k.ToCInt() == int32(Uint8):
|
||||
return 1
|
||||
case k.ToCInt() == int32(Int8):
|
||||
return 1
|
||||
case k.ToCInt() == int32(Int16):
|
||||
return 2
|
||||
case k.ToCInt() == int32(Int):
|
||||
return 4
|
||||
case k.ToCInt() == int32(Int64):
|
||||
return 8
|
||||
case k.ToCInt() == int32(Half):
|
||||
return 2
|
||||
case k.ToCInt() == int32(Float):
|
||||
return 4
|
||||
case k.ToCInt() == int32(Double):
|
||||
return 8
|
||||
case k.ToCInt() == int32(ComplexHalf):
|
||||
return 4
|
||||
case k.ToCInt() == int32(ComplexDouble):
|
||||
return 16
|
||||
case k.ToCInt() == int32(Bool):
|
||||
return 1
|
||||
default:
|
||||
log.Fatalf("Unreachable")
|
||||
}
|
||||
return uint(0)
|
||||
}
|
||||
|
||||
type KindDevice struct {
|
||||
Kind Kind
|
||||
Device Device
|
||||
}
|
||||
|
||||
var (
|
||||
FloatCPU KindDevice = KindDevice{Float, CPU}
|
||||
DoubleCPU KindDevice = KindDevice{Double, CPU}
|
||||
Int64CPU KindDevice = KindDevice{Int64, CPU}
|
||||
|
||||
FloatCUDA KindDevice = KindDevice{Float, CudaBuilder(0)}
|
||||
DoubleCUDA KindDevice = KindDevice{Double, CudaBuilder(0)}
|
||||
Int64CUDA KindDevice = KindDevice{Int64, CudaBuilder(0)}
|
||||
)
|
|
@ -5,8 +5,8 @@ package libtch
|
|||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
// "fmt"
|
||||
// "reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
|
@ -33,27 +33,27 @@ func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_byt
|
|||
|
||||
// t is of type `unsafe.Pointer` in Go and `*void` in C
|
||||
t := C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
|
||||
fmt.Printf("t type: %v\n", reflect.TypeOf(t).Kind())
|
||||
fmt.Printf("1. C.tensor AtTensorOfData returned from C call: %v\n", t)
|
||||
// fmt.Printf("t type: %v\n", reflect.TypeOf(t).Kind())
|
||||
// fmt.Printf("1. C.tensor AtTensorOfData returned from C call: %v\n", t)
|
||||
// Keep C pointer value tin Go struct
|
||||
cTensorPtrVal := unsafe.Pointer(t)
|
||||
fmt.Printf("2. cTensorPtrVal: %v\n", cTensorPtrVal)
|
||||
// fmt.Printf("2. cTensorPtrVal: %v\n", cTensorPtrVal)
|
||||
|
||||
var retVal *C_tensor
|
||||
retVal = &C_tensor{private: cTensorPtrVal}
|
||||
fmt.Printf("3. C_tensor.private: %v\n", (*retVal).private)
|
||||
// fmt.Printf("3. C_tensor.private: %v\n", (*retVal).private)
|
||||
|
||||
// test call C.at_print to print out tensor
|
||||
// C.at_print(*(*C.tensor)(unsafe.Pointer(&t)))
|
||||
AtPrint(retVal)
|
||||
// AtPrint(retVal)
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func AtPrint(t *C_tensor) {
|
||||
fmt.Printf("4. C_tensor.private AtPrint: %v\n", (*t).private)
|
||||
// fmt.Printf("4. C_tensor.private AtPrint: %v\n", (*t).private)
|
||||
cTensor := (C.tensor)((*t).private)
|
||||
fmt.Printf("5. C.tensor AtPrint: %v\n", cTensor)
|
||||
// fmt.Printf("5. C.tensor AtPrint: %v\n", cTensor)
|
||||
|
||||
C.at_print(cTensor)
|
||||
}
|
||||
|
|
|
@ -1,20 +1,18 @@
|
|||
package wrapper
|
||||
|
||||
//#include <stdlib.h>
|
||||
// #include <stdlib.h>
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
// "fmt"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
||||
gotch "github.com/sugarme/gotch"
|
||||
lib "github.com/sugarme/gotch/libtch"
|
||||
)
|
||||
|
||||
type Tensor struct {
|
||||
ctensor *t.C_tensor
|
||||
ctensor *lib.C_tensor
|
||||
}
|
||||
|
||||
// NewTensor creates a new tensor
|
||||
|
@ -24,45 +22,43 @@ func NewTensor() Tensor {
|
|||
}
|
||||
|
||||
// FOfSlice creates tensor from a slice data
|
||||
func(ts Tensor) FOfSlice(data []inteface{}) (retVal Tensor, err error) {
|
||||
func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor, err error) {
|
||||
|
||||
data := []int{0, 0, 0, 0}
|
||||
shape := []int64{int64(len(data))}
|
||||
nflattened := numElements(shape)
|
||||
dtype := 3 // Kind.Int
|
||||
eltSizeInBytes := 4 // Element Size in Byte for Int dtype
|
||||
dataLen := reflect.ValueOf(data).Len()
|
||||
shape := []int64{int64(dataLen)}
|
||||
elementNum := ElementCount(shape)
|
||||
// eltSizeInBytes := dtype.EltSizeInBytes() // Element Size in Byte for Int dtype
|
||||
eltSizeInBytes := gotch.DTypeSize(dtype)
|
||||
|
||||
nbytes := eltSizeInBytes * int(uintptr(nflattened))
|
||||
nbytes := int(eltSizeInBytes) * int(elementNum)
|
||||
|
||||
// NOTE: dataPrt is type of `*void` in C or type of `unsafe.Pointer` in Go
|
||||
// data should be allocated to memory BY `C` side
|
||||
dataPtr := C.malloc(C.size_t(nbytes))
|
||||
dataPtr, buff := CMalloc(nbytes)
|
||||
|
||||
// Recall: 1 << 30 = 1 * 2 * 30
|
||||
// Ref. See more at https://stackoverflow.com/questions/48756732
|
||||
dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes]
|
||||
|
||||
buf := bytes.NewBuffer(dataSlice[:0:nbytes])
|
||||
|
||||
EncodeTensor(buf, reflect.ValueOf(data), shape)
|
||||
|
||||
c_tensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(dtype))
|
||||
|
||||
retVal = Tensor{c_tensor}
|
||||
|
||||
// Read back created tensor values by C libtorch
|
||||
readDataPtr := lib.AtDataPtr(retVal.c_tensor)
|
||||
readDataSlice := (*[1 << 30]byte)(readDataPtr)[:nbytes:nbytes]
|
||||
// typ := typeOf(dtype, shape)
|
||||
typ := reflect.TypeOf(int32(0)) // C. type `int` ~ Go type `int32`
|
||||
val := reflect.New(typ)
|
||||
if err := DecodeTensor(bytes.NewReader(readDataSlice), shape, typ, val); err != nil {
|
||||
panic(fmt.Sprintf("unable to decode Tensor of type %v and shape %v - %v", dtype, shape, err))
|
||||
if err = EncodeTensor(buff, reflect.ValueOf(data), shape); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tensorData := reflect.Indirect(val).Interface()
|
||||
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(gotch.DType2CInt(dtype)))
|
||||
|
||||
fmt.Println("%v", tensorData)
|
||||
retVal = &Tensor{ctensor}
|
||||
|
||||
// Read back created tensor values by C libtorch
|
||||
// readDataPtr := lib.AtDataPtr(retVal.ctensor)
|
||||
// readDataSlice := (*[1 << 30]byte)(readDataPtr)[:nbytes:nbytes]
|
||||
// // typ := typeOf(dtype, shape)
|
||||
// typ := reflect.TypeOf(int32(0)) // C. type `int` ~ Go type `int32`
|
||||
// val := reflect.New(typ)
|
||||
// if err := DecodeTensor(bytes.NewReader(readDataSlice), shape, typ, val); err != nil {
|
||||
// panic(fmt.Sprintf("unable to decode Tensor of type %v and shape %v - %v", dtype, shape, err))
|
||||
// }
|
||||
//
|
||||
// tensorData := reflect.Indirect(val).Interface()
|
||||
//
|
||||
// fmt.Println("%v", tensorData)
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) Print() {
|
||||
lib.AtPrint(ts.ctensor)
|
||||
}
|
||||
|
|
|
@ -1,16 +1,20 @@
|
|||
package wrapper
|
||||
|
||||
// #include <stdlib.h>
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
||||
gotch "github.com/sugarme/gotch"
|
||||
// gotch "github.com/sugarme/gotch"
|
||||
)
|
||||
|
||||
// nativeEndian is a ByteOrder for local platform.
|
||||
// Ref. https://stackoverflow.com/a/53286786
|
||||
// Ref. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/go/tensor.go#L488-L505
|
||||
var nativeEndian binary.ByteOrder
|
||||
|
||||
func init() {
|
||||
|
@ -27,6 +31,36 @@ func init() {
|
|||
}
|
||||
}
|
||||
|
||||
// CMalloc allocates a given number of bytes to C side memory.
|
||||
// It returns
|
||||
// - dataPtr: a C pointer type of `*void` (`unsafe.Pointer` in Go).
|
||||
// - buf : a Go pointer points to a given bytes of buffer (empty) in C memory
|
||||
// allocated by C waiting for writing data to.
|
||||
//
|
||||
// NOTE:
|
||||
// 1. Go pointer is a pointer to Go memory. C pointer is a pointer to C memory.
|
||||
// 2. General rule is Go code can use C pointers. Go code may pass Go pointer to C
|
||||
// provided that the Go memory to which it points does NOT contain any Go
|
||||
// pointers. BUT C code must not store any Go pointers in Go memory, even
|
||||
// temporarily.
|
||||
// 3. Some Go values contain Go pointers IMPLICITLY: strings, slices, maps,
|
||||
// channels and function values. Thus, pointers to these values should not be
|
||||
// passed to C side. Instead, data should be allocated to C memory and return a
|
||||
// C pointer to it using `C.malloc`.
|
||||
// Ref: https://github.com/golang/proposal/blob/master/design/12416-cgo-pointers.md
|
||||
func CMalloc(nbytes int) (dataPtr unsafe.Pointer, buf *bytes.Buffer) {
|
||||
|
||||
dataPtr = C.malloc(C.size_t(nbytes))
|
||||
|
||||
// Recall: 1 << 30 = 1 * 2 * 30
|
||||
// Ref. See more at https://stackoverflow.com/questions/48756732
|
||||
dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes]
|
||||
buf = bytes.NewBuffer(dataSlice[:0:nbytes])
|
||||
|
||||
return dataPtr, buf
|
||||
}
|
||||
|
||||
// EncodeTensor loads tensor data to C memory and returns a C pointer.
|
||||
func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
||||
switch v.Kind() {
|
||||
case reflect.Bool:
|
||||
|
@ -37,7 +71,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
|||
if err := w.WriteByte(b); err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -55,7 +89,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
|||
// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
|
||||
if len(shape) == 1 && v.Len() > 0 {
|
||||
switch v.Index(0).Kind() {
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||
return binary.Write(w, nativeEndian, v.Interface())
|
||||
}
|
||||
}
|
||||
|
@ -74,8 +108,8 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// DecodeTensor decodes the Tensor from the buffer to ptr using the format
|
||||
// specified in c_api.h. Use stringDecoder for String tensors.
|
||||
// DecodeTensor decodes tensor value from a C memory buffer given
|
||||
// C pointer, data type and shape and returns data value of type interface
|
||||
func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
|
||||
switch typ.Kind() {
|
||||
case reflect.Bool:
|
||||
|
@ -84,7 +118,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
|
|||
return err
|
||||
}
|
||||
ptr.Elem().SetBool(b == 1)
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||
if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -96,7 +130,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
|
|||
// Optimization: if only one dimension is left we can use binary.Read() directly for this slice
|
||||
if len(shape) == 1 && val.Len() > 0 {
|
||||
switch val.Index(0).Kind() {
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||
return binary.Read(r, nativeEndian, val.Interface())
|
||||
}
|
||||
}
|
||||
|
@ -113,47 +147,11 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
|
|||
return nil
|
||||
}
|
||||
|
||||
func numElements(shape []int64) int64 {
|
||||
// ElementCount counts number of element in the tensor given a shape
|
||||
func ElementCount(shape []int64) int64 {
|
||||
n := int64(1)
|
||||
for _, d := range shape {
|
||||
n *= d
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// GetKind returns data type `Kind` (a element of tensor can hold)
|
||||
// v - a value of a data element
|
||||
func GetKind(v interface{}) (retVal gotch.Kind, err error) {
|
||||
|
||||
switch {
|
||||
case reflect.TypeOf(v) == int:
|
||||
retVal = gotch.Int
|
||||
case reflect.TypeOf(v) == uint8:
|
||||
retVal = gotch.Uint8
|
||||
|
||||
default:
|
||||
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// // TypeOf converts from a DType and Shape to the equivalent Go type.
|
||||
// func TypeOf(dt DType, shape []int64) reflect.Type {
|
||||
// var ret reflect.Type
|
||||
// for _, t := range types {
|
||||
// if dt == DType(t.dataType) {
|
||||
// ret = t.typ
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
// if ret == nil {
|
||||
// // TODO get tensor name
|
||||
// panic(fmt.Sprintf("Unsupported DType %d", int(dt)))
|
||||
// }
|
||||
// for range shape {
|
||||
// ret = reflect.SliceOf(ret)
|
||||
// }
|
||||
// return ret
|
||||
// }
|
||||
|
|
Loading…
Reference in New Issue
Block a user