1207 lines
29 KiB
Go
1207 lines
29 KiB
Go
package ts
|
|
|
|
// JIT interface to run model trained/saved using PyTorch Python API.
|
|
|
|
// #include "stdlib.h"
|
|
import "C"
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"reflect"
|
|
"unsafe"
|
|
|
|
"git.andr3h3nriqu3s.com/andr3/gotch"
|
|
lib "git.andr3h3nriqu3s.com/andr3/gotch/libtch"
|
|
)
|
|
|
|
type CIValue struct {
|
|
civalue lib.Civalue
|
|
}
|
|
|
|
type IValueKind int
|
|
|
|
const (
|
|
NoneVal IValueKind = iota
|
|
TensorVal // *Tensor
|
|
DoubleVal // float64
|
|
IntVal // int64
|
|
BoolVal // bool
|
|
TupleVal // []*IValue
|
|
IntListVal // []int64
|
|
DoubleListVal // []float64
|
|
BoolListVal // []bool
|
|
StringVal // string
|
|
TensorListVal // []*Tensor
|
|
GenericListVal // []*IValue
|
|
GenericDictVal // map[IValue]IValue - 2 elements
|
|
GenericVal // *IValue
|
|
)
|
|
|
|
type IValue struct {
|
|
value interface{}
|
|
kind IValueKind
|
|
name string
|
|
}
|
|
|
|
// NewIValue creates a new IValue from given value of various types.
|
|
func NewIValue(v interface{}) *IValue {
|
|
retVal := &IValue{value: v}
|
|
if v == nil {
|
|
retVal.kind = NoneVal
|
|
retVal.name = "None"
|
|
return retVal
|
|
}
|
|
|
|
inputTypeStr := reflect.TypeOf(v).Kind().String()
|
|
|
|
switch inputTypeStr {
|
|
case "*Tensor":
|
|
retVal.kind = TensorVal
|
|
retVal.name = "Tensor"
|
|
case "float64":
|
|
retVal.kind = DoubleVal
|
|
retVal.name = "Double"
|
|
case "float32":
|
|
retVal.kind = GenericVal
|
|
retVal.name = "Generic"
|
|
case "int64":
|
|
retVal.kind = IntVal
|
|
retVal.name = "Int"
|
|
case "int":
|
|
retVal.kind = GenericVal
|
|
retVal.name = "Generic"
|
|
case "int32":
|
|
retVal.kind = GenericVal
|
|
retVal.name = "Generic"
|
|
case "bool":
|
|
retVal.kind = BoolVal
|
|
retVal.name = "Bool"
|
|
case "string":
|
|
retVal.kind = StringVal
|
|
retVal.name = "String"
|
|
case "slice":
|
|
switch reflect.TypeOf(v).Elem().Kind().String() {
|
|
case "int64":
|
|
retVal.kind = IntListVal
|
|
retVal.name = "IntList"
|
|
case "float64":
|
|
retVal.kind = DoubleListVal
|
|
retVal.name = "DoubleList"
|
|
case "float32":
|
|
retVal.kind = GenericListVal
|
|
retVal.name = "GenericList"
|
|
case "int32":
|
|
retVal.kind = GenericListVal
|
|
retVal.name = "GenericList"
|
|
case "int":
|
|
retVal.kind = GenericListVal
|
|
retVal.name = "GenericList"
|
|
case "string":
|
|
retVal.kind = GenericListVal
|
|
retVal.name = "GenericList"
|
|
case "bool":
|
|
retVal.kind = BoolListVal
|
|
retVal.name = "BoolList"
|
|
case "ptr": // NOTE: only supported `*Tensor` type
|
|
val := reflect.Indirect(reflect.ValueOf(v))
|
|
switch {
|
|
// 1. Tuple (*Tensor, *Tensor)
|
|
case val.Type().String() == "[]*ts.Tensor" && val.Len() == 2:
|
|
retVal.kind = TensorListVal
|
|
retVal.name = "Tuple"
|
|
retVal.value = v.([]*Tensor)
|
|
|
|
// 2. List (*Tensor, *Tensor, ...)
|
|
case val.Type().String() == "[]*ts.Tensor" && val.Len() > 2:
|
|
retVal.kind = TensorListVal
|
|
retVal.name = "TensorList"
|
|
retVal.value = v.([]*Tensor)
|
|
case val.Type().String() == "[]*ts.IValue" && val.Len() == 2:
|
|
retVal.kind = TupleVal
|
|
retVal.name = "Tuple"
|
|
retVal.value = v.([]*IValue)
|
|
case val.Type().String() == "[]*ts.IValue" && val.Len() > 2, val.Type().String() == "[]*ts.IValue" && val.Len() == 1:
|
|
retVal.kind = GenericListVal
|
|
retVal.name = "GenericList"
|
|
default:
|
|
log.Fatalf("NewIValue method call - 'slice -> struct' case - Unsupported type (%v)\n", val.Type().String())
|
|
}
|
|
}
|
|
case "map":
|
|
// TODO: exclude map of type other than IValue type
|
|
retVal.kind = GenericDictVal
|
|
retVal.name = "GenericDict"
|
|
case "ptr":
|
|
val := reflect.Indirect(reflect.ValueOf(v))
|
|
fieldName := val.Type().Field(2).Name
|
|
switch fieldName {
|
|
case "ctensor":
|
|
retVal.kind = TensorVal
|
|
retVal.name = "Tensor"
|
|
default:
|
|
log.Fatalf("NewIValue method call - 'struct' case - Unsupported type (%v)\n", reflect.TypeOf(v).Kind().String())
|
|
}
|
|
default:
|
|
log.Fatalf("NewIValue method call - Unsupported type (%v)\n", reflect.TypeOf(v).Kind().String())
|
|
}
|
|
|
|
return retVal
|
|
}
|
|
|
|
// IValue methods:
|
|
// ===============
|
|
|
|
func (iv *IValue) ToCIValue() (*CIValue, error) {
|
|
switch iv.name {
|
|
case "None":
|
|
cval := lib.AtiNone()
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &CIValue{civalue: cval}, nil
|
|
|
|
case "Tensor":
|
|
cval := lib.AtiTensor(iv.value.(*Tensor).ctensor)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &CIValue{civalue: cval}, nil
|
|
|
|
case "Int":
|
|
cval := lib.AtiInt(iv.value.(int64))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &CIValue{civalue: cval}, nil
|
|
|
|
case "Double":
|
|
cval := lib.AtiDouble(iv.value.(float64))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &CIValue{civalue: cval}, nil
|
|
|
|
case "Bool":
|
|
cval := lib.AtiBool(iv.value.(bool))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &CIValue{civalue: cval}, nil
|
|
|
|
case "Tuple":
|
|
val := reflect.Indirect(reflect.ValueOf(iv.value))
|
|
switch {
|
|
// 1. Tuple is (*Tensor, *Tensor)
|
|
case val.Type() == reflect.TypeOf([]Tensor{}):
|
|
var v []Tensor = iv.value.([]Tensor)
|
|
var cvals []lib.Civalue
|
|
for _, tensor := range v {
|
|
ival := NewIValue(tensor)
|
|
cval, err := ival.ToCIValue()
|
|
if err != nil {
|
|
err = fmt.Errorf("ToCIValue method call err - Tuple case: %v\n", err)
|
|
return nil, err
|
|
}
|
|
cvals = append(cvals, cval.civalue)
|
|
}
|
|
|
|
tuple := lib.AtiTuple(cvals, len(cvals))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &CIValue{civalue: tuple}, nil
|
|
|
|
// 2. Tuple is (*IValue, *IValue)
|
|
default:
|
|
var v []*IValue = iv.value.([]*IValue)
|
|
var cvals []lib.Civalue
|
|
for _, i := range v {
|
|
cval, err := i.ToCIValue()
|
|
if err != nil {
|
|
err = fmt.Errorf("ToCIValue method call err - Tuple case: %v\n", err)
|
|
return nil, err
|
|
}
|
|
cvals = append(cvals, cval.civalue)
|
|
}
|
|
|
|
tuple := lib.AtiTuple(cvals, len(cvals))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &CIValue{civalue: tuple}, nil
|
|
}
|
|
|
|
case "GenericList":
|
|
// GenericList can be: string, int, int32, float32
|
|
// TODO: refactor to function
|
|
// NOTE: atm, `GenericList` are all unsupported cases
|
|
var cvals []lib.Civalue
|
|
vtyp := reflect.TypeOf(iv.value).Elem().Kind().String()
|
|
switch vtyp {
|
|
case "string":
|
|
var v []string = iv.value.([]string)
|
|
for _, i := range v {
|
|
ival := NewIValue(i)
|
|
cval, err := ival.ToCIValue()
|
|
if err != nil {
|
|
err = fmt.Errorf("ToCIValue method call err - GenericList case: %v\n", err)
|
|
return nil, err
|
|
}
|
|
cvals = append(cvals, cval.civalue)
|
|
}
|
|
|
|
case "int":
|
|
var v []int = iv.value.([]int)
|
|
for _, i := range v {
|
|
ival := NewIValue(i)
|
|
cval, err := ival.ToCIValue()
|
|
if err != nil {
|
|
log.Fatalf("ToCIValue method call err - int case: %v\n", err)
|
|
}
|
|
cvals = append(cvals, cval.civalue)
|
|
}
|
|
case "int32":
|
|
var v []int32 = iv.value.([]int32)
|
|
for _, i := range v {
|
|
ival := NewIValue(i)
|
|
cval, err := ival.ToCIValue()
|
|
if err != nil {
|
|
log.Fatalf("ToCIValue method call err - int32 case: %v\n", err)
|
|
}
|
|
cvals = append(cvals, cval.civalue)
|
|
}
|
|
case "float32":
|
|
var v []float32 = iv.value.([]float32)
|
|
for _, i := range v {
|
|
ival := NewIValue(i)
|
|
cval, err := ival.ToCIValue()
|
|
if err != nil {
|
|
log.Fatalf("ToCIValue method call err - float32 case: %v\n", err)
|
|
}
|
|
cvals = append(cvals, cval.civalue)
|
|
}
|
|
default:
|
|
log.Fatalf("ToCIValue method call err - Default case: Unsupport type (%v)\n", vtyp)
|
|
|
|
}
|
|
|
|
list := lib.AtiGenericList(cvals, len(cvals))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &CIValue{civalue: list}, nil
|
|
|
|
case "IntList":
|
|
var vals []int64 = iv.value.([]int64)
|
|
cval := lib.AtiIntList(vals, len(vals))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &CIValue{civalue: cval}, nil
|
|
|
|
case "DoubleList":
|
|
var vals []float64 = iv.value.([]float64)
|
|
cval := lib.AtiDoubleList(vals, len(vals))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &CIValue{civalue: cval}, nil
|
|
|
|
case "BoolList":
|
|
var vals []bool = iv.value.([]bool)
|
|
cval := lib.AtiBoolList(vals, len(vals))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &CIValue{civalue: cval}, nil
|
|
|
|
case "TensorList":
|
|
var vals []*Tensor = iv.value.([]*Tensor)
|
|
var cvals []lib.Ctensor
|
|
for _, i := range vals {
|
|
cvals = append(cvals, i.ctensor)
|
|
}
|
|
list := lib.AtiTensorList(cvals, len(cvals))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &CIValue{civalue: list}, nil
|
|
|
|
case "String":
|
|
cval := lib.AtiString(iv.value.(string))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &CIValue{civalue: cval}, nil
|
|
|
|
case "GenericDict":
|
|
var cvals []lib.Civalue
|
|
keyType := reflect.TypeOf(iv.value).Key().Kind().String()
|
|
valType := reflect.TypeOf(iv.value).Elem().Kind().String()
|
|
|
|
// 1. Create key and value lists seperately
|
|
switch {
|
|
case keyType == "int64" && valType == "int64":
|
|
var m map[int64]int64 = iv.value.(map[int64]int64)
|
|
var vals []int64
|
|
for k, v := range m {
|
|
vals = append(vals, k, v)
|
|
}
|
|
for _, v := range vals {
|
|
ival := NewIValue(v)
|
|
cval, err := ival.ToCIValue()
|
|
if err != nil {
|
|
log.Fatalf("ToCIValue method call err - GenericDict case: %v\n", err)
|
|
}
|
|
cvals = append(cvals, cval.civalue)
|
|
}
|
|
|
|
case keyType == "float64" && valType == "float64":
|
|
var m map[float64]float64 = iv.value.(map[float64]float64)
|
|
var vals []float64
|
|
for k, v := range m {
|
|
vals = append(vals, k, v)
|
|
}
|
|
for _, v := range vals {
|
|
ival := NewIValue(v)
|
|
cval, err := ival.ToCIValue()
|
|
if err != nil {
|
|
log.Fatalf("ToCIValue method call err - GenericDict case: %v\n", err)
|
|
}
|
|
cvals = append(cvals, cval.civalue)
|
|
}
|
|
|
|
case keyType == "float32" && valType == "float32":
|
|
var m map[float32]float32 = iv.value.(map[float32]float32)
|
|
var vals []float32
|
|
for k, v := range m {
|
|
vals = append(vals, k, v)
|
|
}
|
|
for _, v := range vals {
|
|
ival := NewIValue(v)
|
|
cval, err := ival.ToCIValue()
|
|
if err != nil {
|
|
log.Fatalf("ToCIValue method call err - GenericDict case: %v\n", err)
|
|
}
|
|
cvals = append(cvals, cval.civalue)
|
|
}
|
|
|
|
// TODO: map[int64]Tensor
|
|
// TODO: map[float64]Tensor
|
|
// TODO: map[string]Tensor
|
|
// TODO: map[bool]Tensor
|
|
// ...
|
|
|
|
default:
|
|
log.Fatalf("ToCIValue method call - GenericDict case: unsupported key type(%v) or value type(%v) \n", keyType, valType)
|
|
}
|
|
|
|
// 2. Pairing key and value in a slice (cvals)
|
|
dict := lib.AtiGenericDict(cvals, len(cvals)/2)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &CIValue{civalue: dict}, nil
|
|
|
|
case "Generic":
|
|
err := fmt.Errorf("ToCIValue method call - Generic case: unsupport type(%v)\n", reflect.TypeOf(iv.value).Kind().String())
|
|
return nil, err
|
|
|
|
default:
|
|
err := fmt.Errorf("ToCIValue method call - Generic case: unsupport type(%v)\n", reflect.TypeOf(iv.value).Kind().String())
|
|
return nil, err
|
|
}
|
|
|
|
panic("Shouldn't reached here.")
|
|
}
|
|
|
|
// IValueFromC returns an IValue from given CIValue.
|
|
//
|
|
// It consumes the pointer and frees the associated memory.
|
|
func IValueFromC(cval *CIValue) (*IValue, error) {
|
|
// tag will be a value of int32
|
|
tag := lib.AtiTag(cval.civalue)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch tag {
|
|
case 0:
|
|
return &IValue{
|
|
value: nil,
|
|
kind: NoneVal,
|
|
name: "None",
|
|
}, nil
|
|
case 1:
|
|
tensor := lib.AtiToTensor(cval.civalue)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &IValue{
|
|
value: newTensor(tensor),
|
|
kind: TensorVal,
|
|
name: "Tensor",
|
|
}, nil
|
|
case 2:
|
|
v := lib.AtiToDouble(cval.civalue)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &IValue{
|
|
value: v,
|
|
kind: DoubleVal,
|
|
name: "Double",
|
|
}, nil
|
|
case 3:
|
|
v := lib.AtiToInt(cval.civalue)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &IValue{
|
|
value: v,
|
|
kind: IntVal,
|
|
name: "Int",
|
|
}, nil
|
|
|
|
case 4:
|
|
v := lib.AtiToBool(cval.civalue)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &IValue{
|
|
value: v,
|
|
kind: BoolVal,
|
|
name: "Bool",
|
|
}, nil
|
|
|
|
case 5: // Tuple []IValue 2 elements
|
|
// 1. Determine tuple length
|
|
len := lib.AtiTupleLength(cval.civalue)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
// 2. Call with first pointer and length
|
|
ptr1 := (*lib.Civalue)(unsafe.Pointer(C.malloc(0)))
|
|
lib.AtiToTuple(cval.civalue, ptr1, int(len))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 3. Get list of Civalue tuple elements
|
|
var civalues []CIValue
|
|
civalues = append(civalues, CIValue{civalue: *ptr1})
|
|
currPtr := ptr1
|
|
for i := 1; i < int(len); i++ {
|
|
nextPtr := (*lib.Civalue)(unsafe.Pointer(uintptr(unsafe.Pointer(currPtr)) + unsafe.Sizeof(ptr1)))
|
|
civalues = append(civalues, CIValue{civalue: *nextPtr})
|
|
currPtr = nextPtr
|
|
}
|
|
|
|
// 4. Get Ivalue from Civalue for each tuple element
|
|
// Determine element kind
|
|
v, err := IValueFromC(&civalues[0])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
elemName := v.Name()
|
|
switch elemName {
|
|
case "Tensor":
|
|
var vals []*Tensor
|
|
for _, civalue := range civalues {
|
|
v, err := IValueFromC(&civalue)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
vals = append(vals, v.Value().(*Tensor))
|
|
}
|
|
if len == 2 {
|
|
return &IValue{
|
|
value: vals,
|
|
kind: TensorListVal,
|
|
name: "Tuple",
|
|
}, nil
|
|
} else {
|
|
return &IValue{
|
|
value: vals,
|
|
kind: TensorListVal,
|
|
name: "TensorList",
|
|
}, nil
|
|
}
|
|
case "IntList":
|
|
var vals []int64
|
|
for _, civalue := range civalues {
|
|
v, err := IValueFromC(&civalue)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
vals = append(vals, v.Value().(int64))
|
|
}
|
|
return &IValue{
|
|
value: vals,
|
|
kind: IntListVal,
|
|
name: "IntList",
|
|
}, nil
|
|
case "BoolList":
|
|
var vals []bool
|
|
for _, civalue := range civalues {
|
|
v, err := IValueFromC(&civalue)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
vals = append(vals, v.Value().(bool))
|
|
}
|
|
return &IValue{
|
|
value: vals,
|
|
kind: BoolListVal,
|
|
name: "BoolList",
|
|
}, nil
|
|
case "DoubleList":
|
|
var vals []float64
|
|
for _, civalue := range civalues {
|
|
v, err := IValueFromC(&civalue)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
vals = append(vals, v.Value().(float64))
|
|
}
|
|
return &IValue{
|
|
value: vals,
|
|
kind: DoubleListVal,
|
|
name: "DoubleList",
|
|
}, nil
|
|
|
|
default:
|
|
var vals []interface{}
|
|
for _, civalue := range civalues {
|
|
v, err := IValueFromC(&civalue)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
vals = append(vals, v)
|
|
}
|
|
|
|
return &IValue{
|
|
value: vals,
|
|
kind: TupleVal,
|
|
name: "Tuple",
|
|
}, nil
|
|
}
|
|
|
|
case 6: // IntList
|
|
// 1. Len
|
|
len := lib.AtiLength(cval.civalue)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 2. Call
|
|
ptr1 := unsafe.Pointer(C.malloc(0))
|
|
lib.AtiToIntList(cval.civalue, ptr1, int(len))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 3. Get int list
|
|
var intVals []int64
|
|
intVals = append(intVals, *(*int64)(unsafe.Pointer(ptr1)))
|
|
currPtr := ptr1
|
|
for i := 1; i < int(len); i++ {
|
|
nextPtr := unsafe.Pointer(uintptr(unsafe.Pointer(currPtr)) + unsafe.Sizeof(ptr1))
|
|
intVals = append(intVals, *(*int64)(unsafe.Pointer(nextPtr)))
|
|
currPtr = nextPtr
|
|
}
|
|
|
|
return &IValue{
|
|
value: intVals,
|
|
kind: IntListVal,
|
|
name: "IntList",
|
|
}, nil
|
|
|
|
case 7: // DoubleList
|
|
// 1. Len
|
|
len := lib.AtiLength(cval.civalue)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 2. Call
|
|
ptr1 := unsafe.Pointer(C.malloc(0))
|
|
lib.AtiToDoubleList(cval.civalue, ptr1, int(len))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 3. Get int list
|
|
var floatVals []float64
|
|
floatVals = append(floatVals, *(*float64)(unsafe.Pointer(ptr1)))
|
|
currPtr := ptr1
|
|
for i := 1; i < int(len); i++ {
|
|
nextPtr := unsafe.Pointer(uintptr(unsafe.Pointer(currPtr)) + unsafe.Sizeof(ptr1))
|
|
floatVals = append(floatVals, *(*float64)(unsafe.Pointer(nextPtr)))
|
|
currPtr = nextPtr
|
|
}
|
|
|
|
return &IValue{
|
|
value: floatVals,
|
|
kind: DoubleListVal,
|
|
name: "DoubleList",
|
|
}, nil
|
|
|
|
case 8: // BoolList
|
|
// 1. Len
|
|
len := lib.AtiLength(cval.civalue)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 2. Call
|
|
ptr1 := unsafe.Pointer(C.malloc(0))
|
|
lib.AtiToBoolList(cval.civalue, ptr1, int(len))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 3. Get values
|
|
var vals []int32
|
|
var bvals []bool
|
|
vals = append(vals, *(*int32)(unsafe.Pointer(ptr1)))
|
|
currPtr := ptr1
|
|
for i := 1; i < int(len); i++ {
|
|
nextPtr := unsafe.Pointer(uintptr(unsafe.Pointer(currPtr)) + unsafe.Sizeof(ptr1))
|
|
vals = append(vals, *(*int32)(unsafe.Pointer(nextPtr)))
|
|
currPtr = nextPtr
|
|
}
|
|
|
|
for _, i := range vals {
|
|
bval := false
|
|
if i == 1 {
|
|
bval = true
|
|
}
|
|
bvals = append(bvals, bval)
|
|
}
|
|
|
|
return &IValue{
|
|
value: bvals,
|
|
kind: BoolListVal,
|
|
name: "BoolList",
|
|
}, nil
|
|
|
|
case 9: // String
|
|
v := lib.AtiToString(cval.civalue)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &IValue{
|
|
value: v,
|
|
kind: StringVal,
|
|
name: "String",
|
|
}, nil
|
|
|
|
case 10: // TensorList
|
|
// 1. Len
|
|
len := lib.AtiLength(cval.civalue)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 2. Call
|
|
ptr1 := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
lib.AtiToTensorList(cval.civalue, ptr1, int(len))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 3. Get values
|
|
var tensors []*Tensor
|
|
tensors = append(tensors, newTensor(*ptr1))
|
|
currPtr := ptr1
|
|
for i := 1; i < int(len); i++ {
|
|
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currPtr)) + unsafe.Sizeof(ptr1)))
|
|
tensors = append(tensors, newTensor(*nextPtr))
|
|
currPtr = nextPtr
|
|
}
|
|
|
|
return &IValue{
|
|
value: tensors,
|
|
kind: TensorListVal,
|
|
name: "TensorList",
|
|
}, nil
|
|
|
|
case 12: // GenericList []IValue
|
|
// NOTE: atm, all these cases are unsupported.
|
|
// 1. Len
|
|
len := lib.AtiLength(cval.civalue)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
// 2. Call with first pointer and length
|
|
ptr1 := (*lib.Civalue)(unsafe.Pointer(C.malloc(0)))
|
|
lib.AtiToGenericList(cval.civalue, ptr1, int(len))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 3. Get values
|
|
var civalues []CIValue
|
|
civalues = append(civalues, CIValue{civalue: *ptr1})
|
|
currPtr := ptr1
|
|
for i := 1; i < int(len); i++ {
|
|
nextPtr := (*lib.Civalue)(unsafe.Pointer(uintptr(unsafe.Pointer(currPtr)) + unsafe.Sizeof(ptr1)))
|
|
civalues = append(civalues, CIValue{civalue: *nextPtr})
|
|
currPtr = nextPtr
|
|
}
|
|
|
|
// 4. Get Ivalue from Civalue for each tuple element
|
|
var vals []interface{}
|
|
var itemTyp string
|
|
for _, civalue := range civalues {
|
|
v, err := IValueFromC(&civalue)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
itemTyp = reflect.TypeOf(v.value).Kind().String()
|
|
vals = append(vals, v.value)
|
|
}
|
|
|
|
switch itemTyp {
|
|
case "string":
|
|
var specVals []string
|
|
for _, v := range vals {
|
|
specVals = append(specVals, v.(string))
|
|
}
|
|
return &IValue{
|
|
value: specVals,
|
|
kind: GenericListVal,
|
|
name: "GenericList",
|
|
}, nil
|
|
case "int":
|
|
var specVals []int
|
|
for _, v := range vals {
|
|
specVals = append(specVals, v.(int))
|
|
}
|
|
return &IValue{
|
|
value: specVals,
|
|
kind: GenericListVal,
|
|
name: "GenericList",
|
|
}, nil
|
|
case "int32":
|
|
var specVals []int32
|
|
for _, v := range vals {
|
|
specVals = append(specVals, v.(int32))
|
|
}
|
|
return &IValue{
|
|
value: vals,
|
|
kind: GenericListVal,
|
|
name: "GenericList",
|
|
}, nil
|
|
case "float32":
|
|
var specVals []float32
|
|
for _, v := range vals {
|
|
specVals = append(specVals, v.(float32))
|
|
}
|
|
return &IValue{
|
|
value: vals,
|
|
kind: GenericListVal,
|
|
name: "GenericList",
|
|
}, nil
|
|
|
|
default:
|
|
log.Fatalf("IValueFromC method call - GenericList case: Unsupported item type (%v)\n", itemTyp)
|
|
}
|
|
|
|
case 13: // GenericDict map[IValue]IValue
|
|
// 1. Len
|
|
numVals := lib.AtiLength(cval.civalue)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
// 2. Call with first pointer and length
|
|
ptr1 := (*lib.Civalue)(unsafe.Pointer(C.malloc(0)))
|
|
lib.AtiToGenericDict(cval.civalue, ptr1, int(numVals))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 3. Get values
|
|
|
|
// TODO: Need to drill down a specific type
|
|
var civalues []CIValue
|
|
civalues = append(civalues, CIValue{civalue: *ptr1})
|
|
currPtr := ptr1
|
|
for i := 1; i < int(numVals)*2; i++ {
|
|
nextPtr := (*lib.Civalue)(unsafe.Pointer(uintptr(unsafe.Pointer(currPtr)) + unsafe.Sizeof(ptr1)))
|
|
civalues = append(civalues, CIValue{civalue: *nextPtr})
|
|
currPtr = nextPtr
|
|
}
|
|
|
|
// 4. Get Ivalue from Civalue for each element
|
|
var vals []interface{}
|
|
var itemTyp string
|
|
for _, civalue := range civalues {
|
|
v, err := IValueFromC(&civalue)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
itemTyp = reflect.TypeOf(v.value).Kind().String()
|
|
vals = append(vals, v.value)
|
|
}
|
|
|
|
switch itemTyp {
|
|
case "string":
|
|
var specVals map[string]string = make(map[string]string)
|
|
for i := 0; i < len(vals); i += 2 {
|
|
specVals[vals[i].(string)] = vals[i+1].(string)
|
|
}
|
|
return &IValue{
|
|
value: specVals,
|
|
kind: GenericDictVal,
|
|
name: "GenericDict",
|
|
}, nil
|
|
case "int":
|
|
var specVals map[int]int = make(map[int]int)
|
|
for i := 0; i < len(vals); i += 2 {
|
|
specVals[vals[i].(int)] = vals[i+1].(int)
|
|
}
|
|
return &IValue{
|
|
value: specVals,
|
|
kind: GenericDictVal,
|
|
name: "GenericDict",
|
|
}, nil
|
|
case "int32":
|
|
var specVals map[int32]int32 = make(map[int32]int32)
|
|
for i := 0; i < len(vals); i += 2 {
|
|
specVals[vals[i].(int32)] = vals[i+1].(int32)
|
|
}
|
|
return &IValue{
|
|
value: specVals,
|
|
kind: GenericDictVal,
|
|
name: "GenericDict",
|
|
}, nil
|
|
case "int64":
|
|
var specVals map[int64]int64 = make(map[int64]int64)
|
|
for i := 0; i < len(vals); i += 2 {
|
|
specVals[vals[i].(int64)] = vals[i+1].(int64)
|
|
}
|
|
return &IValue{
|
|
value: specVals,
|
|
kind: GenericDictVal,
|
|
name: "GenericDict",
|
|
}, nil
|
|
case "float32":
|
|
var specVals map[float32]float32 = make(map[float32]float32)
|
|
for i := 0; i < len(vals); i += 2 {
|
|
specVals[vals[i].(float32)] = vals[i+1].(float32)
|
|
}
|
|
return &IValue{
|
|
value: specVals,
|
|
kind: GenericDictVal,
|
|
name: "GenericDict",
|
|
}, nil
|
|
case "float64":
|
|
var specVals map[float64]float64 = make(map[float64]float64)
|
|
for i := 0; i < len(vals); i += 2 {
|
|
specVals[vals[i].(float64)] = vals[i+1].(float64)
|
|
}
|
|
return &IValue{
|
|
value: specVals,
|
|
kind: GenericDictVal,
|
|
name: "GenericDict",
|
|
}, nil
|
|
}
|
|
|
|
default:
|
|
err := fmt.Errorf("IValueFromC - Unsupported type (tag value: %v)\n", tag)
|
|
return nil, err
|
|
}
|
|
|
|
panic("Shouldn't reach here.")
|
|
}
|
|
|
|
func (iv *IValue) Value() interface{} {
|
|
return iv.value
|
|
}
|
|
|
|
func (iv *IValue) Name() string {
|
|
return iv.name
|
|
}
|
|
|
|
func (iv *IValue) Kind() IValueKind {
|
|
return iv.kind
|
|
}
|
|
|
|
// A JIT PyTorch module.
|
|
//
|
|
// These modules can be created via the
|
|
// [TorchScript python api](https://pytorch.org/docs/stable/jit.html).
|
|
type CModule struct {
|
|
Cmodule lib.Cmodule
|
|
}
|
|
|
|
func (cm *CModule) Drop() {
|
|
lib.AtmFree(cm.Cmodule)
|
|
if err := TorchErr(); err != nil {
|
|
log.Fatalf("CModule Drop method err: %v\n", err)
|
|
}
|
|
}
|
|
|
|
// Loads a PyTorch saved JIT model from a file.
|
|
func ModuleLoad(path string) (*CModule, error) {
|
|
cmodule := lib.AtmLoad(path)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &CModule{cmodule}, nil
|
|
|
|
}
|
|
|
|
// Loads a PyTorch saved JIT model from a file onto the given device.
|
|
//
|
|
// This function loads the model directly on the specified device,
|
|
// which means it also allows loading a GPU model on the CPU without having a CUDA enabled GPU.
|
|
func ModuleLoadOnDevice(path string, device gotch.Device) (*CModule, error) {
|
|
cmodule := lib.AtmLoadOnDevice(path, device.CInt())
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &CModule{cmodule}, nil
|
|
}
|
|
|
|
// Loads a PyTorch saved JIT model from a read instance.
|
|
func ModuleLoadData(stream io.Reader) (*CModule, error) {
|
|
|
|
buf := new(bytes.Buffer)
|
|
buf.ReadFrom(stream)
|
|
|
|
bufString := buf.String()
|
|
|
|
cmodule := lib.AtmLoadStr(bufString, len(bufString))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &CModule{cmodule}, nil
|
|
|
|
}
|
|
|
|
// Loads a PyTorch saved JIT model from a read instance.
|
|
//
|
|
// This function loads the model directly on the specified device,
|
|
// which means it also allows loading a GPU model on the CPU without having a CUDA enabled GPU.
|
|
func ModuleLoadDataOnDevice(stream io.Reader, device gotch.Device) (*CModule, error) {
|
|
buf := new(bytes.Buffer)
|
|
buf.ReadFrom(stream)
|
|
|
|
bufString := buf.String()
|
|
|
|
cmodule := lib.AtmLoadStrOnDevice(bufString, len(bufString), device.CInt())
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &CModule{cmodule}, nil
|
|
}
|
|
|
|
// ForwardTs performs the forward pass for a model on some specified tensor inputs.
|
|
func (cm *CModule) ForwardTs(tensors []*Tensor) (*Tensor, error) {
|
|
var ctensors []lib.Ctensor
|
|
for _, t := range tensors {
|
|
ctensors = append(ctensors, t.ctensor)
|
|
}
|
|
|
|
// NOTE: Write a slice of ctensors to C memory and get the pointer
|
|
// 1. Calculate buffer size
|
|
cptrSize := int(unsafe.Sizeof(ctensors[0])) // 8 bytes
|
|
nbytes := cptrSize * len(ctensors)
|
|
dataPtr := C.malloc(C.size_t(nbytes))
|
|
defer C.free(dataPtr)
|
|
dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes]
|
|
|
|
// 2. Convert C pointers to []byte
|
|
var data []byte
|
|
for _, ctensor := range ctensors {
|
|
b := make([]byte, cptrSize)
|
|
u := uintptr(unsafe.Pointer(ctensor))
|
|
switch cptrSize {
|
|
case 4:
|
|
binary.LittleEndian.PutUint32(b, uint32(u))
|
|
case 8:
|
|
binary.LittleEndian.PutUint64(b, uint64(u))
|
|
default:
|
|
panic(fmt.Sprintf("unknown uintptr size: %v", cptrSize))
|
|
}
|
|
|
|
data = append(data, b...)
|
|
}
|
|
|
|
// 3. Copy data to buffer
|
|
copy(dataSlice[:], data)
|
|
|
|
// 4. Call C func with slice data pointer and number of ctensor pointers
|
|
// NOTE:
|
|
// - `dataPtr` is the pointer to slice of ctensor pointers
|
|
// - `nsize` is number of ctensor pointers encoded in binary data.
|
|
ctensorsPtr := (*lib.Ctensor)(dataPtr)
|
|
ctensor := lib.AtmForward(cm.Cmodule, ctensorsPtr, len(ctensors))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return newTensor(ctensor), nil
|
|
}
|
|
|
|
// ForwardIs performs the forward pass for a model on some specified ivalue input.
|
|
func (cm *CModule) ForwardIs(ivalues []*IValue) (*IValue, error) {
|
|
|
|
var civalues []lib.Civalue
|
|
for _, i := range ivalues {
|
|
civalue, err := i.ToCIValue()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
civalues = append(civalues, civalue.civalue)
|
|
}
|
|
|
|
// NOTE: Write a slice of civalues to C memory and get the pointer
|
|
// 1. Calculate buffer size
|
|
cptrSize := int(unsafe.Sizeof(civalues[0])) // 8 bytes
|
|
nbytes := cptrSize * len(civalues)
|
|
dataPtr := C.malloc(C.size_t(nbytes))
|
|
defer C.free(dataPtr)
|
|
dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes]
|
|
|
|
// 2. Convert C pointers to []byte
|
|
var data []byte
|
|
for _, civalue := range civalues {
|
|
b := make([]byte, cptrSize)
|
|
u := uintptr(unsafe.Pointer(civalue))
|
|
switch cptrSize {
|
|
case 4:
|
|
binary.LittleEndian.PutUint32(b, uint32(u))
|
|
case 8:
|
|
binary.LittleEndian.PutUint64(b, uint64(u))
|
|
default:
|
|
panic(fmt.Sprintf("unknown uintptr size: %v", cptrSize))
|
|
}
|
|
|
|
data = append(data, b...)
|
|
}
|
|
|
|
// 3. Copy data to buffer
|
|
copy(dataSlice[:], data)
|
|
|
|
// 4. Call C func with slice data pointer and number of civalue pointers
|
|
// NOTE:
|
|
// - `dataPtr` is the pointer to slice of civalue pointers
|
|
// - `nsize` is number of civalue pointers encoded in binary data.
|
|
civaluesPtr := (*lib.Civalue)(dataPtr)
|
|
|
|
civ := lib.AtmForward_(cm.Cmodule, civaluesPtr, len(civalues))
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return IValueFromC(&CIValue{civ})
|
|
}
|
|
|
|
// To moves CModule to specified device.
|
|
func (cm *CModule) To(device gotch.Device, kind gotch.DType, nonBlocking bool) {
|
|
lib.AtmTo(cm.Cmodule, device.CInt(), kind.CInt(), nonBlocking)
|
|
if err := TorchErr(); err != nil {
|
|
log.Fatalf("CModule To method call err: %v\n", err)
|
|
}
|
|
}
|
|
|
|
// Save save CModule to a specified path.
|
|
func (cm *CModule) Save(file string) error {
|
|
lib.AtmSave(cm.Cmodule, file)
|
|
return TorchErr()
|
|
}
|
|
|
|
// NamedParameters loads some named tensors from a module.
|
|
func (cm *CModule) NamedParameters() ([]NamedTensor, error) {
|
|
var data lib.LoadData
|
|
dataPtr := lib.PStore.Set(&data)
|
|
lib.AtmNamedParameters(cm.Cmodule, dataPtr)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var namedTensors []NamedTensor
|
|
for _, v := range data.NamedCtensors {
|
|
namedTensor := NamedTensor{
|
|
Name: v.Name,
|
|
Tensor: newTensor(v.Ctensor),
|
|
}
|
|
|
|
namedTensors = append(namedTensors, namedTensor)
|
|
}
|
|
|
|
return namedTensors, nil
|
|
}
|
|
|
|
// GetProfilingMode get CModule profiling mode
|
|
func (cm *CModule) GetProfilingMode() bool {
|
|
retVal := lib.AtmGetProfilingMode()
|
|
if err := TorchErr(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return retVal
|
|
}
|
|
|
|
// SetProfilingMode set CModule profiling mode
|
|
func (cm *CModule) SetProfilingMode(b bool) {
|
|
lib.AtmSetProfilingMode(b)
|
|
if err := TorchErr(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
// SetTrain set CModule to train mode
|
|
func (cm *CModule) SetTrain() {
|
|
lib.AtmTrain(cm.Cmodule)
|
|
if err := TorchErr(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
// SetEval set CModule to inference mode
|
|
func (cm *CModule) SetEval() {
|
|
lib.AtmEval(cm.Cmodule)
|
|
if err := TorchErr(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
// Implement Module for CModule:
|
|
// =============================
|
|
|
|
// Forwad implements Module interface for CModule.
|
|
func (cm *CModule) Forward(tensor *Tensor) (*Tensor, error) {
|
|
|
|
var tensors []*Tensor = []*Tensor{tensor}
|
|
return cm.ForwardTs(tensors)
|
|
}
|
|
|
|
// Tensor methods for CModule:
|
|
// ======================================
|
|
|
|
// Apply forwards tensor itself through a module.
|
|
func (ts *Tensor) ApplyCModule(m *CModule) *Tensor {
|
|
retVal, err := m.Forward(ts)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return retVal
|
|
}
|