WIP: restructure and WIP: kind

This commit is contained in:
sugarme 2020-05-29 02:58:23 +10:00
parent bbf8bface1
commit 773f423fff
6 changed files with 241 additions and 222 deletions

View File

@ -1,4 +1,4 @@
package gorch
package gotch
import (
"log"

135
kind.go Normal file
View File

@ -0,0 +1,135 @@
package gotch
import (
"log"
"reflect"
)
// CInt is equal to C type int. Go type is int32
type CInt = int32
// Kind represents different kind of element that a tensor can hold.
// It has an embedded `reflect.Type` for type reflection.
type Kind struct {
reflect.Type
}
// TODO: double check these Torch DType to Go type
var (
Uint8 = Kind{reflect.TypeOf(uint8(1))} // 0
Int8 = Kind{reflect.TypeOf(int8(1))} // 1
Int16 = Kind{reflect.TypeOf(int16(1))} // 2
Int = Kind{reflect.TypeOf(int(1))} // 3
Int64 = Kind{reflect.TypeOf(int64(1))} // 4
Half = Kind{reflect.TypeOf(float32(1))} // 5
Float = Kind{reflect.TypeOf(float64(1))} // 6
Double = Kind{reflect.TypeOf(float64(1))} // 7
ComplexHalf = kind{reflect.TypeOf(complex(1))} // 8
ComplexFloat = Kind{reflect.TypeOf(complex64(1))} // 9
ComplexDouble = kind{reflect.TypeOf(complex128(1))} // 10
Bool = kind{reflect.TypeOf(true)} // 11
)
// ToCInt converts Kind to CInt type value which is `C int`
func (k Kind) ToCInt() CInt {
switch {
case k.Kind() == uint8:
return 0
case k.Kind() == int8:
return 1
case k.Kind() == int16:
return 2
case k.Kind() == int:
return 3
case k.Kind() == int64:
return 4
case k.Kind() == float32:
return 5
default:
log.Fatalf("Unsupported type.")
}
// unreachable
return CInt(-1)
}
// OfCInt converts a value of type CInt to Kind type value
func (k Kind) OfCInt(v CInt) Kind {
switch v {
case 0:
return Uint8
case 1:
return Int8
case 2:
return Int16
case 3:
return Int
case 4:
return Int64
case 5:
return Half
case 6:
return Float
case 7:
return Double
case 8:
return ComplexHalf
case 9:
return ComplexFloat
case 10:
return ComplexDouble
case 11:
return Bool
default:
log.Fatalf("Unexpected kind %v\n", v)
}
return Kind{reflect.TypeOf(false)}
}
// EltSizeInBytes converts a Kind value to number of bytes
// This is a ELement Size In Byte in Libtorch.
// Has it been deprecated?
func (k Kind) EltSizeInBytes() uint {
switch {
case k.ToCInt() == int32(Uint8):
return 1
case k.ToCInt() == int32(Int8):
return 1
case k.ToCInt() == int32(Int16):
return 2
case k.ToCInt() == int32(Int):
return 4
case k.ToCInt() == int32(Int64):
return 8
case k.ToCInt() == int32(Half):
return 2
case k.ToCInt() == int32(Float):
return 4
case k.ToCInt() == int32(Double):
return 8
case k.ToCInt() == int32(ComplexHalf):
return 4
case k.ToCInt() == int32(ComplexDouble):
return 16
case k.ToCInt() == int32(Bool):
return 1
default:
log.Fatalf("Unreachable")
}
return uint(0)
}
type KindDevice struct {
Kind Kind
Device Device
}
var (
FloatCPU KindDevice = KindDevice{Float, CPU}
DoubleCPU KindDevice = KindDevice{Double, CPU}
Int64CPU KindDevice = KindDevice{Int64, CPU}
FloatCUDA KindDevice = KindDevice{Float, CudaBuilder(0)}
DoubleCUDA KindDevice = KindDevice{Double, CudaBuilder(0)}
Int64CUDA KindDevice = KindDevice{Int64, CudaBuilder(0)}
)

View File

@ -18,7 +18,7 @@ type C_tensor struct {
private unsafe.Pointer
}
func NewTensor() *C_tensor {
func AtNewTensor() *C_tensor {
t := C.at_new_tensor()
return &C_tensor{private: unsafe.Pointer(t)}
}
@ -31,11 +31,6 @@ func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_byt
c_elt_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&elt_size_in_bytes))
c_kind := *(*C.int)(unsafe.Pointer(&kind))
// c_dims := (*C.long)(unsafe.Pointer(uintptr(dims)))
// c_ndims := *(*C.size_t)(unsafe.Pointer(uintptr(ndims)))
// c_elt_size_in_bytes := *(*C.size_t)(unsafe.Pointer(uintptr(elt_size_in_bytes)))
// c_kind := *(*C.int)(unsafe.Pointer(uintptr(kind)))
// t is of type `unsafe.Pointer` in Go and `*void` in C
t := C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
fmt.Printf("t type: %v\n", reflect.TypeOf(t).Kind())

View File

@ -1,147 +0,0 @@
package tensor
import (
"log"
gotch "github.com/sugarme/gotch"
)
// CInt is equal to C type int. Go type is int32
type CInt = int32
// Kind is 'enum' like type. It represents different kind of elements
// that a Tensor can hold.
type Kind int
const (
Uint8 Kind = iota // 0
Int8 // 1
Int16 // 2
Int // 3
Int64 // 4
Half // 5
Float // 6
Double // 7
ComplexHalf // 8
ComplexFloat // 9
ComplexDouble // 10
Bool // 11
)
// ToCInt converts Kind to CInt type value which is `C int`
func (k Kind) ToCInt() CInt {
return CInt(k)
}
// OfCInt converts a value of type CInt to Kind type value
func (k Kind) OfCInt(v CInt) Kind {
switch v {
case 0:
return Uint8
case 1:
return Int8
case 2:
return Int16
case 3:
return Int
case 4:
return Int64
case 5:
return Half
case 6:
return Float
case 7:
return Double
case 8:
return ComplexHalf
case 9:
return ComplexFloat
case 10:
return ComplexDouble
case 11:
return Bool
default:
log.Fatalf("Unexpected kind %v\n", v)
}
return Kind(0)
}
// EltSizeInBytes converts a Kind value to number of bytes
// This is a ELement Size In Byte in Libtorch.
// Has it been deprecated?
func (k Kind) EltSizeInBytes() uint {
switch {
case k.ToCInt() == int32(Uint8):
return 1
case k.ToCInt() == int32(Int8):
return 1
case k.ToCInt() == int32(Int16):
return 2
case k.ToCInt() == int32(Int):
return 4
case k.ToCInt() == int32(Int64):
return 8
case k.ToCInt() == int32(Half):
return 2
case k.ToCInt() == int32(Float):
return 4
case k.ToCInt() == int32(Double):
return 8
case k.ToCInt() == int32(ComplexHalf):
return 4
case k.ToCInt() == int32(ComplexDouble):
return 16
case k.ToCInt() == int32(Bool):
return 1
default:
log.Fatalf("Unreachable")
}
return uint(0)
}
type KindDevice struct {
Kind Kind
Device gotch.Device
}
var (
FloatCPU KindDevice = KindDevice{Float, gotch.CPU}
DoubleCPU KindDevice = KindDevice{Double, gotch.CPU}
Int64CPU KindDevice = KindDevice{Int64, gotch.CPU}
FloatCUDA KindDevice = KindDevice{Float, gotch.CudaBuilder(0)}
DoubleCUDA KindDevice = KindDevice{Double, gotch.CudaBuilder(0)}
Int64CUDA KindDevice = KindDevice{Int64, gotch.CudaBuilder(0)}
)
type KindTrait interface {
GetKind() Kind
}
type KindUint8 struct{}
func (k KindUint8) GetKind() Kind { return Uint8 }
type KindInt8 struct{}
func (k KindInt8) GetKind() Kind { return Int8 }
type KindInt16 struct{}
func (k KindInt16) GetKind() Kind { return Int16 }
type KindInt64 struct{}
func (k KindInt64) GetKind() Kind { return Int64 }
type KindFloat32 struct{}
func (k KindFloat32) GetKind() Kind { return Float }
type KindFloat64 struct{}
func (k KindFloat64) GetKind() Kind { return Double }
type KindBool struct{}
func (k KindBool) GetKind() Kind { return Bool }

68
wrapper/tensor.go Normal file
View File

@ -0,0 +1,68 @@
package wrapper
//#include <stdlib.h>
import "C"
import (
"bytes"
"encoding/binary"
"fmt"
"reflect"
"unsafe"
lib "github.com/sugarme/gotch/libtch"
)
type Tensor struct {
ctensor *t.C_tensor
}
// NewTensor creates a new tensor
func NewTensor() Tensor {
ctensor := lib.AtNewTensor()
return Tensor{ctensor}
}
// FOfSlice creates tensor from a slice data
func(ts Tensor) FOfSlice(data []inteface{}) (retVal Tensor, err error) {
data := []int{0, 0, 0, 0}
shape := []int64{int64(len(data))}
nflattened := numElements(shape)
dtype := 3 // Kind.Int
eltSizeInBytes := 4 // Element Size in Byte for Int dtype
nbytes := eltSizeInBytes * int(uintptr(nflattened))
// NOTE: dataPrt is type of `*void` in C or type of `unsafe.Pointer` in Go
// data should be allocated to memory BY `C` side
dataPtr := C.malloc(C.size_t(nbytes))
// Recall: 1 << 30 = 1 * 2 * 30
// Ref. See more at https://stackoverflow.com/questions/48756732
dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes]
buf := bytes.NewBuffer(dataSlice[:0:nbytes])
EncodeTensor(buf, reflect.ValueOf(data), shape)
c_tensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(dtype))
retVal = Tensor{c_tensor}
// Read back created tensor values by C libtorch
readDataPtr := lib.AtDataPtr(retVal.c_tensor)
readDataSlice := (*[1 << 30]byte)(readDataPtr)[:nbytes:nbytes]
// typ := typeOf(dtype, shape)
typ := reflect.TypeOf(int32(0)) // C. type `int` ~ Go type `int32`
val := reflect.New(typ)
if err := DecodeTensor(bytes.NewReader(readDataSlice), shape, typ, val); err != nil {
panic(fmt.Sprintf("unable to decode Tensor of type %v and shape %v - %v", dtype, shape, err))
}
tensorData := reflect.Indirect(val).Interface()
fmt.Println("%v", tensorData)
return retVal, nil
}

View File

@ -1,23 +1,16 @@
package tensor
//#include <stdlib.h>
import "C"
package wrapper
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"reflect"
// "runtime"
"unsafe"
lib "github.com/sugarme/gotch/libtch"
gotch "github.com/sugarme/gotch"
)
type Tensor struct {
ctensor *t.C_tensor
}
var nativeEndian binary.ByteOrder
func init() {
@ -34,58 +27,7 @@ func init() {
}
}
// FnOfSlice creates tensor from a slice data
func FnOfSlice() (retVal Tensor, err error) {
data := []int{0, 0, 0, 0}
shape := []int64{int64(len(data))}
nflattened := numElements(shape)
dtype := 3 // Kind.Int
eltSizeInBytes := 4 // Element Size in Byte for Int dtype
nbytes := eltSizeInBytes * int(uintptr(nflattened))
// NOTE: dataPrt is type of `*void` in C or type of `unsafe.Pointer` in Go
dataPtr := C.malloc(C.size_t(nbytes))
// Recall: 1 << 30 = 1 * 2 * 30
// Ref. See more at https://stackoverflow.com/questions/48756732
dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes]
buf := bytes.NewBuffer(dataSlice[:0:nbytes])
encodeTensor(buf, reflect.ValueOf(data), shape)
c_tensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(dtype))
retVal = Tensor{c_tensor}
// Read back created tensor values by C libtorch
readDataPtr := lib.AtDataPtr(retVal.c_tensor)
readDataSlice := (*[1 << 30]byte)(readDataPtr)[:nbytes:nbytes]
// typ := typeOf(dtype, shape)
typ := reflect.TypeOf(int32(0)) // C. type `int` ~ Go type `int32`
val := reflect.New(typ)
if err := decodeTensor(bytes.NewReader(readDataSlice), shape, typ, val); err != nil {
panic(fmt.Sprintf("unable to decode Tensor of type %v and shape %v - %v", dtype, shape, err))
}
tensorData := reflect.Indirect(val).Interface()
fmt.Println("%v", tensorData)
return retVal, nil
}
func numElements(shape []int64) int64 {
n := int64(1)
for _, d := range shape {
n *= d
}
return n
}
func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
switch v.Kind() {
case reflect.Bool:
b := byte(0)
@ -120,7 +62,7 @@ func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
subShape := shape[1:]
for i := 0; i < v.Len(); i++ {
err := encodeTensor(w, v.Index(i), subShape)
err := EncodeTensor(w, v.Index(i), subShape)
if err != nil {
return err
}
@ -132,9 +74,9 @@ func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
return nil
}
// decodeTensor decodes the Tensor from the buffer to ptr using the format
// DecodeTensor decodes the Tensor from the buffer to ptr using the format
// specified in c_api.h. Use stringDecoder for String tensors.
func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
switch typ.Kind() {
case reflect.Bool:
b, err := r.ReadByte()
@ -160,7 +102,7 @@ func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
}
for i := 0; i < val.Len(); i++ {
if err := decodeTensor(r, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil {
if err := DecodeTensor(r, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil {
return err
}
}
@ -171,8 +113,34 @@ func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
return nil
}
// // typeOf converts from a DType and Shape to the equivalent Go type.
// func typeOf(dt DType, shape []int64) reflect.Type {
func numElements(shape []int64) int64 {
n := int64(1)
for _, d := range shape {
n *= d
}
return n
}
// GetKind returns data type `Kind` (a element of tensor can hold)
// v - a value of a data element
func GetKind(v interface{}) (retVal gotch.Kind, err error) {
switch {
case reflect.TypeOf(v) == int:
retVal = gotch.Int
case reflect.TypeOf(v) == uint8:
retVal = gotch.Uint8
default:
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
return retVal, err
}
return retVal, nil
}
// // TypeOf converts from a DType and Shape to the equivalent Go type.
// func TypeOf(dt DType, shape []int64) reflect.Type {
// var ret reflect.Type
// for _, t := range types {
// if dt == DType(t.dataType) {