2022-03-12 07:20:20 +00:00
|
|
|
package ts
|
2020-05-28 08:30:17 +01:00
|
|
|
|
2020-05-30 00:04:47 +01:00
|
|
|
// #include <stdlib.h>
|
|
|
|
import "C"
|
|
|
|
|
2020-05-28 08:30:17 +01:00
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"encoding/binary"
|
|
|
|
"fmt"
|
2022-02-13 11:46:50 +00:00
|
|
|
|
2020-06-11 02:57:56 +01:00
|
|
|
// "log"
|
2020-05-28 08:30:17 +01:00
|
|
|
"reflect"
|
|
|
|
"unsafe"
|
2020-05-30 06:39:56 +01:00
|
|
|
|
2024-04-21 15:15:00 +01:00
|
|
|
gotch "git.andr3h3nriqu3s.com/andr3/gotch"
|
2020-05-28 08:30:17 +01:00
|
|
|
)
|
|
|
|
|
2020-05-30 00:04:47 +01:00
|
|
|
// nativeEndian is a ByteOrder for local platform.
|
|
|
|
// Ref. https://stackoverflow.com/a/53286786
|
|
|
|
// Ref. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/go/tensor.go#L488-L505
|
2020-05-28 08:30:17 +01:00
|
|
|
var nativeEndian binary.ByteOrder
|
|
|
|
|
|
|
|
func init() {
|
|
|
|
buf := [2]byte{}
|
|
|
|
*(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD)
|
|
|
|
|
|
|
|
switch buf {
|
|
|
|
case [2]byte{0xCD, 0xAB}:
|
|
|
|
nativeEndian = binary.LittleEndian
|
|
|
|
case [2]byte{0xAB, 0xCD}:
|
|
|
|
nativeEndian = binary.BigEndian
|
|
|
|
default:
|
|
|
|
panic("Could not determine native endianness.")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-05-30 00:04:47 +01:00
|
|
|
// CMalloc allocates a given number of bytes to C side memory.
|
|
|
|
// It returns
|
|
|
|
// - dataPtr: a C pointer type of `*void` (`unsafe.Pointer` in Go).
|
|
|
|
// - buf : a Go pointer points to a given bytes of buffer (empty) in C memory
|
|
|
|
// allocated by C waiting for writing data to.
|
|
|
|
//
|
|
|
|
// NOTE:
|
|
|
|
// 1. Go pointer is a pointer to Go memory. C pointer is a pointer to C memory.
|
|
|
|
// 2. General rule is Go code can use C pointers. Go code may pass Go pointer to C
|
|
|
|
// provided that the Go memory to which it points does NOT contain any Go
|
|
|
|
// pointers. BUT C code must not store any Go pointers in Go memory, even
|
|
|
|
// temporarily.
|
|
|
|
// 3. Some Go values contain Go pointers IMPLICITLY: strings, slices, maps,
|
|
|
|
// channels and function values. Thus, pointers to these values should not be
|
|
|
|
// passed to C side. Instead, data should be allocated to C memory and return a
|
|
|
|
// C pointer to it using `C.malloc`.
|
|
|
|
// Ref: https://github.com/golang/proposal/blob/master/design/12416-cgo-pointers.md
|
|
|
|
func CMalloc(nbytes int) (dataPtr unsafe.Pointer, buf *bytes.Buffer) {
|
|
|
|
|
|
|
|
dataPtr = C.malloc(C.size_t(nbytes))
|
2020-06-17 07:45:59 +01:00
|
|
|
// NOTE: uncomment this cause panic!
|
|
|
|
// defer C.free(unsafe.Pointer(dataPtr))
|
2020-05-30 00:04:47 +01:00
|
|
|
|
2020-12-09 22:21:03 +00:00
|
|
|
// Recall: 1 << 30 = 1 * 2 * 30 = 1073741824
|
|
|
|
dataSlice := (*[1 << 32]byte)(dataPtr)[:nbytes:nbytes] // 4294967296
|
2020-05-30 00:04:47 +01:00
|
|
|
buf = bytes.NewBuffer(dataSlice[:0:nbytes])
|
|
|
|
|
|
|
|
return dataPtr, buf
|
|
|
|
}
|
|
|
|
|
|
|
|
// EncodeTensor loads tensor data to C memory and returns a C pointer.
|
2020-05-28 17:58:23 +01:00
|
|
|
func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
2020-05-28 08:30:17 +01:00
|
|
|
switch v.Kind() {
|
|
|
|
case reflect.Bool:
|
|
|
|
b := byte(0)
|
|
|
|
if v.Bool() {
|
|
|
|
b = 1
|
|
|
|
}
|
|
|
|
if err := w.WriteByte(b); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2023-07-06 15:01:23 +01:00
|
|
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
2020-05-28 08:30:17 +01:00
|
|
|
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
case reflect.Array, reflect.Slice:
|
|
|
|
// If current dimension is a slice, verify that it has the expected size
|
|
|
|
// Go's type system makes that guarantee for arrays.
|
|
|
|
if v.Kind() == reflect.Slice {
|
|
|
|
expected := int(shape[0])
|
|
|
|
if v.Len() != expected {
|
2023-07-06 15:01:23 +01:00
|
|
|
return fmt.Errorf("EncodeTensor() failed: mismatched slice lengths: %d and %d", v.Len(), expected)
|
2020-05-28 08:30:17 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
|
|
|
|
if len(shape) == 1 && v.Len() > 0 {
|
|
|
|
switch v.Index(0).Kind() {
|
2023-07-06 15:01:23 +01:00
|
|
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
2020-05-28 08:30:17 +01:00
|
|
|
return binary.Write(w, nativeEndian, v.Interface())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
subShape := shape[1:]
|
|
|
|
for i := 0; i < v.Len(); i++ {
|
2020-05-28 17:58:23 +01:00
|
|
|
err := EncodeTensor(w, v.Index(i), subShape)
|
2020-05-28 08:30:17 +01:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
default:
|
2023-07-06 15:01:23 +01:00
|
|
|
return fmt.Errorf("EncodeTensor() failed: unsupported type %v", v.Type())
|
2020-05-28 08:30:17 +01:00
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2020-05-30 00:04:47 +01:00
|
|
|
// DecodeTensor decodes tensor value from a C memory buffer given
|
|
|
|
// C pointer, data type and shape and returns data value of type interface
|
2020-05-28 17:58:23 +01:00
|
|
|
func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
|
2020-05-28 08:30:17 +01:00
|
|
|
switch typ.Kind() {
|
|
|
|
case reflect.Bool:
|
|
|
|
b, err := r.ReadByte()
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
ptr.Elem().SetBool(b == 1)
|
2023-07-06 15:01:23 +01:00
|
|
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
2020-05-28 08:30:17 +01:00
|
|
|
if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
case reflect.Slice:
|
|
|
|
val := reflect.Indirect(ptr)
|
|
|
|
val.Set(reflect.MakeSlice(typ, int(shape[0]), int(shape[0])))
|
|
|
|
|
|
|
|
// Optimization: if only one dimension is left we can use binary.Read() directly for this slice
|
|
|
|
if len(shape) == 1 && val.Len() > 0 {
|
|
|
|
switch val.Index(0).Kind() {
|
2023-07-06 15:01:23 +01:00
|
|
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
2020-05-28 08:30:17 +01:00
|
|
|
return binary.Read(r, nativeEndian, val.Interface())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
for i := 0; i < val.Len(); i++ {
|
2020-05-28 17:58:23 +01:00
|
|
|
if err := DecodeTensor(r, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil {
|
2020-05-28 08:30:17 +01:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
default:
|
|
|
|
return fmt.Errorf("unsupported type %v", typ)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2020-05-30 00:04:47 +01:00
|
|
|
// ElementCount counts number of element in the tensor given a shape
|
2023-07-06 15:01:23 +01:00
|
|
|
func ElementCount(shape []int64) int {
|
|
|
|
n := 1
|
2020-05-28 17:58:23 +01:00
|
|
|
for _, d := range shape {
|
2023-07-06 15:01:23 +01:00
|
|
|
n *= int(d)
|
2020-05-28 17:58:23 +01:00
|
|
|
}
|
|
|
|
return n
|
|
|
|
}
|
2020-05-30 06:39:56 +01:00
|
|
|
|
|
|
|
// DataDim returns number of elements in data
|
2020-06-01 06:45:25 +01:00
|
|
|
// NOTE: only support scalar and (nested) slice/array of scalar type
|
2020-05-30 06:39:56 +01:00
|
|
|
func DataDim(data interface{}) (retVal int, err error) {
|
2020-06-01 06:45:25 +01:00
|
|
|
_, count, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
|
2020-06-08 05:33:19 +01:00
|
|
|
|
2020-06-01 06:45:25 +01:00
|
|
|
return count, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// DataCheck checks the input data for element Go type and number of elements.
|
2023-07-06 15:01:23 +01:00
|
|
|
// It will return errors if element dtype is not supported.
|
|
|
|
func DataCheck(data interface{}) (dtype gotch.DType, n int, err error) {
|
2020-06-01 06:45:25 +01:00
|
|
|
return dataCheck(reflect.ValueOf(data).Interface(), 0)
|
|
|
|
}
|
|
|
|
|
|
|
|
// NOTE: 0 is reflect.Kind() of Invalid
|
|
|
|
// See: https://golang.org/pkg/reflect/#Kind
|
2023-07-06 15:01:23 +01:00
|
|
|
func dataCheck(data interface{}, count int) (dtype gotch.DType, n int, err error) {
|
2020-05-30 06:39:56 +01:00
|
|
|
v := reflect.ValueOf(data)
|
2020-06-01 06:45:25 +01:00
|
|
|
var total int = count
|
|
|
|
var round = 0
|
2020-05-30 06:39:56 +01:00
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
|
2020-06-01 06:45:25 +01:00
|
|
|
if round == 0 {
|
|
|
|
round = v.Len()
|
2020-05-30 06:39:56 +01:00
|
|
|
}
|
2020-06-01 06:45:25 +01:00
|
|
|
for i := 0; i < v.Len(); i++ {
|
|
|
|
round--
|
2023-07-06 15:01:23 +01:00
|
|
|
dtype, total, err = dataCheck(v.Index(i).Interface(), total)
|
2020-05-30 06:39:56 +01:00
|
|
|
|
2020-06-01 06:45:25 +01:00
|
|
|
if err != nil {
|
2023-07-06 15:01:23 +01:00
|
|
|
return gotch.Invalid, 0, err
|
2020-06-01 06:45:25 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
return dtype, total, nil
|
|
|
|
}
|
2020-06-01 06:45:25 +01:00
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
total += 1
|
|
|
|
dtype, err = gotch.GoKind2DType(v.Kind())
|
|
|
|
if err != nil {
|
|
|
|
err = fmt.Errorf("DataCheck() failed: unsupported data structure or type: %v\n", v.Kind())
|
|
|
|
return gotch.Invalid, 0, err
|
2020-05-30 06:39:56 +01:00
|
|
|
}
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
return dtype, total, nil
|
2020-05-30 06:39:56 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// DataAsPtr write to C memory and returns a C pointer.
|
|
|
|
//
|
|
|
|
// NOTE:
|
|
|
|
// Supported data types are scalar, slice/array of scalar type equivalent to
|
|
|
|
// DType.
|
|
|
|
func DataAsPtr(data interface{}) (dataPtr unsafe.Pointer, err error) {
|
|
|
|
// 1. Count number of elements in data
|
|
|
|
elementNum, err := DataDim(data)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
// 2. Number of bytes
|
2020-05-30 06:39:56 +01:00
|
|
|
dtype, err := gotch.DTypeFromData(data)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
nbytes := int(dtype.Size()) * int(elementNum)
|
2020-05-30 06:39:56 +01:00
|
|
|
|
|
|
|
// 3. Get C pointer and prepare C memory buffer for writing
|
|
|
|
dataPtr, buff := CMalloc(nbytes)
|
|
|
|
|
|
|
|
// 4. Write data to C memory
|
2020-06-01 06:45:25 +01:00
|
|
|
// NOTE: data should be **fixed size** values so that binary.Write can work
|
|
|
|
// A fixed-size value is either a fixed-size arithmetic type (bool, int8, uint8,
|
|
|
|
// int16, float32, complex64, ...) or an array or struct containing only fixed-size values.
|
|
|
|
// See more: https://golang.org/pkg/encoding/binary/
|
|
|
|
// Therefore, we will need to flatten data to `[]T`
|
|
|
|
fData, err := FlattenData(data)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
err = binary.Write(buff, nativeEndian, fData)
|
2020-05-30 06:39:56 +01:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return dataPtr, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// FlattenDim counts number of elements with given shape
|
|
|
|
func FlattenDim(shape []int64) int {
|
|
|
|
n := int64(1)
|
|
|
|
for _, d := range shape {
|
|
|
|
n *= d
|
|
|
|
}
|
|
|
|
|
|
|
|
return int(n)
|
|
|
|
}
|
2020-06-01 06:45:25 +01:00
|
|
|
|
|
|
|
// FlattenData flattens data to 1D array ([]T)
|
|
|
|
func FlattenData(data interface{}) (fData interface{}, err error) {
|
|
|
|
|
2020-07-22 06:26:18 +01:00
|
|
|
// If data is 1D already, just return it.
|
|
|
|
dataVal := reflect.ValueOf(data)
|
|
|
|
dataTyp := reflect.TypeOf(data)
|
|
|
|
if dataVal.Kind() == reflect.Slice {
|
|
|
|
eleVal := dataTyp.Elem()
|
|
|
|
if eleVal.Kind() != reflect.Slice {
|
|
|
|
return data, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-06-01 06:45:25 +01:00
|
|
|
flat, err := flattenData(reflect.ValueOf(data).Interface(), 0, []interface{}{})
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
ele := flat[0]
|
|
|
|
|
|
|
|
// Boring task. Convert interface to specific type.
|
|
|
|
// Any good way to do???
|
|
|
|
switch reflect.ValueOf(ele).Kind() {
|
|
|
|
case reflect.Uint8:
|
|
|
|
var retVal []uint8
|
|
|
|
for _, v := range flat {
|
|
|
|
retVal = append(retVal, v.(uint8))
|
|
|
|
}
|
|
|
|
return retVal, nil
|
|
|
|
case reflect.Int8:
|
|
|
|
var retVal []int8
|
|
|
|
for _, v := range flat {
|
|
|
|
retVal = append(retVal, v.(int8))
|
|
|
|
}
|
|
|
|
return retVal, nil
|
|
|
|
case reflect.Int16:
|
|
|
|
var retVal []int16
|
|
|
|
for _, v := range flat {
|
|
|
|
retVal = append(retVal, v.(int16))
|
|
|
|
}
|
|
|
|
return retVal, nil
|
|
|
|
case reflect.Int32:
|
|
|
|
var retVal []int32
|
|
|
|
for _, v := range flat {
|
|
|
|
retVal = append(retVal, v.(int32))
|
|
|
|
}
|
|
|
|
return retVal, nil
|
|
|
|
case reflect.Int64:
|
|
|
|
var retVal []int64
|
|
|
|
for _, v := range flat {
|
|
|
|
retVal = append(retVal, v.(int64))
|
|
|
|
}
|
|
|
|
return retVal, nil
|
|
|
|
case reflect.Float32:
|
|
|
|
var retVal []float32
|
|
|
|
for _, v := range flat {
|
|
|
|
retVal = append(retVal, v.(float32))
|
|
|
|
}
|
|
|
|
return retVal, nil
|
|
|
|
case reflect.Float64:
|
|
|
|
var retVal []float64
|
|
|
|
for _, v := range flat {
|
|
|
|
retVal = append(retVal, v.(float64))
|
|
|
|
}
|
|
|
|
return retVal, nil
|
|
|
|
case reflect.Bool:
|
|
|
|
var retVal []bool
|
|
|
|
for _, v := range flat {
|
|
|
|
retVal = append(retVal, v.(bool))
|
|
|
|
}
|
|
|
|
return retVal, nil
|
|
|
|
|
|
|
|
default:
|
|
|
|
err = fmt.Errorf("Unsupport type for input data: %v\n", reflect.ValueOf(ele).Kind())
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
func flattenData(data interface{}, round int, flat []interface{}) (f []interface{}, err error) {
|
|
|
|
v := reflect.ValueOf(data)
|
|
|
|
var flatData []interface{} = flat
|
|
|
|
|
|
|
|
switch v.Kind() {
|
|
|
|
case reflect.Slice, reflect.Array:
|
|
|
|
if round == 0 {
|
|
|
|
round = v.Len()
|
|
|
|
}
|
|
|
|
for i := 0; i < v.Len(); i++ {
|
|
|
|
round--
|
|
|
|
flatData, err = flattenData(v.Index(i).Interface(), round, flatData)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return flatData, nil
|
|
|
|
|
2020-06-08 05:33:19 +01:00
|
|
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
2020-06-01 06:45:25 +01:00
|
|
|
flatData = append(flatData, data)
|
|
|
|
}
|
|
|
|
|
|
|
|
return flatData, nil
|
|
|
|
}
|
2020-06-11 02:57:56 +01:00
|
|
|
|
|
|
|
// InvokeFn reflects and invokes a function of interface type.
|
|
|
|
func InvokeFnWithArgs(fn interface{}, args ...string) {
|
|
|
|
v := reflect.ValueOf(fn)
|
|
|
|
rargs := make([]reflect.Value, len(args))
|
|
|
|
for i, a := range args {
|
|
|
|
rargs[i] = reflect.ValueOf(a)
|
|
|
|
}
|
|
|
|
v.Call(rargs)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Func struct contains information of a function
|
|
|
|
type FuncInfo struct {
|
|
|
|
Signature string
|
|
|
|
InArgs []reflect.Value
|
|
|
|
OutArgs []reflect.Value
|
|
|
|
IsVariadic bool
|
|
|
|
}
|
|
|
|
|
|
|
|
type Func struct {
|
|
|
|
typ reflect.Type
|
|
|
|
val reflect.Value
|
|
|
|
meta FuncInfo
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewFunc(fn interface{}) (retVal Func, err error) {
|
|
|
|
meta, err := getFuncInfo(fn)
|
|
|
|
if err != nil {
|
|
|
|
return retVal, err
|
|
|
|
}
|
|
|
|
|
|
|
|
retVal = Func{
|
|
|
|
typ: reflect.TypeOf(fn),
|
|
|
|
val: reflect.ValueOf(fn),
|
|
|
|
meta: meta,
|
|
|
|
}
|
|
|
|
return retVal, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// getFuncInfo analyzes input of interface type and returns function information
|
|
|
|
// in FuncInfo struct. It returns error if input is not a function type under
|
|
|
|
// the hood.
|
|
|
|
func getFuncInfo(fn interface{}) (retVal FuncInfo, err error) {
|
|
|
|
fnVal := reflect.ValueOf(fn)
|
|
|
|
fnTyp := reflect.TypeOf(fn)
|
|
|
|
|
|
|
|
// First, check whether input is a function type
|
|
|
|
if fnVal.Kind() != reflect.Func {
|
|
|
|
err = fmt.Errorf("Input is not a function.")
|
|
|
|
return retVal, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// get number of input and output arguments of function
|
|
|
|
numIn := fnTyp.NumIn() // inbound parameters
|
|
|
|
numOut := fnTyp.NumOut() // outbound parameters
|
|
|
|
isVariadic := fnTyp.IsVariadic() // whether function is a variadic func
|
|
|
|
fnSig := fnTyp.String() // function signature
|
|
|
|
|
|
|
|
// get input and ouput arguments values (reflect.Value type)
|
|
|
|
var inArgs []reflect.Value
|
|
|
|
var outArgs []reflect.Value
|
|
|
|
|
|
|
|
for i := 0; i < numIn; i++ {
|
|
|
|
t := fnTyp.In(i) // reflect.Type
|
|
|
|
|
|
|
|
inArgs = append(inArgs, reflect.ValueOf(t))
|
|
|
|
}
|
|
|
|
|
|
|
|
for i := 0; i < numOut; i++ {
|
|
|
|
t := fnTyp.Out(i) // reflect.Type
|
|
|
|
outArgs = append(outArgs, reflect.ValueOf(t))
|
|
|
|
}
|
|
|
|
|
|
|
|
retVal = FuncInfo{
|
|
|
|
Signature: fnSig,
|
|
|
|
InArgs: inArgs,
|
|
|
|
OutArgs: outArgs,
|
|
|
|
IsVariadic: isVariadic,
|
|
|
|
}
|
|
|
|
|
|
|
|
return retVal, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Info analyzes input of interface type and returns function information
|
|
|
|
// in FuncInfo struct. It returns error if input is not a function type under
|
|
|
|
// the hood. It will be panic if input is not a function
|
|
|
|
func (f *Func) Info() (retVal FuncInfo) {
|
|
|
|
return f.meta
|
|
|
|
}
|
|
|
|
|
|
|
|
func (f *Func) Invoke() interface{} {
|
|
|
|
// call function with input parameters
|
|
|
|
// TODO: return vals are []reflect.Value
|
|
|
|
// How do we match them to output order of signature function
|
|
|
|
return f.val.Call(f.meta.InArgs)
|
|
|
|
}
|
2020-06-14 13:46:36 +01:00
|
|
|
|
|
|
|
// Must is a helper to unwrap function it wraps. If having error,
|
|
|
|
// it will cause panic.
|
|
|
|
func Must(ts Tensor, err error) (retVal Tensor) {
|
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
return ts
|
|
|
|
}
|
2022-02-13 11:46:50 +00:00
|
|
|
|
|
|
|
func sliceIntToInt32(input []int) []int32 {
|
|
|
|
out := make([]int32, len(input))
|
|
|
|
for i := 0; i < len(input); i++ {
|
|
|
|
out[i] = int32(input[i])
|
|
|
|
}
|
|
|
|
return out
|
|
|
|
}
|
2022-02-13 12:50:45 +00:00
|
|
|
|
|
|
|
func sliceInt32ToInt(input []int32) []int {
|
|
|
|
out := make([]int, len(input))
|
|
|
|
for i := 0; i < len(input); i++ {
|
|
|
|
out[i] = int(input[i])
|
|
|
|
}
|
|
|
|
return out
|
|
|
|
}
|