gotch/libtch/tensor.go

739 lines
21 KiB
Go

package libtch
//#include "stddef.h"
//#include "stdbool.h"
//#include "torch_api.h"
//#include "stdlib.h"
//void callback_fn(void *, char *, tensor);
//typedef void (*f)(void *, char *, tensor);
import "C"
import (
"strings"
"unsafe"
)
// NOTE: C.tensor is a C pointer to torch::Tensor
type Ctensor = C.tensor
type Cscalar = C.scalar
type Coptimizer = C.optimizer
type Civalue = C.ivalue
type Cmodule = C.module
type NamedCtensor struct {
Name string
Ctensor C.tensor
}
type LoadData struct {
NamedCtensors []NamedCtensor
}
var PStore = NewPointerStore()
func AtNewTensor() Ctensor {
return C.at_new_tensor()
}
// tensor at_new_tensor();
func NewTensor() Ctensor {
return C.at_new_tensor()
}
// int at_device(tensor);
func AtDevice(ts Ctensor) int {
cint := C.at_device(ts)
return *(*int)(unsafe.Pointer(&cint))
}
// tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims, size_t element_size_in_bytes, int type);
func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_bytes uint, kind int) Ctensor {
// just get pointer of the first element of shape
c_dims := (*C.int64_t)(unsafe.Pointer(&dims[0]))
c_ndims := *(*C.size_t)(unsafe.Pointer(&ndims))
c_elt_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&elt_size_in_bytes))
c_kind := *(*C.int)(unsafe.Pointer(&kind))
return C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
}
// void at_print(tensor);
func AtPrint(t Ctensor) {
C.at_print(t)
}
// void *at_data_ptr(tensor);
func AtDataPtr(t Ctensor) unsafe.Pointer {
return C.at_data_ptr(t)
}
// size_t at_dim(tensor);
func AtDim(t Ctensor) uint64 {
result := C.at_dim(t)
return *(*uint64)(unsafe.Pointer(&result))
}
// void at_shape(tensor, int64_t *);
func AtShape(t Ctensor, ptr unsafe.Pointer) {
c_ptr := (*C.long)(ptr)
C.at_shape(t, c_ptr)
}
// int at_scalar_type(tensor);
func AtScalarType(t Ctensor) int32 {
result := C.at_scalar_type(t)
return *(*int32)(unsafe.Pointer(&result))
}
func GetAndResetLastErr() *C.char {
return C.get_and_reset_last_err()
}
// int atc_cuda_device_count();
func AtcCudaDeviceCount() int {
result := C.atc_cuda_device_count()
return *(*int)(unsafe.Pointer(&result))
}
// int atc_cuda_is_available();
func AtcCudaIsAvailable() bool {
result := C.atc_cuda_is_available()
return *(*bool)(unsafe.Pointer(&result))
}
// int atc_cudnn_is_available();
func AtcCudnnIsAvailable() bool {
result := C.atc_cudnn_is_available()
return *(*bool)(unsafe.Pointer(&result))
}
// void atc_set_benchmark_cudnn(int b);
func AtcSetBenchmarkCudnn(b int) {
cb := *(*C.int)(unsafe.Pointer(&b))
C.atc_set_benchmark_cudnn(cb)
}
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
func AtDoubleValueAtIndexes(ts Ctensor, indexes unsafe.Pointer, indexesLen int) float64 {
ctensor := (C.tensor)(ts)
cindexes := (*C.long)(indexes)
cindexesLen := *(*C.int)(unsafe.Pointer(&indexesLen))
retVal := C.at_double_value_at_indexes(ctensor, cindexes, cindexesLen)
return *(*float64)(unsafe.Pointer(&retVal))
}
// int64_t at_int64_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
func AtInt64ValueAtIndexes(ts Ctensor, indexes unsafe.Pointer, indexesLen int) int64 {
ctensor := (C.tensor)(ts)
cindexes := (*C.long)(indexes)
cindexesLen := *(*C.int)(unsafe.Pointer(&indexesLen))
retVal := C.at_int64_value_at_indexes(ctensor, cindexes, cindexesLen)
return *(*int64)(unsafe.Pointer(&retVal))
}
// int at_requires_grad(tensor);
func AtRequiresGrad(ts Ctensor) bool {
retVal := C.at_requires_grad(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_defined(tensor);
func AtDefined(ts Ctensor) bool {
retVal := C.at_defined(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_is_sparse(tensor);
func AtIsSparse(ts Ctensor) bool {
retVal := C.at_is_sparse(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
// void at_backward(tensor, int, int);
func AtBackward(ts Ctensor, keepGraph int, createGraph int) {
ckeepGraph := *(*C.int)(unsafe.Pointer(&keepGraph))
ccreateGraph := *(*C.int)(unsafe.Pointer(&createGraph))
C.at_backward(ts, ckeepGraph, ccreateGraph)
}
/*
* void at_run_backward(tensor *tensors,
* int ntensors,
* tensor *inputs,
* int ninputs,
* tensor *outputs,
* int keep_graph,
* int create_graph);
* */
func AtRunBackward(tensorsPtr *Ctensor, ntensors int, inputsPtr *Ctensor, ninputs int, outputsPtr *Ctensor, keepGraph int, createGraph int) {
cntensors := *(*C.int)(unsafe.Pointer(&ntensors))
cninputs := *(*C.int)(unsafe.Pointer(&ninputs))
ckeepGraph := *(*C.int)(unsafe.Pointer(&keepGraph))
ccreateGraph := *(*C.int)(unsafe.Pointer(&createGraph))
C.at_run_backward(tensorsPtr, cntensors, inputsPtr, cninputs, outputsPtr, ckeepGraph, ccreateGraph)
}
// void at_copy_data(tensor tensor, void *vs, size_t numel, size_t element_size_in_bytes);
func AtCopyData(ts Ctensor, vs unsafe.Pointer, numel uint, element_size_in_bytes uint) {
cnumel := *(*C.size_t)(unsafe.Pointer(&numel))
celement_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&element_size_in_bytes))
C.at_copy_data(ts, vs, cnumel, celement_size_in_bytes)
}
// tensor at_shallow_clone(tensor);
func AtShallowClone(ts Ctensor) Ctensor {
return C.at_shallow_clone(ts)
}
// tensor at_get(tensor, int index);
func AtGet(ts Ctensor, index int) Ctensor {
cindex := *(*C.int)(unsafe.Pointer(&index))
return C.at_get(ts, cindex)
}
// void at_copy_(tensor dst, tensor src);
func AtCopy_(dst Ctensor, src Ctensor) {
C.at_copy_(dst, src)
}
// void at_save(tensor, char *filename);
func AtSave(ts Ctensor, path string) {
cstringPtr := C.CString(path)
defer C.free(unsafe.Pointer(cstringPtr))
C.at_save(ts, cstringPtr)
}
// tensor at_load(char *filename);
func AtLoad(path string) Ctensor {
cstringPtr := C.CString(path)
defer C.free(unsafe.Pointer(cstringPtr))
return C.at_load(cstringPtr)
}
// void at_save_multi(tensor *tensors, char **tensor_names, int ntensors, char *filename);
func AtSaveMulti(tensors []Ctensor, tensor_names []string, ntensors int, filename string) {
var ctensors []C.tensor
for i := 0; i < len(tensors); i++ {
ctensors = append(ctensors, (C.tensor)(tensors[i]))
}
cpointerSize := 4
cnamesPtr := (*[1 << 30]**C.char)(C.malloc(C.size_t(cpointerSize * len(tensor_names))))
// defer C.free(unsafe.Pointer(cnamesPtr))
for i := 0; i < len(tensor_names); i++ {
cname := C.CString(tensor_names[i])
cnamesPtr[i] = &cname
// defer C.free(unsafe.Pointer(cnamesPtr[i]))
}
cntensors := *(*C.int)(unsafe.Pointer(&ntensors))
cfilename := C.CString(filename)
C.at_save_multi(&ctensors[0], cnamesPtr[0], cntensors, cfilename)
}
// void at_save_multi(tensor *tensors, char **tensor_names, int ntensors, char *filename);
func AtSaveMultiNew(tensors []Ctensor, names []string, filename string) {
// NOTE. namedTensors is slice of tensors which wrap Ctensor pointer.
// However, they are not neccessary consecutive Ctensor pointer.
// The following code will create 2 arrays to contain ctensors and cnames
// 1. Calculate memory size for the array.
// 2. Copy C pointer values to the arrays.
ntensors := len(tensors)
// number of bytes for each array of pointers.
nbytes := C.size_t(ntensors) * C.size_t(unsafe.Sizeof(uintptr(0)))
cnamesPtr := (*[1 << 30]**C.char)(C.malloc(nbytes))
for i, name := range names {
cname := C.CString(name)
cnamesPtr[i] = &cname
}
ctensorsPtr := (*[1 << 30]C.tensor)(C.malloc(nbytes))
for i, ctensor := range tensors {
ctensorsPtr[i] = (C.tensor)(ctensor)
}
cntensors := *(*C.int)(unsafe.Pointer(&ntensors))
cfilename := C.CString(filename)
C.at_save_multi(&ctensorsPtr[0], cnamesPtr[0], cntensors, cfilename)
}
/* [at_load_multi] takes as input an array of nullptr for [tensors]. */
// void at_load_multi(tensor *tensors, char **tensor_names, int ntensors, char *filename);
func AtLoadMulti(tensors []Ctensor, tensor_names []string, ntensors int, filename string) {
// TODO: implement this
}
// void at_load_callback(char *filename, void *data, void (*f)(void *, char *, tensor));
/*
* void at_load_callback(char *filename, void *data, void (*f)(void *, char *, tensor)) {
* PROTECT(
* auto module = torch::jit::load(filename);
* for (const auto &p : module.named_parameters()) {
* auto v = p.value;
* f(data, (char*)p.name.c_str(), new torch::Tensor(v));
* }
* )
* }
* */
func AtLoadCallback(filename string, dataPtr unsafe.Pointer) {
cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename))
C.at_load_callback(cfilename, dataPtr, C.f(C.callback_fn))
}
//TODO: move `callback_fn` to wrapper package???
//export callback_fn
func callback_fn(dataPtr unsafe.Pointer, name *C.char, ctensor C.tensor) {
tsName := C.GoString(name)
tsName = strings.ReplaceAll(tsName, "|", ".")
namedCtensor := NamedCtensor{
Name: tsName,
Ctensor: ctensor,
}
data := PStore.Get(dataPtr).(*LoadData)
data.NamedCtensors = append(data.NamedCtensors, namedCtensor)
}
/*
* void at_load_callback_with_device(char *filename, void *data, void (*f)(void *, char *, tensor), int device_id) {
* PROTECT(
* auto module = torch::jit::load(filename, device_of_int(device_id));
* for (const auto &p : module.named_parameters()) {
* auto v = p.value;
* f(data, (char*)p.name.c_str(), new torch::Tensor(v));
* }
* )
* }
* */
func AtLoadCallbackWithDevice(filename string, dataPtr unsafe.Pointer, device int32) {
cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename))
cdevice := *(*C.int)(unsafe.Pointer(&device))
C.at_load_callback_with_device(cfilename, dataPtr, C.f(C.callback_fn), cdevice)
}
/*
* char *at_to_string(tensor t, int line_size) {
* PROTECT(
* std::ostringstream oss;
* torch::print(oss, *t, line_size);
* return strdup(oss.str().c_str());
* )
* return nullptr;
* }
* */
func AtToString(ts Ctensor, lineSize int64) string {
clineSize := *(*C.int)(unsafe.Pointer(&lineSize))
charPtr := C.at_to_string(ts, clineSize)
goString := C.GoString(charPtr)
return goString
}
// void at_free(tensor);
func AtFree(ts Ctensor) {
C.at_free(ts)
}
//int at_grad_set_enabled(int b);
func AtGradSetEnabled(b int) int {
cbool := *(*C.int)(unsafe.Pointer(&b))
cretVal := C.at_grad_set_enabled(cbool)
return *(*int)(unsafe.Pointer(&cretVal))
}
/*
* optimizer ato_adam(double learning_rate,
* double beta1,
* double beta2,
* double weight_decay);
* */
func AtoAdam(learningRate, beta1, beta2, weightDecay float64) Coptimizer {
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
cbeta1 := *(*C.double)(unsafe.Pointer(&beta1))
cbeta2 := *(*C.double)(unsafe.Pointer(&beta2))
cweightDecay := *(*C.double)(unsafe.Pointer(&weightDecay))
return C.ato_adam(clearningRate, cbeta1, cbeta2, cweightDecay)
}
/*
* optimizer ato_rms_prop(double learning_rate,
* double alpha,
* double eps,
* double weight_decay,
* double momentum,
* int centered);
* */
func AtoRmsProp(learningRate, alpha, eps, weightDecay, momentum float64, centered int) Coptimizer {
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
calpha := *(*C.double)(unsafe.Pointer(&alpha))
ceps := *(*C.double)(unsafe.Pointer(&eps))
cweightDecay := *(*C.double)(unsafe.Pointer(&weightDecay))
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
ccentered := *(*C.int)(unsafe.Pointer(&centered))
return C.ato_rms_prop(clearningRate, calpha, ceps, cweightDecay, cmomentum, ccentered)
}
/*
* optimizer ato_sgd(double learning_rate,
* double momentum,
* double dampening,
* double weight_decay,
* int nesterov);
* */
func AtoSgd(learningRate, momentum, dampening, weightDecay float64, nesterov int) Coptimizer {
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
cdampening := *(*C.double)(unsafe.Pointer(&dampening))
cweightDecay := *(*C.double)(unsafe.Pointer(&weightDecay))
cnesterov := *(*C.int)(unsafe.Pointer(&nesterov))
return C.ato_sgd(clearningRate, cmomentum, cdampening, cweightDecay, cnesterov)
}
// void ato_add_parameters(optimizer, tensor *, int ntensors);
func AtoAddParameters(coptimizer Coptimizer, tensors []Ctensor, ntensors int) {
var ctensors []C.tensor
for i := 0; i < len(tensors); i++ {
ctensors = append(ctensors, (C.tensor)(tensors[i]))
}
cntensors := *(*C.int)(unsafe.Pointer(&ntensors))
// Just give pointer to the first element of ctensors slice
C.ato_add_parameters(coptimizer, &ctensors[0], cntensors)
}
// void ato_set_learning_rate(optimizer, double learning_rate);
func AtoSetLearningRate(coptimizer Coptimizer, learningRate float64) {
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
C.ato_set_learning_rate(coptimizer, clearningRate)
}
// void ato_set_momentum(optimizer, double momentum);
func AtoSetMomentum(coptimizer Coptimizer, momentum float64) {
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
C.ato_set_momentum(coptimizer, cmomentum)
}
// void ato_zero_grad(optimizer);
func AtoZeroGrad(coptimizer Coptimizer) {
C.ato_zero_grad(coptimizer)
}
// void ato_step(optimizer);
func AtoStep(coptimizer Coptimizer) {
C.ato_step(coptimizer)
}
// void ato_free(optimizer);
func AtoFree(coptimizer Coptimizer) {
C.ato_free(coptimizer)
}
// tensor at_load_image(char *filename);
func AtLoadImage(path string) Ctensor {
cpath := C.CString(path)
defer C.free(unsafe.Pointer(cpath))
return C.at_load_image(cpath)
}
// int at_save_image(tensor, char *filename);
func AtSaveImage(ts Ctensor, path string) {
cpath := C.CString(path)
defer C.free(unsafe.Pointer(cpath))
// TODO: we don't take the return value
// as we handle error with `TochErr()` anyway
_ = C.at_save_image(ts, cpath)
}
// tensor at_resize_image(tensor, int w, int h);
func AtResizeImage(ts Ctensor, w, h int64) Ctensor {
cw := *(*C.int)(unsafe.Pointer(&w))
ch := *(*C.int)(unsafe.Pointer(&h))
return C.at_resize_image(ts, cw, ch)
}
// ivalue ati_none();
func AtiNone() Civalue {
return C.ati_none()
}
// ivalue ati_tensor(tensor);
func AtiTensor(ts Ctensor) Civalue {
return C.ati_tensor(ts)
}
// ivalue ati_int(int64_t);
func AtiInt(val int64) Civalue {
cval := *(*C.int64_t)(unsafe.Pointer(&val))
result := C.ati_int(cval)
return result
}
// ivalue ati_double(double);
func AtiDouble(val float64) Civalue {
cval := *(*C.double)(unsafe.Pointer(&val))
return C.ati_double(cval)
}
// ivalue ati_bool(int);
func AtiBool(val bool) Civalue {
ival := 0
if val {
ival = 1
}
cval := *(*C.int)(unsafe.Pointer(&ival))
return C.ati_bool(cval)
}
// ivalue ati_string(char *);
func AtiString(val string) Civalue {
cval := C.CString(val)
return C.ati_string(cval)
}
// ivalue ati_tuple(ivalue *, int);
func AtiTuple(tupleData []Civalue, tupleLen int) Civalue {
ctupleDataPtr := (*C.ivalue)(unsafe.Pointer(&tupleData[0]))
ctupleLen := *(*C.int)(unsafe.Pointer(&tupleLen))
return C.ati_tuple(ctupleDataPtr, ctupleLen)
}
// ivalue ati_generic_list(ivalue *, int);
func AtiGenericList(genericListData []Civalue, genericListLen int) Civalue {
cgenericListDataPtr := (*C.ivalue)(unsafe.Pointer(&genericListData[0]))
cgenericListLen := *(*C.int)(unsafe.Pointer(&genericListLen))
return C.ati_generic_list(cgenericListDataPtr, cgenericListLen)
}
// ivalue ati_generic_dict(ivalue *, int);
func AtiGenericDict(genericDictData []Civalue, genericDictLen int) Civalue {
cgenericDictDataPtr := (*C.ivalue)(unsafe.Pointer(&genericDictData[0]))
cgenericDictLen := *(*C.int)(unsafe.Pointer(&genericDictLen))
return C.ati_generic_dict(cgenericDictDataPtr, cgenericDictLen)
}
// ivalue ati_int_list(int64_t *, int);
func AtiIntList(intListData []int64, intListLen int) Civalue {
cintListDataPtr := (*C.int64_t)(unsafe.Pointer(&intListData[0]))
cintListLen := *(*C.int)(unsafe.Pointer(&intListLen))
return C.ati_int_list(cintListDataPtr, cintListLen)
}
// ivalue ati_double_list(double *, int);
func AtiDoubleList(doubleListData []float64, doubleListLen int) Civalue {
cdoubleListDataPtr := (*C.double)(unsafe.Pointer(&doubleListData[0]))
cdoubleListLen := *(*C.int)(unsafe.Pointer(&doubleListLen))
return C.ati_double_list(cdoubleListDataPtr, cdoubleListLen)
}
// ivalue ati_bool_list(char *, int);
func AtiBoolList(boolListData []bool, boolListLen int) Civalue {
var cboolListData []int
for _, v := range boolListData {
item := 0
if v {
item = 1
}
cboolListData = append(cboolListData, item)
}
cboolListDataPtr := (*C.char)(unsafe.Pointer(&cboolListData[0]))
cboolListLen := *(*C.int)(unsafe.Pointer(&boolListLen))
return C.ati_bool_list(cboolListDataPtr, cboolListLen)
}
// ivalue ati_tensor_list(tensor *, int);
func AtiTensorList(tensorListData []Ctensor, tensorListLen int) Civalue {
ctensorListDataPtr := (*C.tensor)(unsafe.Pointer(&tensorListData[0]))
ctensorListLen := *(*C.int)(unsafe.Pointer(&tensorListLen))
return C.ati_tensor_list(ctensorListDataPtr, ctensorListLen)
}
// tensor ati_to_tensor(ivalue);
func AtiToTensor(val Civalue) Ctensor {
return C.ati_to_tensor(val)
}
// int64_t ati_to_int(ivalue);
func AtiToInt(val Civalue) int64 {
cval := C.ati_to_int(val)
return *(*int64)(unsafe.Pointer(&cval))
}
// double ati_to_double(ivalue);
func AtiToDouble(val Civalue) float64 {
cval := C.ati_to_double(val)
return *(*float64)(unsafe.Pointer(&cval))
}
// char *ati_to_string(ivalue);
func AtiToString(val Civalue) string {
cval := C.ati_to_string(val)
return C.GoString(cval)
}
// int ati_to_bool(ivalue);
func AtiToBool(val Civalue) bool {
cval := C.ati_to_bool(val)
goval := *(*int32)(unsafe.Pointer(&cval))
return goval == 1
}
// int ati_length(ivalue);
func AtiLength(val Civalue) int32 {
cval := C.ati_length(val)
return *(*int32)(unsafe.Pointer(&cval))
}
// int ati_tuple_length(ivalue);
func AtiTupleLength(val Civalue) int32 {
cval := C.ati_tuple_length(val)
return *(*int32)(unsafe.Pointer(&cval))
}
// void ati_to_tuple(ivalue, ivalue *, int);
func AtiToTuple(val Civalue, ptr *Civalue, tupleLen int) {
ctupleLen := *(*C.int)(unsafe.Pointer(&tupleLen))
C.ati_to_tuple(val, ptr, ctupleLen)
}
// void ati_to_generic_list(ivalue, ivalue *, int);
func AtiToGenericList(val Civalue, ptr *Civalue, genericListLen int) {
cgenericListLen := *(*C.int)(unsafe.Pointer(&genericListLen))
C.ati_to_generic_list(val, ptr, cgenericListLen)
}
// void ati_to_generic_dict(ivalue, ivalue *, int);
func AtiToGenericDict(val Civalue, ptr *Civalue, genericDictLen int) {
cgenericDictLen := *(*C.int)(unsafe.Pointer(&genericDictLen))
C.ati_to_generic_dict(val, ptr, cgenericDictLen)
}
// void ati_to_int_list(ivalue, int64_t *, int);
func AtiToIntList(val Civalue, ptr unsafe.Pointer, intListLen int) {
cptr := (*C.int64_t)(ptr)
cintListLen := *(*C.int)(unsafe.Pointer(&intListLen))
C.ati_to_int_list(val, cptr, cintListLen)
}
// void ati_to_double_list(ivalue, double *, int);
func AtiToDoubleList(val Civalue, ptr unsafe.Pointer, doubleListLen int) {
cptr := (*C.double)(ptr)
cdoubleListLen := *(*C.int)(unsafe.Pointer(&doubleListLen))
C.ati_to_double_list(val, cptr, cdoubleListLen)
}
// void ati_to_bool_list(ivalue, char *, int);
func AtiToBoolList(val Civalue, ptr unsafe.Pointer, boolListLen int) {
cptr := (*C.char)(ptr)
cboolListLen := *(*C.int)(unsafe.Pointer(&boolListLen))
C.ati_to_bool_list(val, cptr, cboolListLen)
}
// void ati_to_tensor_list(ivalue, tensor *, int);
func AtiToTensorList(val Civalue, ptr *Ctensor, tensorListLen int) {
ctensorListLen := *(*C.int)(unsafe.Pointer(&tensorListLen))
C.ati_to_tensor_list(val, ptr, ctensorListLen)
}
// int ati_tag(ivalue);
func AtiTag(val Civalue) int32 {
ctag := C.ati_tag(val)
return *(*int32)(unsafe.Pointer(&ctag))
}
// void ati_free(ivalue);
func AtiFree(val Civalue) {
C.ati_free(val)
}
// module atm_load(char *);
func AtmLoad(path string) Cmodule {
ptr := C.CString(path)
return C.atm_load(ptr)
}
// module atm_load_on_device(char *, int device);
func AtmLoadOnDevice(path string, device int32) Cmodule {
ptr := C.CString(path)
cdevice := *(*C.int)(unsafe.Pointer(&device))
return C.atm_load_on_device(ptr, cdevice)
}
// module atm_load_str(char *, size_t sz);
func AtmLoadStr(val string, sz int) Cmodule {
ptr := C.CString(val)
csz := *(*C.size_t)(unsafe.Pointer(&sz))
return C.atm_load_str(ptr, csz)
}
// module atm_load_str_on_device(char *, size_t sz, int device);
func AtmLoadStrOnDevice(val string, sz int, device int32) Cmodule {
ptr := C.CString(val)
csz := *(*C.size_t)(unsafe.Pointer(&sz))
cdevice := *(*C.int)(unsafe.Pointer(&device))
return C.atm_load_str_on_device(ptr, csz, cdevice)
}
// tensor atm_forward(module, tensor *tensors, int ntensors);
func AtmForward(m Cmodule, tensors *Ctensor, ntensors int) Ctensor {
cntensors := *(*C.int)(unsafe.Pointer(&ntensors))
return C.atm_forward(m, tensors, cntensors)
}
// ivalue atm_forward_(module, ivalue *ivalues, int nivalues);
func AtmForward_(m Cmodule, ivalues *Civalue, nivalues int) Civalue {
cnivalues := *(*C.int)(unsafe.Pointer(&nivalues))
return C.atm_forward_(m, ivalues, cnivalues)
}
// void atm_free(module);
func AtmFree(m Cmodule) {
C.atm_free(m)
}
// void atm_to(module m, int device, int dtype, bool non_blocking);
func AtmTo(m Cmodule, device int32, dtype int32, nonBlocking bool) {
cdevice := *(*C.int)(unsafe.Pointer(&device))
cdtype := *(*C.int)(unsafe.Pointer(&dtype))
cnonBlocking := *(*C.bool)(unsafe.Pointer(&nonBlocking))
C.atm_to(m, cdevice, cdtype, cnonBlocking)
}