changed(tensor/jit): changed to receiver pointer

This commit is contained in:
sugarme 2021-01-02 12:35:57 +11:00
parent 9fe5cef70a
commit c7a8c4b348
4 changed files with 208 additions and 214 deletions

25
nn/jit.go Normal file
View File

@ -0,0 +1,25 @@
package nn
import (
ts "github.com/sugarme/gotch/tensor"
)
// TrainableCModule is a trainable version of JIT Pytorch module
//
// These modules can be created via TorchScript python API.
// See: https://pytorch.org/docs/stable/jit.html
type TrainableCModule struct {
Inner ts.CModule
}
// TrainableModuleLoad loads a PyTorch saved JIT module from a file and adds
// tensors (weights) to `varstore` so that module can be trained.
/*
* func TrainableModuleLoad(p *Path, file string) (*TrainableCModule, error) {
*
* inner, err := ts.ModuleLoadOnDevice(file, p.Device())
* if err != nil {
* return nil, err
* }
* }
* */

View File

@ -659,30 +659,3 @@ func (e *Entry) OrZerosNoTrain(dims []int64) *ts.Tensor {
z := ts.MustZeros(dims, gotch.Float, e.path.Device())
return e.path.getOrAddWithLock(e.name, z, true, *e.variables)
}
// TODO: can we implement `Div` operator in Go?
// NOTE: `Rhs` (right hand side) is a generic type parameter
// If not given, it will be default to `self` type
/*
* impl<'a, T> Div<T> for &'a mut Path<'a>
* where
* T: std::string::ToString,
* {
* type Output = Path<'a>;
*
* fn div(self, rhs: T) -> Self::Output {
* self.sub(rhs.to_string())
* }
* }
*
* impl<'a, T> Div<T> for &'a Path<'a>
* where
* T: std::string::ToString,
* {
* type Output = Path<'a>;
*
* fn div(self, rhs: T) -> Self::Output {
* self.sub(rhs.to_string())
* }
* }
* */

View File

@ -50,9 +50,9 @@ type IValue struct {
}
// NewIValue creates a new IValue from given value of various types.
func NewIValue(v interface{}) (retVal IValue) {
func NewIValue(v interface{}) *IValue {
retVal = IValue{value: v}
retVal := &IValue{value: v}
if v == nil {
retVal.kind = NoneVal
retVal.name = "None"
@ -132,7 +132,7 @@ func NewIValue(v interface{}) (retVal IValue) {
var ivals []IValue
for _, tensor := range tensors {
ival := NewIValue(tensor)
ivals = append(ivals, ival)
ivals = append(ivals, *ival)
}
retVal.value = ivals
@ -167,45 +167,45 @@ func NewIValue(v interface{}) (retVal IValue) {
// IValue methods:
// ===============
func (iv IValue) ToCIValue() (retVal CIValue, err error) {
func (iv *IValue) ToCIValue() (*CIValue, error) {
switch iv.name {
case "None":
cval := lib.AtiNone()
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CIValue{civalue: cval}, nil
return &CIValue{civalue: cval}, nil
case "Tensor":
cval := lib.AtiTensor(iv.value.(Tensor).ctensor)
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CIValue{civalue: cval}, nil
return &CIValue{civalue: cval}, nil
case "Int":
cval := lib.AtiInt(iv.value.(int64))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CIValue{civalue: cval}, nil
return &CIValue{civalue: cval}, nil
case "Double":
cval := lib.AtiDouble(iv.value.(float64))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CIValue{civalue: cval}, nil
return &CIValue{civalue: cval}, nil
case "Bool":
cval := lib.AtiBool(iv.value.(bool))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CIValue{civalue: cval}, nil
return &CIValue{civalue: cval}, nil
case "Tuple":
val := reflect.Indirect(reflect.ValueOf(iv.value))
@ -219,16 +219,16 @@ func (iv IValue) ToCIValue() (retVal CIValue, err error) {
cval, err := ival.ToCIValue()
if err != nil {
err = fmt.Errorf("ToCIValue method call err - Tuple case: %v\n", err)
return retVal, err
return nil, err
}
cvals = append(cvals, cval.civalue)
}
tuple := lib.AtiTuple(cvals, len(cvals))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CIValue{civalue: tuple}, nil
return &CIValue{civalue: tuple}, nil
// 2. Tuple is (IValue, IValue)
default:
@ -238,16 +238,16 @@ func (iv IValue) ToCIValue() (retVal CIValue, err error) {
cval, err := i.ToCIValue()
if err != nil {
err = fmt.Errorf("ToCIValue method call err - Tuple case: %v\n", err)
return retVal, err
return nil, err
}
cvals = append(cvals, cval.civalue)
}
tuple := lib.AtiTuple(cvals, len(cvals))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CIValue{civalue: tuple}, nil
return &CIValue{civalue: tuple}, nil
}
case "GenericList":
@ -264,7 +264,7 @@ func (iv IValue) ToCIValue() (retVal CIValue, err error) {
cval, err := ival.ToCIValue()
if err != nil {
err = fmt.Errorf("ToCIValue method call err - GenericList case: %v\n", err)
return retVal, err
return nil, err
}
cvals = append(cvals, cval.civalue)
}
@ -305,34 +305,34 @@ func (iv IValue) ToCIValue() (retVal CIValue, err error) {
}
list := lib.AtiGenericList(cvals, len(cvals))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CIValue{civalue: list}, nil
return &CIValue{civalue: list}, nil
case "IntList":
var vals []int64 = iv.value.([]int64)
cval := lib.AtiIntList(vals, len(vals))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CIValue{civalue: cval}, nil
return &CIValue{civalue: cval}, nil
case "DoubleList":
var vals []float64 = iv.value.([]float64)
cval := lib.AtiDoubleList(vals, len(vals))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CIValue{civalue: cval}, nil
return &CIValue{civalue: cval}, nil
case "BoolList":
var vals []bool = iv.value.([]bool)
cval := lib.AtiBoolList(vals, len(vals))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CIValue{civalue: cval}, nil
return &CIValue{civalue: cval}, nil
case "TensorList":
var vals []Tensor = iv.value.([]Tensor)
@ -341,17 +341,17 @@ func (iv IValue) ToCIValue() (retVal CIValue, err error) {
cvals = append(cvals, i.ctensor)
}
list := lib.AtiTensorList(cvals, len(cvals))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CIValue{civalue: list}, nil
return &CIValue{civalue: list}, nil
case "String":
cval := lib.AtiString(iv.value.(string))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CIValue{civalue: cval}, nil
return &CIValue{civalue: cval}, nil
case "GenericDict":
var cvals []lib.Civalue
@ -417,92 +417,93 @@ func (iv IValue) ToCIValue() (retVal CIValue, err error) {
// 2. Pairing key and value in a slice (cvals)
dict := lib.AtiGenericDict(cvals, len(cvals)/2)
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CIValue{civalue: dict}, nil
return &CIValue{civalue: dict}, nil
case "Generic":
log.Fatalf("ToCIValue method call - Generic case: unsupport type(%v)\n", reflect.TypeOf(iv.value).Kind().String())
err := fmt.Errorf("ToCIValue method call - Generic case: unsupport type(%v)\n", reflect.TypeOf(iv.value).Kind().String())
return nil, err
default:
log.Fatalf("ToCIValue method call - Generic case: unsupport type(%v)\n", reflect.TypeOf(iv.value).Kind().String())
err := fmt.Errorf("ToCIValue method call - Generic case: unsupport type(%v)\n", reflect.TypeOf(iv.value).Kind().String())
return nil, err
}
return retVal, nil
panic("Shouldn't reached here.")
}
// IValueFromC returns an IValue from given CIValue.
//
// It consumes the pointer and frees the associated memory.
func IValueFromC(cval CIValue) (retVal IValue, err error) {
func IValueFromC(cval *CIValue) (*IValue, error) {
// tag will be a value of int32
tag := lib.AtiTag(cval.civalue)
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
switch tag {
case 0:
retVal = IValue{
return &IValue{
value: nil,
kind: NoneVal,
name: "None",
}
}, nil
case 1:
tensor := lib.AtiToTensor(cval.civalue)
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
retVal = IValue{
return &IValue{
value: tensor,
kind: TensorVal,
name: "Tensor",
}
}, nil
case 2:
v := lib.AtiToDouble(cval.civalue)
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
retVal = IValue{
return &IValue{
value: v,
kind: DoubleVal,
name: "Double",
}
}, nil
case 3:
v := lib.AtiToInt(cval.civalue)
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
retVal = IValue{
return &IValue{
value: v,
kind: IntVal,
name: "Int",
}
}, nil
case 4:
v := lib.AtiToBool(cval.civalue)
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
retVal = IValue{
return &IValue{
value: v,
kind: BoolVal,
name: "Bool",
}
}, nil
case 5: // Tuple []IValue 2 elements
// 1. Determine tuple length
len := lib.AtiTupleLength(cval.civalue)
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
// 2. Call with first pointer and length
ptr1 := (*lib.Civalue)(unsafe.Pointer(C.malloc(0)))
lib.AtiToTuple(cval.civalue, ptr1, int(len))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
// 3. Get list of Civalue tuple elements
@ -518,31 +519,31 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
// 4. Get Ivalue from Civalue for each tuple element
var vals []interface{}
for _, civalue := range civalues {
v, err := IValueFromC(civalue)
v, err := IValueFromC(&civalue)
if err != nil {
return retVal, err
return nil, err
}
vals = append(vals, v)
}
retVal = IValue{
return &IValue{
value: vals,
kind: TupleVal,
name: "Tuple",
}
}, nil
case 6: // IntList
// 1. Len
len := lib.AtiLength(cval.civalue)
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
// 2. Call
ptr1 := unsafe.Pointer(C.malloc(0))
lib.AtiToIntList(cval.civalue, ptr1, int(len))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
// 3. Get int list
@ -555,24 +556,24 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
currPtr = nextPtr
}
retVal = IValue{
return &IValue{
value: intVals,
kind: IntListVal,
name: "IntList",
}
}, nil
case 7: // DoubleList
// 1. Len
len := lib.AtiLength(cval.civalue)
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
// 2. Call
ptr1 := unsafe.Pointer(C.malloc(0))
lib.AtiToDoubleList(cval.civalue, ptr1, int(len))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
// 3. Get int list
@ -585,24 +586,24 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
currPtr = nextPtr
}
retVal = IValue{
return &IValue{
value: floatVals,
kind: DoubleListVal,
name: "DoubleList",
}
}, nil
case 8: // BoolList
// 1. Len
len := lib.AtiLength(cval.civalue)
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
// 2. Call
ptr1 := unsafe.Pointer(C.malloc(0))
lib.AtiToBoolList(cval.civalue, ptr1, int(len))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
// 3. Get values
@ -624,35 +625,35 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
bvals = append(bvals, bval)
}
retVal = IValue{
return &IValue{
value: bvals,
kind: BoolListVal,
name: "BoolList",
}
}, nil
case 9: // String
v := lib.AtiToString(cval.civalue)
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
retVal = IValue{
return &IValue{
value: v,
kind: StringVal,
name: "String",
}
}, nil
case 10: // TensorList
// 1. Len
len := lib.AtiLength(cval.civalue)
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
// 2. Call
ptr1 := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtiToTensorList(cval.civalue, ptr1, int(len))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
// 3. Get values
@ -665,24 +666,24 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
currPtr = nextPtr
}
retVal = IValue{
return &IValue{
value: tensors,
kind: TensorListVal,
name: "TensorList",
}
}, nil
case 12: // GenericList []IValue
// NOTE: atm, all these cases are unsupported.
// 1. Len
len := lib.AtiLength(cval.civalue)
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
// 2. Call with first pointer and length
ptr1 := (*lib.Civalue)(unsafe.Pointer(C.malloc(0)))
lib.AtiToGenericList(cval.civalue, ptr1, int(len))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
// 3. Get values
@ -699,9 +700,9 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
var vals []interface{}
var itemTyp string
for _, civalue := range civalues {
v, err := IValueFromC(civalue)
v, err := IValueFromC(&civalue)
if err != nil {
return retVal, err
return nil, err
}
itemTyp = reflect.TypeOf(v.value).Kind().String()
vals = append(vals, v.value)
@ -713,42 +714,41 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
for _, v := range vals {
specVals = append(specVals, v.(string))
}
retVal = IValue{
return &IValue{
value: specVals,
kind: GenericListVal,
name: "GenericList",
}
}, nil
case "int":
var specVals []int
for _, v := range vals {
specVals = append(specVals, v.(int))
}
retVal = IValue{
return &IValue{
value: specVals,
kind: GenericListVal,
name: "GenericList",
}
}, nil
case "int32":
var specVals []int32
for _, v := range vals {
specVals = append(specVals, v.(int32))
}
retVal = IValue{
return &IValue{
value: vals,
kind: GenericListVal,
name: "GenericList",
}
}, nil
case "float32":
var specVals []float32
for _, v := range vals {
specVals = append(specVals, v.(float32))
}
retVal = IValue{
return &IValue{
value: vals,
kind: GenericListVal,
name: "GenericList",
}
return retVal, nil
}, nil
default:
log.Fatalf("IValueFromC method call - GenericList case: Unsupported item type (%v)\n", itemTyp)
@ -757,14 +757,14 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
case 13: // GenericDict map[IValue]IValue
// 1. Len
numVals := lib.AtiLength(cval.civalue)
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
// 2. Call with first pointer and length
ptr1 := (*lib.Civalue)(unsafe.Pointer(C.malloc(0)))
lib.AtiToGenericDict(cval.civalue, ptr1, int(numVals))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
// 3. Get values
@ -783,9 +783,9 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
var vals []interface{}
var itemTyp string
for _, civalue := range civalues {
v, err := IValueFromC(civalue)
v, err := IValueFromC(&civalue)
if err != nil {
return retVal, err
return nil, err
}
itemTyp = reflect.TypeOf(v.value).Kind().String()
vals = append(vals, v.value)
@ -797,85 +797,84 @@ func IValueFromC(cval CIValue) (retVal IValue, err error) {
for i := 0; i < len(vals); i += 2 {
specVals[vals[i].(string)] = vals[i+1].(string)
}
retVal = IValue{
return &IValue{
value: specVals,
kind: GenericDictVal,
name: "GenericDict",
}
}, nil
case "int":
var specVals map[int]int = make(map[int]int)
for i := 0; i < len(vals); i += 2 {
specVals[vals[i].(int)] = vals[i+1].(int)
}
retVal = IValue{
return &IValue{
value: specVals,
kind: GenericDictVal,
name: "GenericDict",
}
}, nil
case "int32":
var specVals map[int32]int32 = make(map[int32]int32)
for i := 0; i < len(vals); i += 2 {
specVals[vals[i].(int32)] = vals[i+1].(int32)
}
retVal = IValue{
return &IValue{
value: specVals,
kind: GenericDictVal,
name: "GenericDict",
}
}, nil
case "int64":
var specVals map[int64]int64 = make(map[int64]int64)
for i := 0; i < len(vals); i += 2 {
specVals[vals[i].(int64)] = vals[i+1].(int64)
}
retVal = IValue{
return &IValue{
value: specVals,
kind: GenericDictVal,
name: "GenericDict",
}
}, nil
case "float32":
var specVals map[float32]float32 = make(map[float32]float32)
for i := 0; i < len(vals); i += 2 {
specVals[vals[i].(float32)] = vals[i+1].(float32)
}
retVal = IValue{
return &IValue{
value: specVals,
kind: GenericDictVal,
name: "GenericDict",
}
return retVal, nil
}, nil
case "float64":
var specVals map[float64]float64 = make(map[float64]float64)
for i := 0; i < len(vals); i += 2 {
specVals[vals[i].(float64)] = vals[i+1].(float64)
}
retVal = IValue{
return &IValue{
value: specVals,
kind: GenericDictVal,
name: "GenericDict",
}
return retVal, nil
}, nil
}
default:
log.Fatalf("IValueFromC - Unsupported type (tag value: %v)\n", tag)
err := fmt.Errorf("IValueFromC - Unsupported type (tag value: %v)\n", tag)
return nil, err
}
return
panic("Shouldn't reach here.")
}
func (iv IValue) Value() (retVal interface{}) {
func (iv *IValue) Value() interface{} {
return iv.value
}
func (iv IValue) Name() (retVal string) {
func (iv *IValue) Name() string {
return iv.name
}
func (iv IValue) Kind() (retVal IValueKind) {
func (iv *IValue) Kind() IValueKind {
return iv.kind
}
// A jit PyTorch module.
// A JIT PyTorch module.
//
// These modules can be created via the
// [TorchScript python api](https://pytorch.org/docs/stable/jit.html).
@ -883,7 +882,7 @@ type CModule struct {
Cmodule lib.Cmodule
}
func (cm CModule) Drop() {
func (cm *CModule) Drop() {
lib.AtmFree(cm.Cmodule)
if err := TorchErr(); err != nil {
log.Fatalf("CModule Drop method err: %v\n", err)
@ -891,13 +890,13 @@ func (cm CModule) Drop() {
}
// Loads a PyTorch saved JIT model from a file.
func ModuleLoad(path string) (retVal CModule, err error) {
func ModuleLoad(path string) (*CModule, error) {
cmodule := lib.AtmLoad(path)
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CModule{cmodule}, nil
return &CModule{cmodule}, nil
}
@ -905,17 +904,17 @@ func ModuleLoad(path string) (retVal CModule, err error) {
//
// This function loads the model directly on the specified device,
// which means it also allows loading a GPU model on the CPU without having a CUDA enabled GPU.
func ModuleLoadOnDevice(path string, device gotch.Device) (retVal CModule, err error) {
func ModuleLoadOnDevice(path string, device gotch.Device) (*CModule, error) {
cmodule := lib.AtmLoadOnDevice(path, device.CInt())
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CModule{cmodule}, nil
return &CModule{cmodule}, nil
}
// Loads a PyTorch saved JIT model from a read instance.
func ModuleLoadData(stream io.Reader) (retVal CModule, err error) {
func ModuleLoadData(stream io.Reader) (*CModule, error) {
buf := new(bytes.Buffer)
buf.ReadFrom(stream)
@ -923,11 +922,11 @@ func ModuleLoadData(stream io.Reader) (retVal CModule, err error) {
bufString := buf.String()
cmodule := lib.AtmLoadStr(bufString, len(bufString))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CModule{cmodule}, nil
return &CModule{cmodule}, nil
}
@ -935,22 +934,22 @@ func ModuleLoadData(stream io.Reader) (retVal CModule, err error) {
//
// This function loads the model directly on the specified device,
// which means it also allows loading a GPU model on the CPU without having a CUDA enabled GPU.
func ModuleLoadDataOnDevice(stream io.Reader, device gotch.Device) (retVal CModule, err error) {
func ModuleLoadDataOnDevice(stream io.Reader, device gotch.Device) (*CModule, error) {
buf := new(bytes.Buffer)
buf.ReadFrom(stream)
bufString := buf.String()
cmodule := lib.AtmLoadStrOnDevice(bufString, len(bufString), device.CInt())
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return CModule{cmodule}, nil
return &CModule{cmodule}, nil
}
// Performs the forward pass for a model on some specified tensor inputs.
func (cm CModule) ForwardTs(tensors []Tensor) (retVal *Tensor, err error) {
// ForwardTs performs the forward pass for a model on some specified tensor inputs.
func (cm *CModule) ForwardTs(tensors []Tensor) (*Tensor, error) {
var ctensors []lib.Ctensor
for _, t := range tensors {
ctensors = append(ctensors, t.ctensor)
@ -990,21 +989,21 @@ func (cm CModule) ForwardTs(tensors []Tensor) (retVal *Tensor, err error) {
// - `nsize` is number of ctensor pointers encoded in binary data.
ctensorsPtr := (*lib.Ctensor)(dataPtr)
ctensor := lib.AtmForward(cm.Cmodule, ctensorsPtr, len(ctensors))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
return &Tensor{ctensor}, nil
}
// Performs the forward pass for a model on some specified ivalue input.
func (cm CModule) ForwardIs(ivalues []IValue) (retVal IValue, err error) {
// ForwardIs performs the forward pass for a model on some specified ivalue input.
func (cm *CModule) ForwardIs(ivalues []IValue) (*IValue, error) {
var civalues []lib.Civalue
for _, i := range ivalues {
civalue, err := i.ToCIValue()
if err != nil {
return retVal, err
return nil, err
}
civalues = append(civalues, civalue.civalue)
}
@ -1044,19 +1043,15 @@ func (cm CModule) ForwardIs(ivalues []IValue) (retVal IValue, err error) {
civaluesPtr := (*lib.Civalue)(dataPtr)
civ := lib.AtmForward_(cm.Cmodule, civaluesPtr, len(civalues))
if err = TorchErr(); err != nil {
return retVal, err
if err := TorchErr(); err != nil {
return nil, err
}
retVal, err = IValueFromC(CIValue{civ})
if err != nil {
return retVal, err
}
return retVal, nil
return IValueFromC(&CIValue{civ})
}
func (cm CModule) To(device gotch.Device, kind gotch.DType, nonBlocking bool) {
// To moves CModule to specified device.
func (cm *CModule) To(device gotch.Device, kind gotch.DType, nonBlocking bool) {
lib.AtmTo(cm.Cmodule, device.CInt(), kind.CInt(), nonBlocking)
if err := TorchErr(); err != nil {
log.Fatalf("CModule To method call err: %v\n", err)
@ -1066,7 +1061,8 @@ func (cm CModule) To(device gotch.Device, kind gotch.DType, nonBlocking bool) {
// Implement Module for CModule:
// =============================
func (cm CModule) Forward(tensor *Tensor) (retVal *Tensor, err error) {
// Forwad implements Module interface for CModule.
func (cm *CModule) Forward(tensor *Tensor) (*Tensor, error) {
var tensors []Tensor = []Tensor{*tensor}
return cm.ForwardTs(tensors)
@ -1076,7 +1072,7 @@ func (cm CModule) Forward(tensor *Tensor) (retVal *Tensor, err error) {
// ======================================
// Apply forwards tensor itself through a module.
func (ts *Tensor) ApplyCModule(m CModule) (retVal *Tensor) {
func (ts *Tensor) ApplyCModule(m *CModule) *Tensor {
retVal, err := m.Forward(ts)
if err != nil {
log.Fatal(err)

View File

@ -86,7 +86,7 @@ func TestModuleForwardIValue(t *testing.T) {
iv1 := ts.NewIValue(*ts1)
iv2 := ts.NewIValue(*ts2)
got, err := foo.ForwardIs([]ts.IValue{iv1, iv2})
got, err := foo.ForwardIs([]ts.IValue{*iv1, *iv2})
if err != nil {
t.Error(err)
}