WIP: restructure and WIP: kind
This commit is contained in:
parent
bbf8bface1
commit
773f423fff
135
kind.go
Normal file
135
kind.go
Normal file
|
@ -0,0 +1,135 @@
|
|||
package gotch
|
||||
|
||||
import (
|
||||
"log"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// CInt is equal to C type int. Go type is int32
|
||||
type CInt = int32
|
||||
|
||||
// Kind represents different kind of element that a tensor can hold.
|
||||
// It has an embedded `reflect.Type` for type reflection.
|
||||
type Kind struct {
|
||||
reflect.Type
|
||||
}
|
||||
|
||||
// TODO: double check these Torch DType to Go type
|
||||
var (
|
||||
Uint8 = Kind{reflect.TypeOf(uint8(1))} // 0
|
||||
Int8 = Kind{reflect.TypeOf(int8(1))} // 1
|
||||
Int16 = Kind{reflect.TypeOf(int16(1))} // 2
|
||||
Int = Kind{reflect.TypeOf(int(1))} // 3
|
||||
Int64 = Kind{reflect.TypeOf(int64(1))} // 4
|
||||
Half = Kind{reflect.TypeOf(float32(1))} // 5
|
||||
Float = Kind{reflect.TypeOf(float64(1))} // 6
|
||||
Double = Kind{reflect.TypeOf(float64(1))} // 7
|
||||
ComplexHalf = kind{reflect.TypeOf(complex(1))} // 8
|
||||
ComplexFloat = Kind{reflect.TypeOf(complex64(1))} // 9
|
||||
ComplexDouble = kind{reflect.TypeOf(complex128(1))} // 10
|
||||
Bool = kind{reflect.TypeOf(true)} // 11
|
||||
)
|
||||
|
||||
// ToCInt converts Kind to CInt type value which is `C int`
|
||||
func (k Kind) ToCInt() CInt {
|
||||
switch {
|
||||
case k.Kind() == uint8:
|
||||
return 0
|
||||
case k.Kind() == int8:
|
||||
return 1
|
||||
case k.Kind() == int16:
|
||||
return 2
|
||||
case k.Kind() == int:
|
||||
return 3
|
||||
case k.Kind() == int64:
|
||||
return 4
|
||||
case k.Kind() == float32:
|
||||
return 5
|
||||
default:
|
||||
log.Fatalf("Unsupported type.")
|
||||
}
|
||||
|
||||
// unreachable
|
||||
return CInt(-1)
|
||||
}
|
||||
|
||||
// OfCInt converts a value of type CInt to Kind type value
|
||||
func (k Kind) OfCInt(v CInt) Kind {
|
||||
switch v {
|
||||
case 0:
|
||||
return Uint8
|
||||
case 1:
|
||||
return Int8
|
||||
case 2:
|
||||
return Int16
|
||||
case 3:
|
||||
return Int
|
||||
case 4:
|
||||
return Int64
|
||||
case 5:
|
||||
return Half
|
||||
case 6:
|
||||
return Float
|
||||
case 7:
|
||||
return Double
|
||||
case 8:
|
||||
return ComplexHalf
|
||||
case 9:
|
||||
return ComplexFloat
|
||||
case 10:
|
||||
return ComplexDouble
|
||||
case 11:
|
||||
return Bool
|
||||
default:
|
||||
log.Fatalf("Unexpected kind %v\n", v)
|
||||
}
|
||||
return Kind{reflect.TypeOf(false)}
|
||||
}
|
||||
|
||||
// EltSizeInBytes converts a Kind value to number of bytes
|
||||
// This is a ELement Size In Byte in Libtorch.
|
||||
// Has it been deprecated?
|
||||
func (k Kind) EltSizeInBytes() uint {
|
||||
switch {
|
||||
case k.ToCInt() == int32(Uint8):
|
||||
return 1
|
||||
case k.ToCInt() == int32(Int8):
|
||||
return 1
|
||||
case k.ToCInt() == int32(Int16):
|
||||
return 2
|
||||
case k.ToCInt() == int32(Int):
|
||||
return 4
|
||||
case k.ToCInt() == int32(Int64):
|
||||
return 8
|
||||
case k.ToCInt() == int32(Half):
|
||||
return 2
|
||||
case k.ToCInt() == int32(Float):
|
||||
return 4
|
||||
case k.ToCInt() == int32(Double):
|
||||
return 8
|
||||
case k.ToCInt() == int32(ComplexHalf):
|
||||
return 4
|
||||
case k.ToCInt() == int32(ComplexDouble):
|
||||
return 16
|
||||
case k.ToCInt() == int32(Bool):
|
||||
return 1
|
||||
default:
|
||||
log.Fatalf("Unreachable")
|
||||
}
|
||||
return uint(0)
|
||||
}
|
||||
|
||||
type KindDevice struct {
|
||||
Kind Kind
|
||||
Device Device
|
||||
}
|
||||
|
||||
var (
|
||||
FloatCPU KindDevice = KindDevice{Float, CPU}
|
||||
DoubleCPU KindDevice = KindDevice{Double, CPU}
|
||||
Int64CPU KindDevice = KindDevice{Int64, CPU}
|
||||
|
||||
FloatCUDA KindDevice = KindDevice{Float, CudaBuilder(0)}
|
||||
DoubleCUDA KindDevice = KindDevice{Double, CudaBuilder(0)}
|
||||
Int64CUDA KindDevice = KindDevice{Int64, CudaBuilder(0)}
|
||||
)
|
|
@ -18,7 +18,7 @@ type C_tensor struct {
|
|||
private unsafe.Pointer
|
||||
}
|
||||
|
||||
func NewTensor() *C_tensor {
|
||||
func AtNewTensor() *C_tensor {
|
||||
t := C.at_new_tensor()
|
||||
return &C_tensor{private: unsafe.Pointer(t)}
|
||||
}
|
||||
|
@ -31,11 +31,6 @@ func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_byt
|
|||
c_elt_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&elt_size_in_bytes))
|
||||
c_kind := *(*C.int)(unsafe.Pointer(&kind))
|
||||
|
||||
// c_dims := (*C.long)(unsafe.Pointer(uintptr(dims)))
|
||||
// c_ndims := *(*C.size_t)(unsafe.Pointer(uintptr(ndims)))
|
||||
// c_elt_size_in_bytes := *(*C.size_t)(unsafe.Pointer(uintptr(elt_size_in_bytes)))
|
||||
// c_kind := *(*C.int)(unsafe.Pointer(uintptr(kind)))
|
||||
|
||||
// t is of type `unsafe.Pointer` in Go and `*void` in C
|
||||
t := C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
|
||||
fmt.Printf("t type: %v\n", reflect.TypeOf(t).Kind())
|
||||
|
|
147
tensor/kind.go
147
tensor/kind.go
|
@ -1,147 +0,0 @@
|
|||
package tensor
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
gotch "github.com/sugarme/gotch"
|
||||
)
|
||||
|
||||
// CInt is equal to C type int. Go type is int32
|
||||
type CInt = int32
|
||||
|
||||
// Kind is 'enum' like type. It represents different kind of elements
|
||||
// that a Tensor can hold.
|
||||
type Kind int
|
||||
|
||||
const (
|
||||
Uint8 Kind = iota // 0
|
||||
Int8 // 1
|
||||
Int16 // 2
|
||||
Int // 3
|
||||
Int64 // 4
|
||||
Half // 5
|
||||
Float // 6
|
||||
Double // 7
|
||||
ComplexHalf // 8
|
||||
ComplexFloat // 9
|
||||
ComplexDouble // 10
|
||||
Bool // 11
|
||||
)
|
||||
|
||||
// ToCInt converts Kind to CInt type value which is `C int`
|
||||
func (k Kind) ToCInt() CInt {
|
||||
return CInt(k)
|
||||
}
|
||||
|
||||
// OfCInt converts a value of type CInt to Kind type value
|
||||
func (k Kind) OfCInt(v CInt) Kind {
|
||||
switch v {
|
||||
case 0:
|
||||
return Uint8
|
||||
case 1:
|
||||
return Int8
|
||||
case 2:
|
||||
return Int16
|
||||
case 3:
|
||||
return Int
|
||||
case 4:
|
||||
return Int64
|
||||
case 5:
|
||||
return Half
|
||||
case 6:
|
||||
return Float
|
||||
case 7:
|
||||
return Double
|
||||
case 8:
|
||||
return ComplexHalf
|
||||
case 9:
|
||||
return ComplexFloat
|
||||
case 10:
|
||||
return ComplexDouble
|
||||
case 11:
|
||||
return Bool
|
||||
default:
|
||||
log.Fatalf("Unexpected kind %v\n", v)
|
||||
}
|
||||
return Kind(0)
|
||||
}
|
||||
|
||||
// EltSizeInBytes converts a Kind value to number of bytes
|
||||
// This is a ELement Size In Byte in Libtorch.
|
||||
// Has it been deprecated?
|
||||
func (k Kind) EltSizeInBytes() uint {
|
||||
switch {
|
||||
case k.ToCInt() == int32(Uint8):
|
||||
return 1
|
||||
case k.ToCInt() == int32(Int8):
|
||||
return 1
|
||||
case k.ToCInt() == int32(Int16):
|
||||
return 2
|
||||
case k.ToCInt() == int32(Int):
|
||||
return 4
|
||||
case k.ToCInt() == int32(Int64):
|
||||
return 8
|
||||
case k.ToCInt() == int32(Half):
|
||||
return 2
|
||||
case k.ToCInt() == int32(Float):
|
||||
return 4
|
||||
case k.ToCInt() == int32(Double):
|
||||
return 8
|
||||
case k.ToCInt() == int32(ComplexHalf):
|
||||
return 4
|
||||
case k.ToCInt() == int32(ComplexDouble):
|
||||
return 16
|
||||
case k.ToCInt() == int32(Bool):
|
||||
return 1
|
||||
default:
|
||||
log.Fatalf("Unreachable")
|
||||
}
|
||||
return uint(0)
|
||||
}
|
||||
|
||||
type KindDevice struct {
|
||||
Kind Kind
|
||||
Device gotch.Device
|
||||
}
|
||||
|
||||
var (
|
||||
FloatCPU KindDevice = KindDevice{Float, gotch.CPU}
|
||||
DoubleCPU KindDevice = KindDevice{Double, gotch.CPU}
|
||||
Int64CPU KindDevice = KindDevice{Int64, gotch.CPU}
|
||||
|
||||
FloatCUDA KindDevice = KindDevice{Float, gotch.CudaBuilder(0)}
|
||||
DoubleCUDA KindDevice = KindDevice{Double, gotch.CudaBuilder(0)}
|
||||
Int64CUDA KindDevice = KindDevice{Int64, gotch.CudaBuilder(0)}
|
||||
)
|
||||
|
||||
type KindTrait interface {
|
||||
GetKind() Kind
|
||||
}
|
||||
|
||||
type KindUint8 struct{}
|
||||
|
||||
func (k KindUint8) GetKind() Kind { return Uint8 }
|
||||
|
||||
type KindInt8 struct{}
|
||||
|
||||
func (k KindInt8) GetKind() Kind { return Int8 }
|
||||
|
||||
type KindInt16 struct{}
|
||||
|
||||
func (k KindInt16) GetKind() Kind { return Int16 }
|
||||
|
||||
type KindInt64 struct{}
|
||||
|
||||
func (k KindInt64) GetKind() Kind { return Int64 }
|
||||
|
||||
type KindFloat32 struct{}
|
||||
|
||||
func (k KindFloat32) GetKind() Kind { return Float }
|
||||
|
||||
type KindFloat64 struct{}
|
||||
|
||||
func (k KindFloat64) GetKind() Kind { return Double }
|
||||
|
||||
type KindBool struct{}
|
||||
|
||||
func (k KindBool) GetKind() Kind { return Bool }
|
68
wrapper/tensor.go
Normal file
68
wrapper/tensor.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package wrapper
|
||||
|
||||
//#include <stdlib.h>
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
||||
lib "github.com/sugarme/gotch/libtch"
|
||||
)
|
||||
|
||||
type Tensor struct {
|
||||
ctensor *t.C_tensor
|
||||
}
|
||||
|
||||
// NewTensor creates a new tensor
|
||||
func NewTensor() Tensor {
|
||||
ctensor := lib.AtNewTensor()
|
||||
return Tensor{ctensor}
|
||||
}
|
||||
|
||||
// FOfSlice creates tensor from a slice data
|
||||
func(ts Tensor) FOfSlice(data []inteface{}) (retVal Tensor, err error) {
|
||||
|
||||
data := []int{0, 0, 0, 0}
|
||||
shape := []int64{int64(len(data))}
|
||||
nflattened := numElements(shape)
|
||||
dtype := 3 // Kind.Int
|
||||
eltSizeInBytes := 4 // Element Size in Byte for Int dtype
|
||||
|
||||
nbytes := eltSizeInBytes * int(uintptr(nflattened))
|
||||
|
||||
// NOTE: dataPrt is type of `*void` in C or type of `unsafe.Pointer` in Go
|
||||
// data should be allocated to memory BY `C` side
|
||||
dataPtr := C.malloc(C.size_t(nbytes))
|
||||
|
||||
// Recall: 1 << 30 = 1 * 2 * 30
|
||||
// Ref. See more at https://stackoverflow.com/questions/48756732
|
||||
dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes]
|
||||
|
||||
buf := bytes.NewBuffer(dataSlice[:0:nbytes])
|
||||
|
||||
EncodeTensor(buf, reflect.ValueOf(data), shape)
|
||||
|
||||
c_tensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(dtype))
|
||||
|
||||
retVal = Tensor{c_tensor}
|
||||
|
||||
// Read back created tensor values by C libtorch
|
||||
readDataPtr := lib.AtDataPtr(retVal.c_tensor)
|
||||
readDataSlice := (*[1 << 30]byte)(readDataPtr)[:nbytes:nbytes]
|
||||
// typ := typeOf(dtype, shape)
|
||||
typ := reflect.TypeOf(int32(0)) // C. type `int` ~ Go type `int32`
|
||||
val := reflect.New(typ)
|
||||
if err := DecodeTensor(bytes.NewReader(readDataSlice), shape, typ, val); err != nil {
|
||||
panic(fmt.Sprintf("unable to decode Tensor of type %v and shape %v - %v", dtype, shape, err))
|
||||
}
|
||||
|
||||
tensorData := reflect.Indirect(val).Interface()
|
||||
|
||||
fmt.Println("%v", tensorData)
|
||||
|
||||
return retVal, nil
|
||||
}
|
|
@ -1,23 +1,16 @@
|
|||
package tensor
|
||||
|
||||
//#include <stdlib.h>
|
||||
import "C"
|
||||
package wrapper
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
// "runtime"
|
||||
"unsafe"
|
||||
|
||||
lib "github.com/sugarme/gotch/libtch"
|
||||
gotch "github.com/sugarme/gotch"
|
||||
)
|
||||
|
||||
type Tensor struct {
|
||||
ctensor *t.C_tensor
|
||||
}
|
||||
|
||||
var nativeEndian binary.ByteOrder
|
||||
|
||||
func init() {
|
||||
|
@ -34,58 +27,7 @@ func init() {
|
|||
}
|
||||
}
|
||||
|
||||
// FnOfSlice creates tensor from a slice data
|
||||
func FnOfSlice() (retVal Tensor, err error) {
|
||||
|
||||
data := []int{0, 0, 0, 0}
|
||||
shape := []int64{int64(len(data))}
|
||||
nflattened := numElements(shape)
|
||||
dtype := 3 // Kind.Int
|
||||
eltSizeInBytes := 4 // Element Size in Byte for Int dtype
|
||||
|
||||
nbytes := eltSizeInBytes * int(uintptr(nflattened))
|
||||
|
||||
// NOTE: dataPrt is type of `*void` in C or type of `unsafe.Pointer` in Go
|
||||
dataPtr := C.malloc(C.size_t(nbytes))
|
||||
|
||||
// Recall: 1 << 30 = 1 * 2 * 30
|
||||
// Ref. See more at https://stackoverflow.com/questions/48756732
|
||||
dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes]
|
||||
|
||||
buf := bytes.NewBuffer(dataSlice[:0:nbytes])
|
||||
|
||||
encodeTensor(buf, reflect.ValueOf(data), shape)
|
||||
|
||||
c_tensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(dtype))
|
||||
|
||||
retVal = Tensor{c_tensor}
|
||||
|
||||
// Read back created tensor values by C libtorch
|
||||
readDataPtr := lib.AtDataPtr(retVal.c_tensor)
|
||||
readDataSlice := (*[1 << 30]byte)(readDataPtr)[:nbytes:nbytes]
|
||||
// typ := typeOf(dtype, shape)
|
||||
typ := reflect.TypeOf(int32(0)) // C. type `int` ~ Go type `int32`
|
||||
val := reflect.New(typ)
|
||||
if err := decodeTensor(bytes.NewReader(readDataSlice), shape, typ, val); err != nil {
|
||||
panic(fmt.Sprintf("unable to decode Tensor of type %v and shape %v - %v", dtype, shape, err))
|
||||
}
|
||||
|
||||
tensorData := reflect.Indirect(val).Interface()
|
||||
|
||||
fmt.Println("%v", tensorData)
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func numElements(shape []int64) int64 {
|
||||
n := int64(1)
|
||||
for _, d := range shape {
|
||||
n *= d
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
||||
func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
||||
switch v.Kind() {
|
||||
case reflect.Bool:
|
||||
b := byte(0)
|
||||
|
@ -120,7 +62,7 @@ func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
|||
|
||||
subShape := shape[1:]
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
err := encodeTensor(w, v.Index(i), subShape)
|
||||
err := EncodeTensor(w, v.Index(i), subShape)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -132,9 +74,9 @@ func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// decodeTensor decodes the Tensor from the buffer to ptr using the format
|
||||
// DecodeTensor decodes the Tensor from the buffer to ptr using the format
|
||||
// specified in c_api.h. Use stringDecoder for String tensors.
|
||||
func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
|
||||
func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
|
||||
switch typ.Kind() {
|
||||
case reflect.Bool:
|
||||
b, err := r.ReadByte()
|
||||
|
@ -160,7 +102,7 @@ func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
|
|||
}
|
||||
|
||||
for i := 0; i < val.Len(); i++ {
|
||||
if err := decodeTensor(r, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil {
|
||||
if err := DecodeTensor(r, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -171,8 +113,34 @@ func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
|
|||
return nil
|
||||
}
|
||||
|
||||
// // typeOf converts from a DType and Shape to the equivalent Go type.
|
||||
// func typeOf(dt DType, shape []int64) reflect.Type {
|
||||
func numElements(shape []int64) int64 {
|
||||
n := int64(1)
|
||||
for _, d := range shape {
|
||||
n *= d
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// GetKind returns data type `Kind` (a element of tensor can hold)
|
||||
// v - a value of a data element
|
||||
func GetKind(v interface{}) (retVal gotch.Kind, err error) {
|
||||
|
||||
switch {
|
||||
case reflect.TypeOf(v) == int:
|
||||
retVal = gotch.Int
|
||||
case reflect.TypeOf(v) == uint8:
|
||||
retVal = gotch.Uint8
|
||||
|
||||
default:
|
||||
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// // TypeOf converts from a DType and Shape to the equivalent Go type.
|
||||
// func TypeOf(dt DType, shape []int64) reflect.Type {
|
||||
// var ret reflect.Type
|
||||
// for _, t := range types {
|
||||
// if dt == DType(t.dataType) {
|
Loading…
Reference in New Issue
Block a user