changed(tensor/jit): changed to receiver pointer
This commit is contained in:
parent
9fe5cef70a
commit
c7a8c4b348
25
nn/jit.go
Normal file
25
nn/jit.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package nn
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// TrainableCModule is a trainable version of JIT Pytorch module
|
||||
//
|
||||
// These modules can be created via TorchScript python API.
|
||||
// See: https://pytorch.org/docs/stable/jit.html
|
||||
type TrainableCModule struct {
|
||||
Inner ts.CModule
|
||||
}
|
||||
|
||||
// TrainableModuleLoad loads a PyTorch saved JIT module from a file and adds
|
||||
// tensors (weights) to `varstore` so that module can be trained.
|
||||
/*
|
||||
* func TrainableModuleLoad(p *Path, file string) (*TrainableCModule, error) {
|
||||
*
|
||||
* inner, err := ts.ModuleLoadOnDevice(file, p.Device())
|
||||
* if err != nil {
|
||||
* return nil, err
|
||||
* }
|
||||
* }
|
||||
* */
|
|
@ -659,30 +659,3 @@ func (e *Entry) OrZerosNoTrain(dims []int64) *ts.Tensor {
|
|||
z := ts.MustZeros(dims, gotch.Float, e.path.Device())
|
||||
return e.path.getOrAddWithLock(e.name, z, true, *e.variables)
|
||||
}
|
||||
|
||||
// TODO: can we implement `Div` operator in Go?
|
||||
// NOTE: `Rhs` (right hand side) is a generic type parameter
|
||||
// If not given, it will be default to `self` type
|
||||
/*
|
||||
* impl<'a, T> Div<T> for &'a mut Path<'a>
|
||||
* where
|
||||
* T: std::string::ToString,
|
||||
* {
|
||||
* type Output = Path<'a>;
|
||||
*
|
||||
* fn div(self, rhs: T) -> Self::Output {
|
||||
* self.sub(rhs.to_string())
|
||||
* }
|
||||
* }
|
||||
*
|
||||
* impl<'a, T> Div<T> for &'a Path<'a>
|
||||
* where
|
||||
* T: std::string::ToString,
|
||||
* {
|
||||
* type Output = Path<'a>;
|
||||
*
|
||||
* fn div(self, rhs: T) -> Self::Output {
|
||||
* self.sub(rhs.to_string())
|
||||
* }
|
||||
* }
|
||||
* */
|
||||
|
|
368
tensor/jit.go
368
tensor/jit.go
|
@ -50,9 +50,9 @@ type IValue struct {
|
|||
}
|
||||
|
||||
// NewIValue creates a new IValue from given value of various types.
|
||||
func NewIValue(v interface{}) (retVal IValue) {
|
||||
func NewIValue(v interface{}) *IValue {
|
||||
|
||||
retVal = IValue{value: v}
|
||||
retVal := &IValue{value: v}
|
||||
if v == nil {
|
||||
retVal.kind = NoneVal
|
||||
retVal.name = "None"
|
||||
|
@ -132,7 +132,7 @@ func NewIValue(v interface{}) (retVal IValue) {
|
|||
var ivals []IValue
|
||||
for _, tensor := range tensors {
|
||||
ival := NewIValue(tensor)
|
||||
ivals = append(ivals, ival)
|
||||
ivals = append(ivals, *ival)
|
||||
}
|
||||
retVal.value = ivals
|
||||
|
||||
|
@ -167,45 +167,45 @@ func NewIValue(v interface{}) (retVal IValue) {
|
|||
// IValue methods:
|
||||
// ===============
|
||||
|
||||
func (iv IValue) ToCIValue() (retVal CIValue, err error) {
|
||||
func (iv *IValue) ToCIValue() (*CIValue, error) {
|
||||
|
||||
switch iv.name {
|
||||
case "None":
|
||||
cval := lib.AtiNone()
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return CIValue{civalue: cval}, nil
|
||||
return &CIValue{civalue: cval}, nil
|
||||
|
||||
case "Tensor":
|
||||
cval := lib.AtiTensor(iv.value.(Tensor).ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CIValue{civalue: cval}, nil
|
||||
return &CIValue{civalue: cval}, nil
|
||||
|
||||
case "Int":
|
||||
cval := lib.AtiInt(iv.value.(int64))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return CIValue{civalue: cval}, nil
|
||||
return &CIValue{civalue: cval}, nil
|
||||
|
||||
case "Double":
|
||||
cval := lib.AtiDouble(iv.value.(float64))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CIValue{civalue: cval}, nil
|
||||
return &CIValue{civalue: cval}, nil
|
||||
|
||||
case "Bool":
|
||||
cval := lib.AtiBool(iv.value.(bool))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CIValue{civalue: cval}, nil
|
||||
return &CIValue{civalue: cval}, nil
|
||||
|
||||
case "Tuple":
|
||||
val := reflect.Indirect(reflect.ValueOf(iv.value))
|
||||
|
@ -219,16 +219,16 @@ func (iv IValue) ToCIValue() (retVal CIValue, err error) {
|
|||
cval, err := ival.ToCIValue()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("ToCIValue method call err - Tuple case: %v\n", err)
|
||||
return retVal, err
|
||||
return nil, err
|
||||
}
|
||||
cvals = append(cvals, cval.civalue)
|
||||
}
|
||||
|
||||
tuple := lib.AtiTuple(cvals, len(cvals))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CIValue{civalue: tuple}, nil
|
||||
return &CIValue{civalue: tuple}, nil
|
||||
|
||||
// 2. Tuple is (IValue, IValue)
|
||||
default:
|
||||
|
@ -238,16 +238,16 @@ func (iv IValue) ToCIValue() (retVal CIValue, err error) {
|
|||
cval, err := i.ToCIValue()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("ToCIValue method call err - Tuple case: %v\n", err)
|
||||
return retVal, err
|
||||
return nil, err
|
||||
}
|
||||
cvals = append(cvals, cval.civalue)
|
||||
}
|
||||
|
||||
tuple := lib.AtiTuple(cvals, len(cvals))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CIValue{civalue: tuple}, nil
|
||||
return &CIValue{civalue: tuple}, nil
|
||||
}
|
||||
|
||||
case "GenericList":
|
||||
|
@ -264,7 +264,7 @@ func (iv IValue) ToCIValue() (retVal CIValue, err error) {
|
|||
cval, err := ival.ToCIValue()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("ToCIValue method call err - GenericList case: %v\n", err)
|
||||
return retVal, err
|
||||
return nil, err
|
||||
}
|
||||
cvals = append(cvals, cval.civalue)
|
||||
}
|
||||
|
@ -305,34 +305,34 @@ func (iv IValue) ToCIValue() (retVal CIValue, err error) {
|
|||
}
|
||||
|
||||
list := lib.AtiGenericList(cvals, len(cvals))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CIValue{civalue: list}, nil
|
||||
return &CIValue{civalue: list}, nil
|
||||
|
||||
case "IntList":
|
||||
var vals []int64 = iv.value.([]int64)
|
||||
cval := lib.AtiIntList(vals, len(vals))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CIValue{civalue: cval}, nil
|
||||
return &CIValue{civalue: cval}, nil
|
||||
|
||||
case "DoubleList":
|
||||
var vals []float64 = iv.value.([]float64)
|
||||
cval := lib.AtiDoubleList(vals, len(vals))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CIValue{civalue: cval}, nil
|
||||
return &CIValue{civalue: cval}, nil
|
||||
|
||||
case "BoolList":
|
||||
var vals []bool = iv.value.([]bool)
|
||||
cval := lib.AtiBoolList(vals, len(vals))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CIValue{civalue: cval}, nil
|
||||
return &CIValue{civalue: cval}, nil
|
||||
|
||||
case "TensorList":
|
||||
var vals []Tensor = iv.value.([]Tensor)
|
||||
|
@ -341,17 +341,17 @@ func (iv IValue) ToCIValue() (retVal CIValue, err error) {
|
|||
cvals = append(cvals, i.ctensor)
|
||||
}
|
||||
list := lib.AtiTensorList(cvals, len(cvals))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CIValue{civalue: list}, nil
|
||||
return &CIValue{civalue: list}, nil
|
||||
|
||||
case "String":
|
||||
cval := lib.AtiString(iv.value.(string))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CIValue{civalue: cval}, nil
|
||||
return &CIValue{civalue: cval}, nil
|
||||
|
||||
case "GenericDict":
|
||||
var cvals []lib.Civalue
|
||||
|
@ -417,92 +417,93 @@ func (iv IValue) ToCIValue() (retVal CIValue, err error) {
|
|||
|
||||
// 2. Pairing key and value in a slice (cvals)
|
||||
dict := lib.AtiGenericDict(cvals, len(cvals)/2)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CIValue{civalue: dict}, nil
|
||||
return &CIValue{civalue: dict}, nil
|
||||
|
||||
case "Generic":
|
||||
log.Fatalf("ToCIValue method call - Generic case: unsupport type(%v)\n", reflect.TypeOf(iv.value).Kind().String())
|
||||
err := fmt.Errorf("ToCIValue method call - Generic case: unsupport type(%v)\n", reflect.TypeOf(iv.value).Kind().String())
|
||||
return nil, err
|
||||
|
||||
default:
|
||||
log.Fatalf("ToCIValue method call - Generic case: unsupport type(%v)\n", reflect.TypeOf(iv.value).Kind().String())
|
||||
err := fmt.Errorf("ToCIValue method call - Generic case: unsupport type(%v)\n", reflect.TypeOf(iv.value).Kind().String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
panic("Shouldn't reached here.")
|
||||
}
|
||||
|
||||
// IValueFromC returns an IValue from given CIValue.
|
||||
//
|
||||
// It consumes the pointer and frees the associated memory.
|
||||
func IValueFromC(cval CIValue) (retVal IValue, err error) {
|
||||
|
||||
func IValueFromC(cval *CIValue) (*IValue, error) {
|
||||
// tag will be a value of int32
|
||||
tag := lib.AtiTag(cval.civalue)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch tag {
|
||||
case 0:
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: nil,
|
||||
kind: NoneVal,
|
||||
name: "None",
|
||||
}
|
||||
}, nil
|
||||
case 1:
|
||||
tensor := lib.AtiToTensor(cval.civalue)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: tensor,
|
||||
kind: TensorVal,
|
||||
name: "Tensor",
|
||||
}
|
||||
}, nil
|
||||
case 2:
|
||||
v := lib.AtiToDouble(cval.civalue)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: v,
|
||||
kind: DoubleVal,
|
||||
name: "Double",
|
||||
}
|
||||
}, nil
|
||||
case 3:
|
||||
v := lib.AtiToInt(cval.civalue)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: v,
|
||||
kind: IntVal,
|
||||
name: "Int",
|
||||
}
|
||||
}, nil
|
||||
|
||||
case 4:
|
||||
v := lib.AtiToBool(cval.civalue)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: v,
|
||||
kind: BoolVal,
|
||||
name: "Bool",
|
||||
}
|
||||
}, nil
|
||||
|
||||
case 5: // Tuple []IValue 2 elements
|
||||
// 1. Determine tuple length
|
||||
len := lib.AtiTupleLength(cval.civalue)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 2. Call with first pointer and length
|
||||
ptr1 := (*lib.Civalue)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtiToTuple(cval.civalue, ptr1, int(len))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. Get list of Civalue tuple elements
|
||||
|
@ -518,31 +519,31 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
|
|||
// 4. Get Ivalue from Civalue for each tuple element
|
||||
var vals []interface{}
|
||||
for _, civalue := range civalues {
|
||||
v, err := IValueFromC(civalue)
|
||||
v, err := IValueFromC(&civalue)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
return nil, err
|
||||
}
|
||||
vals = append(vals, v)
|
||||
}
|
||||
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: vals,
|
||||
kind: TupleVal,
|
||||
name: "Tuple",
|
||||
}
|
||||
}, nil
|
||||
|
||||
case 6: // IntList
|
||||
// 1. Len
|
||||
len := lib.AtiLength(cval.civalue)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. Call
|
||||
ptr1 := unsafe.Pointer(C.malloc(0))
|
||||
lib.AtiToIntList(cval.civalue, ptr1, int(len))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. Get int list
|
||||
|
@ -555,24 +556,24 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
|
|||
currPtr = nextPtr
|
||||
}
|
||||
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: intVals,
|
||||
kind: IntListVal,
|
||||
name: "IntList",
|
||||
}
|
||||
}, nil
|
||||
|
||||
case 7: // DoubleList
|
||||
// 1. Len
|
||||
len := lib.AtiLength(cval.civalue)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. Call
|
||||
ptr1 := unsafe.Pointer(C.malloc(0))
|
||||
lib.AtiToDoubleList(cval.civalue, ptr1, int(len))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. Get int list
|
||||
|
@ -585,24 +586,24 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
|
|||
currPtr = nextPtr
|
||||
}
|
||||
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: floatVals,
|
||||
kind: DoubleListVal,
|
||||
name: "DoubleList",
|
||||
}
|
||||
}, nil
|
||||
|
||||
case 8: // BoolList
|
||||
// 1. Len
|
||||
len := lib.AtiLength(cval.civalue)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. Call
|
||||
ptr1 := unsafe.Pointer(C.malloc(0))
|
||||
lib.AtiToBoolList(cval.civalue, ptr1, int(len))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. Get values
|
||||
|
@ -624,35 +625,35 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
|
|||
bvals = append(bvals, bval)
|
||||
}
|
||||
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: bvals,
|
||||
kind: BoolListVal,
|
||||
name: "BoolList",
|
||||
}
|
||||
}, nil
|
||||
|
||||
case 9: // String
|
||||
v := lib.AtiToString(cval.civalue)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: v,
|
||||
kind: StringVal,
|
||||
name: "String",
|
||||
}
|
||||
}, nil
|
||||
|
||||
case 10: // TensorList
|
||||
// 1. Len
|
||||
len := lib.AtiLength(cval.civalue)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. Call
|
||||
ptr1 := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtiToTensorList(cval.civalue, ptr1, int(len))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. Get values
|
||||
|
@ -665,24 +666,24 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
|
|||
currPtr = nextPtr
|
||||
}
|
||||
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: tensors,
|
||||
kind: TensorListVal,
|
||||
name: "TensorList",
|
||||
}
|
||||
}, nil
|
||||
|
||||
case 12: // GenericList []IValue
|
||||
// NOTE: atm, all these cases are unsupported.
|
||||
// 1. Len
|
||||
len := lib.AtiLength(cval.civalue)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 2. Call with first pointer and length
|
||||
ptr1 := (*lib.Civalue)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtiToGenericList(cval.civalue, ptr1, int(len))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. Get values
|
||||
|
@ -699,9 +700,9 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
|
|||
var vals []interface{}
|
||||
var itemTyp string
|
||||
for _, civalue := range civalues {
|
||||
v, err := IValueFromC(civalue)
|
||||
v, err := IValueFromC(&civalue)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
return nil, err
|
||||
}
|
||||
itemTyp = reflect.TypeOf(v.value).Kind().String()
|
||||
vals = append(vals, v.value)
|
||||
|
@ -713,42 +714,41 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
|
|||
for _, v := range vals {
|
||||
specVals = append(specVals, v.(string))
|
||||
}
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: specVals,
|
||||
kind: GenericListVal,
|
||||
name: "GenericList",
|
||||
}
|
||||
}, nil
|
||||
case "int":
|
||||
var specVals []int
|
||||
for _, v := range vals {
|
||||
specVals = append(specVals, v.(int))
|
||||
}
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: specVals,
|
||||
kind: GenericListVal,
|
||||
name: "GenericList",
|
||||
}
|
||||
}, nil
|
||||
case "int32":
|
||||
var specVals []int32
|
||||
for _, v := range vals {
|
||||
specVals = append(specVals, v.(int32))
|
||||
}
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: vals,
|
||||
kind: GenericListVal,
|
||||
name: "GenericList",
|
||||
}
|
||||
}, nil
|
||||
case "float32":
|
||||
var specVals []float32
|
||||
for _, v := range vals {
|
||||
specVals = append(specVals, v.(float32))
|
||||
}
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: vals,
|
||||
kind: GenericListVal,
|
||||
name: "GenericList",
|
||||
}
|
||||
return retVal, nil
|
||||
}, nil
|
||||
|
||||
default:
|
||||
log.Fatalf("IValueFromC method call - GenericList case: Unsupported item type (%v)\n", itemTyp)
|
||||
|
@ -757,14 +757,14 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
|
|||
case 13: // GenericDict map[IValue]IValue
|
||||
// 1. Len
|
||||
numVals := lib.AtiLength(cval.civalue)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 2. Call with first pointer and length
|
||||
ptr1 := (*lib.Civalue)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtiToGenericDict(cval.civalue, ptr1, int(numVals))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. Get values
|
||||
|
@ -783,9 +783,9 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
|
|||
var vals []interface{}
|
||||
var itemTyp string
|
||||
for _, civalue := range civalues {
|
||||
v, err := IValueFromC(civalue)
|
||||
v, err := IValueFromC(&civalue)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
return nil, err
|
||||
}
|
||||
itemTyp = reflect.TypeOf(v.value).Kind().String()
|
||||
vals = append(vals, v.value)
|
||||
|
@ -797,85 +797,84 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
|
|||
for i := 0; i < len(vals); i += 2 {
|
||||
specVals[vals[i].(string)] = vals[i+1].(string)
|
||||
}
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: specVals,
|
||||
kind: GenericDictVal,
|
||||
name: "GenericDict",
|
||||
}
|
||||
}, nil
|
||||
case "int":
|
||||
var specVals map[int]int = make(map[int]int)
|
||||
for i := 0; i < len(vals); i += 2 {
|
||||
specVals[vals[i].(int)] = vals[i+1].(int)
|
||||
}
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: specVals,
|
||||
kind: GenericDictVal,
|
||||
name: "GenericDict",
|
||||
}
|
||||
}, nil
|
||||
case "int32":
|
||||
var specVals map[int32]int32 = make(map[int32]int32)
|
||||
for i := 0; i < len(vals); i += 2 {
|
||||
specVals[vals[i].(int32)] = vals[i+1].(int32)
|
||||
}
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: specVals,
|
||||
kind: GenericDictVal,
|
||||
name: "GenericDict",
|
||||
}
|
||||
}, nil
|
||||
case "int64":
|
||||
var specVals map[int64]int64 = make(map[int64]int64)
|
||||
for i := 0; i < len(vals); i += 2 {
|
||||
specVals[vals[i].(int64)] = vals[i+1].(int64)
|
||||
}
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: specVals,
|
||||
kind: GenericDictVal,
|
||||
name: "GenericDict",
|
||||
}
|
||||
}, nil
|
||||
case "float32":
|
||||
var specVals map[float32]float32 = make(map[float32]float32)
|
||||
for i := 0; i < len(vals); i += 2 {
|
||||
specVals[vals[i].(float32)] = vals[i+1].(float32)
|
||||
}
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: specVals,
|
||||
kind: GenericDictVal,
|
||||
name: "GenericDict",
|
||||
}
|
||||
return retVal, nil
|
||||
}, nil
|
||||
case "float64":
|
||||
var specVals map[float64]float64 = make(map[float64]float64)
|
||||
for i := 0; i < len(vals); i += 2 {
|
||||
specVals[vals[i].(float64)] = vals[i+1].(float64)
|
||||
}
|
||||
retVal = IValue{
|
||||
return &IValue{
|
||||
value: specVals,
|
||||
kind: GenericDictVal,
|
||||
name: "GenericDict",
|
||||
}
|
||||
return retVal, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
default:
|
||||
log.Fatalf("IValueFromC - Unsupported type (tag value: %v)\n", tag)
|
||||
err := fmt.Errorf("IValueFromC - Unsupported type (tag value: %v)\n", tag)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
panic("Shouldn't reach here.")
|
||||
}
|
||||
|
||||
func (iv IValue) Value() (retVal interface{}) {
|
||||
func (iv *IValue) Value() interface{} {
|
||||
return iv.value
|
||||
}
|
||||
|
||||
func (iv IValue) Name() (retVal string) {
|
||||
func (iv *IValue) Name() string {
|
||||
return iv.name
|
||||
}
|
||||
|
||||
func (iv IValue) Kind() (retVal IValueKind) {
|
||||
func (iv *IValue) Kind() IValueKind {
|
||||
return iv.kind
|
||||
}
|
||||
|
||||
// A jit PyTorch module.
|
||||
// A JIT PyTorch module.
|
||||
//
|
||||
// These modules can be created via the
|
||||
// [TorchScript python api](https://pytorch.org/docs/stable/jit.html).
|
||||
|
@ -883,7 +882,7 @@ type CModule struct {
|
|||
Cmodule lib.Cmodule
|
||||
}
|
||||
|
||||
func (cm CModule) Drop() {
|
||||
func (cm *CModule) Drop() {
|
||||
lib.AtmFree(cm.Cmodule)
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatalf("CModule Drop method err: %v\n", err)
|
||||
|
@ -891,13 +890,13 @@ func (cm CModule) Drop() {
|
|||
}
|
||||
|
||||
// Loads a PyTorch saved JIT model from a file.
|
||||
func ModuleLoad(path string) (retVal CModule, err error) {
|
||||
func ModuleLoad(path string) (*CModule, error) {
|
||||
cmodule := lib.AtmLoad(path)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return CModule{cmodule}, nil
|
||||
return &CModule{cmodule}, nil
|
||||
|
||||
}
|
||||
|
||||
|
@ -905,17 +904,17 @@ func ModuleLoad(path string) (retVal CModule, err error) {
|
|||
//
|
||||
// This function loads the model directly on the specified device,
|
||||
// which means it also allows loading a GPU model on the CPU without having a CUDA enabled GPU.
|
||||
func ModuleLoadOnDevice(path string, device gotch.Device) (retVal CModule, err error) {
|
||||
func ModuleLoadOnDevice(path string, device gotch.Device) (*CModule, error) {
|
||||
cmodule := lib.AtmLoadOnDevice(path, device.CInt())
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return CModule{cmodule}, nil
|
||||
return &CModule{cmodule}, nil
|
||||
}
|
||||
|
||||
// Loads a PyTorch saved JIT model from a read instance.
|
||||
func ModuleLoadData(stream io.Reader) (retVal CModule, err error) {
|
||||
func ModuleLoadData(stream io.Reader) (*CModule, error) {
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
buf.ReadFrom(stream)
|
||||
|
@ -923,11 +922,11 @@ func ModuleLoadData(stream io.Reader) (retVal CModule, err error) {
|
|||
bufString := buf.String()
|
||||
|
||||
cmodule := lib.AtmLoadStr(bufString, len(bufString))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return CModule{cmodule}, nil
|
||||
return &CModule{cmodule}, nil
|
||||
|
||||
}
|
||||
|
||||
|
@ -935,22 +934,22 @@ func ModuleLoadData(stream io.Reader) (retVal CModule, err error) {
|
|||
//
|
||||
// This function loads the model directly on the specified device,
|
||||
// which means it also allows loading a GPU model on the CPU without having a CUDA enabled GPU.
|
||||
func ModuleLoadDataOnDevice(stream io.Reader, device gotch.Device) (retVal CModule, err error) {
|
||||
func ModuleLoadDataOnDevice(stream io.Reader, device gotch.Device) (*CModule, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
buf.ReadFrom(stream)
|
||||
|
||||
bufString := buf.String()
|
||||
|
||||
cmodule := lib.AtmLoadStrOnDevice(bufString, len(bufString), device.CInt())
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return CModule{cmodule}, nil
|
||||
return &CModule{cmodule}, nil
|
||||
}
|
||||
|
||||
// Performs the forward pass for a model on some specified tensor inputs.
|
||||
func (cm CModule) ForwardTs(tensors []Tensor) (retVal *Tensor, err error) {
|
||||
// ForwardTs performs the forward pass for a model on some specified tensor inputs.
|
||||
func (cm *CModule) ForwardTs(tensors []Tensor) (*Tensor, error) {
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
ctensors = append(ctensors, t.ctensor)
|
||||
|
@ -990,21 +989,21 @@ func (cm CModule) ForwardTs(tensors []Tensor) (retVal *Tensor, err error) {
|
|||
// - `nsize` is number of ctensor pointers encoded in binary data.
|
||||
ctensorsPtr := (*lib.Ctensor)(dataPtr)
|
||||
ctensor := lib.AtmForward(cm.Cmodule, ctensorsPtr, len(ctensors))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Tensor{ctensor}, nil
|
||||
}
|
||||
|
||||
// Performs the forward pass for a model on some specified ivalue input.
|
||||
func (cm CModule) ForwardIs(ivalues []IValue) (retVal IValue, err error) {
|
||||
// ForwardIs performs the forward pass for a model on some specified ivalue input.
|
||||
func (cm *CModule) ForwardIs(ivalues []IValue) (*IValue, error) {
|
||||
|
||||
var civalues []lib.Civalue
|
||||
for _, i := range ivalues {
|
||||
civalue, err := i.ToCIValue()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
return nil, err
|
||||
}
|
||||
civalues = append(civalues, civalue.civalue)
|
||||
}
|
||||
|
@ -1044,19 +1043,15 @@ func (cm CModule) ForwardIs(ivalues []IValue) (retVal IValue, err error) {
|
|||
civaluesPtr := (*lib.Civalue)(dataPtr)
|
||||
|
||||
civ := lib.AtmForward_(cm.Cmodule, civaluesPtr, len(civalues))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
retVal, err = IValueFromC(CIValue{civ})
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
return IValueFromC(&CIValue{civ})
|
||||
}
|
||||
|
||||
func (cm CModule) To(device gotch.Device, kind gotch.DType, nonBlocking bool) {
|
||||
// To moves CModule to specified device.
|
||||
func (cm *CModule) To(device gotch.Device, kind gotch.DType, nonBlocking bool) {
|
||||
lib.AtmTo(cm.Cmodule, device.CInt(), kind.CInt(), nonBlocking)
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatalf("CModule To method call err: %v\n", err)
|
||||
|
@ -1066,7 +1061,8 @@ func (cm CModule) To(device gotch.Device, kind gotch.DType, nonBlocking bool) {
|
|||
// Implement Module for CModule:
|
||||
// =============================
|
||||
|
||||
func (cm CModule) Forward(tensor *Tensor) (retVal *Tensor, err error) {
|
||||
// Forwad implements Module interface for CModule.
|
||||
func (cm *CModule) Forward(tensor *Tensor) (*Tensor, error) {
|
||||
|
||||
var tensors []Tensor = []Tensor{*tensor}
|
||||
return cm.ForwardTs(tensors)
|
||||
|
@ -1076,7 +1072,7 @@ func (cm CModule) Forward(tensor *Tensor) (retVal *Tensor, err error) {
|
|||
// ======================================
|
||||
|
||||
// Apply forwards tensor itself through a module.
|
||||
func (ts *Tensor) ApplyCModule(m CModule) (retVal *Tensor) {
|
||||
func (ts *Tensor) ApplyCModule(m *CModule) *Tensor {
|
||||
retVal, err := m.Forward(ts)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
|
|
@ -86,7 +86,7 @@ func TestModuleForwardIValue(t *testing.T) {
|
|||
iv1 := ts.NewIValue(*ts1)
|
||||
iv2 := ts.NewIValue(*ts2)
|
||||
|
||||
got, err := foo.ForwardIs([]ts.IValue{iv1, iv2})
|
||||
got, err := foo.ForwardIs([]ts.IValue{*iv1, *iv2})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user