gotch/libtch/tensor.go
2020-06-11 17:23:38 +10:00

443 lines
13 KiB
Go

package libtch
//#include "stddef.h"
//#include "stdbool.h"
//#include "torch_api.h"
//#include "stdlib.h"
//void callback_fn(void *, char *, tensor);
//typedef void (*f)(void *, char *, tensor);
import "C"
import (
"unsafe"
)
// NOTE: C.tensor is a C pointer to torch::Tensor
type Ctensor = C.tensor
type Cscalar = C.scalar
type Coptimizer = C.optimizer
type NamedCtensor struct {
Name string
Ctensor C.tensor
}
type LoadData struct {
NamedCtensors []NamedCtensor
}
var PStore = NewPointerStore()
func AtNewTensor() Ctensor {
return C.at_new_tensor()
}
// tensor at_new_tensor();
func NewTensor() Ctensor {
return C.at_new_tensor()
}
// tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims, size_t element_size_in_bytes, int type);
func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_bytes uint, kind int) Ctensor {
// just get pointer of the first element of shape
c_dims := (*C.int64_t)(unsafe.Pointer(&dims[0]))
c_ndims := *(*C.size_t)(unsafe.Pointer(&ndims))
c_elt_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&elt_size_in_bytes))
c_kind := *(*C.int)(unsafe.Pointer(&kind))
return C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
}
// void at_print(tensor);
func AtPrint(t Ctensor) {
C.at_print(t)
}
// void *at_data_ptr(tensor);
func AtDataPtr(t Ctensor) unsafe.Pointer {
return C.at_data_ptr(t)
}
// size_t at_dim(tensor);
func AtDim(t Ctensor) uint64 {
result := C.at_dim(t)
return *(*uint64)(unsafe.Pointer(&result))
}
// void at_shape(tensor, int64_t *);
func AtShape(t Ctensor, ptr unsafe.Pointer) {
c_ptr := (*C.long)(ptr)
C.at_shape(t, c_ptr)
}
// int at_scalar_type(tensor);
func AtScalarType(t Ctensor) int32 {
result := C.at_scalar_type(t)
return *(*int32)(unsafe.Pointer(&result))
}
func GetAndResetLastErr() *C.char {
return C.get_and_reset_last_err()
}
// int atc_cuda_device_count();
func AtcCudaDeviceCount() int {
result := C.atc_cuda_device_count()
return *(*int)(unsafe.Pointer(&result))
}
// int atc_cuda_is_available();
func AtcCudaIsAvailable() bool {
result := C.atc_cuda_is_available()
return *(*bool)(unsafe.Pointer(&result))
}
// int atc_cudnn_is_available();
func AtcCudnnIsAvailable() bool {
result := C.atc_cudnn_is_available()
return *(*bool)(unsafe.Pointer(&result))
}
// void atc_set_benchmark_cudnn(int b);
func AtcSetBenchmarkCudnn(b int) {
cb := *(*C.int)(unsafe.Pointer(&b))
C.atc_set_benchmark_cudnn(cb)
}
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
func AtDoubleValueAtIndexes(ts Ctensor, indexes unsafe.Pointer, indexesLen int) float64 {
ctensor := (C.tensor)(ts)
cindexes := (*C.long)(indexes)
cindexesLen := *(*C.int)(unsafe.Pointer(&indexesLen))
retVal := C.at_double_value_at_indexes(ctensor, cindexes, cindexesLen)
return *(*float64)(unsafe.Pointer(&retVal))
}
// int64_t at_int64_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
func AtInt64ValueAtIndexes(ts Ctensor, indexes unsafe.Pointer, indexesLen int) int64 {
ctensor := (C.tensor)(ts)
cindexes := (*C.long)(indexes)
cindexesLen := *(*C.int)(unsafe.Pointer(&indexesLen))
retVal := C.at_int64_value_at_indexes(ctensor, cindexes, cindexesLen)
return *(*int64)(unsafe.Pointer(&retVal))
}
// int at_requires_grad(tensor);
func AtRequiresGrad(ts Ctensor) bool {
retVal := C.at_requires_grad((C.tensor)(ts))
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_defined(tensor);
func AtDefined(ts Ctensor) bool {
retVal := C.at_defined((C.tensor)(ts))
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_is_sparse(tensor);
func AtIsSparse(ts Ctensor) bool {
retVal := C.at_is_sparse((C.tensor)(ts))
return *(*bool)(unsafe.Pointer(&retVal))
}
// void at_backward(tensor, int, int);
func AtBackward(ts Ctensor, keepGraph int, createGraph int) {
ctensor := (C.tensor)(ts)
ckeepGraph := *(*C.int)(unsafe.Pointer(&keepGraph))
ccreateGraph := *(*C.int)(unsafe.Pointer(&createGraph))
C.at_backward(ctensor, ckeepGraph, ccreateGraph)
}
/*
* void at_run_backward(tensor *tensors,
* int ntensors,
* tensor *inputs,
* int ninputs,
* tensor *outputs,
* int keep_graph,
* int create_graph);
* */
func AtRunBackward(tensorsPtr *Ctensor, ntensors int, inputsPtr *Ctensor, ninputs int, outputsPtr *Ctensor, keepGraph int, createGraph int) {
cntensors := *(*C.int)(unsafe.Pointer(&ntensors))
cninputs := *(*C.int)(unsafe.Pointer(&ninputs))
ckeepGraph := *(*C.int)(unsafe.Pointer(&keepGraph))
ccreateGraph := *(*C.int)(unsafe.Pointer(&createGraph))
C.at_run_backward(tensorsPtr, cntensors, inputsPtr, cninputs, outputsPtr, ckeepGraph, ccreateGraph)
}
// void at_copy_data(tensor tensor, void *vs, size_t numel, size_t element_size_in_bytes);
func AtCopyData(tensor Ctensor, vs unsafe.Pointer, numel uint, element_size_in_bytes uint) {
ctensor := (C.tensor)(tensor)
cnumel := *(*C.size_t)(unsafe.Pointer(&numel))
celement_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&element_size_in_bytes))
C.at_copy_data(ctensor, vs, cnumel, celement_size_in_bytes)
}
// tensor at_shallow_clone(tensor);
func AtShallowClone(ts Ctensor) Ctensor {
ctensor := (C.tensor)(ts)
return C.at_shallow_clone(ctensor)
}
// tensor at_get(tensor, int index);
func AtGet(ts Ctensor, index int) Ctensor {
ctensor := (C.tensor)(ts)
cindex := *(*C.int)(unsafe.Pointer(&index))
return C.at_get(ctensor, cindex)
}
// void at_copy_(tensor dst, tensor src);
func AtCopy_(dst Ctensor, src Ctensor) {
cdst := (C.tensor)(dst)
csrc := (C.tensor)(src)
C.at_copy_(cdst, csrc)
}
// void at_save(tensor, char *filename);
func AtSave(ts Ctensor, path string) {
ctensor := (C.tensor)(ts)
cstringPtr := C.CString(path)
defer C.free(unsafe.Pointer(cstringPtr))
C.at_save(ctensor, cstringPtr)
}
// tensor at_load(char *filename);
func AtLoad(path string) Ctensor {
cstringPtr := C.CString(path)
defer C.free(unsafe.Pointer(cstringPtr))
return C.at_load(cstringPtr)
}
// void at_save_multi(tensor *tensors, char **tensor_names, int ntensors, char *filename);
func AtSaveMulti(tensors []Ctensor, tensor_names []string, ntensors int, filename string) {
var ctensors []C.tensor
for i := 0; i < len(tensors); i++ {
ctensors = append(ctensors, (C.tensor)(tensors[i]))
}
cpointerSize := 4
cnamesPtr := (*[1 << 30]**C.char)(C.malloc(C.size_t(cpointerSize * len(tensor_names))))
for i := 0; i < len(tensor_names); i++ {
cname := C.CString(tensor_names[i])
cnamesPtr[i] = &cname
// defer C.free(unsafe.Pointer(cnamesPtr[i]))
}
cntensors := *(*C.int)(unsafe.Pointer(&ntensors))
cfilename := C.CString(filename)
C.at_save_multi(&ctensors[0], cnamesPtr[0], cntensors, cfilename)
}
/* [at_load_multi] takes as input an array of nullptr for [tensors]. */
// void at_load_multi(tensor *tensors, char **tensor_names, int ntensors, char *filename);
func AtLoadMulti(tensors []Ctensor, tensor_names []string, ntensors int, filename string) {
// TODO: implement this
}
// void at_load_callback(char *filename, void *data, void (*f)(void *, char *, tensor));
/*
* void at_load_callback(char *filename, void *data, void (*f)(void *, char *, tensor)) {
* PROTECT(
* auto module = torch::jit::load(filename);
* for (const auto &p : module.named_parameters()) {
* auto v = p.value;
* f(data, (char*)p.name.c_str(), new torch::Tensor(v));
* }
* )
* }
* */
func AtLoadCallback(filename string, dataPtr unsafe.Pointer) {
cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename))
C.at_load_callback(cfilename, dataPtr, C.f(C.callback_fn))
}
//TODO: move `callback_fn` to wrapper package???
//export callback_fn
func callback_fn(dataPtr unsafe.Pointer, name *C.char, ctensor C.tensor) {
tsName := C.GoString(name)
namedCtensor := NamedCtensor{
Name: tsName,
Ctensor: ctensor,
}
data := PStore.Get(dataPtr).(*LoadData)
data.NamedCtensors = append(data.NamedCtensors, namedCtensor)
}
/*
* void at_load_callback_with_device(char *filename, void *data, void (*f)(void *, char *, tensor), int device_id) {
* PROTECT(
* auto module = torch::jit::load(filename, device_of_int(device_id));
* for (const auto &p : module.named_parameters()) {
* auto v = p.value;
* f(data, (char*)p.name.c_str(), new torch::Tensor(v));
* }
* )
* }
* */
func AtLoadCallbackWithDevice(filename string, dataPtr unsafe.Pointer, device int32) {
cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename))
cdevice := *(*C.int)(unsafe.Pointer(&device))
C.at_load_callback_with_device(cfilename, dataPtr, C.f(C.callback_fn), cdevice)
}
/*
* char *at_to_string(tensor t, int line_size) {
* PROTECT(
* std::ostringstream oss;
* torch::print(oss, *t, line_size);
* return strdup(oss.str().c_str());
* )
* return nullptr;
* }
* */
func AtToString(ts Ctensor, lineSize int64) string {
ctensor := (C.tensor)(ts)
clineSize := *(*C.int)(unsafe.Pointer(&lineSize))
charPtr := C.at_to_string(ctensor, clineSize)
goString := C.GoString(charPtr)
return goString
}
// void at_free(tensor);
func AtFree(ts Ctensor) {
ctensor := (C.tensor)(ts)
C.at_free(ctensor)
}
//int at_grad_set_enabled(int b);
func AtGradSetEnabled(b int) int {
cbool := *(*C.int)(unsafe.Pointer(&b))
cretVal := C.at_grad_set_enabled(cbool)
return *(*int)(unsafe.Pointer(&cretVal))
}
/*
* optimizer ato_adam(double learning_rate,
* double beta1,
* double beta2,
* double weight_decay);
* */
func AtoAdam(learningRate, beta1, beta2, weightDecay float64) Coptimizer {
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
cbeta1 := *(*C.double)(unsafe.Pointer(&beta1))
cbeta2 := *(*C.double)(unsafe.Pointer(&beta2))
cweightDecay := *(*C.double)(unsafe.Pointer(&weightDecay))
return C.ato_adam(clearningRate, cbeta1, cbeta2, cweightDecay)
}
/*
* optimizer ato_rms_prop(double learning_rate,
* double alpha,
* double eps,
* double weight_decay,
* double momentum,
* int centered);
* */
func AtoRmsProp(learningRate, alpha, eps, weightDecay, momentum float64, centered int) Coptimizer {
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
calpha := *(*C.double)(unsafe.Pointer(&alpha))
ceps := *(*C.double)(unsafe.Pointer(&eps))
cweightDecay := *(*C.double)(unsafe.Pointer(&weightDecay))
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
ccentered := *(*C.int)(unsafe.Pointer(&centered))
return C.ato_rms_prop(clearningRate, calpha, ceps, cweightDecay, cmomentum, ccentered)
}
/*
* optimizer ato_sgd(double learning_rate,
* double momentum,
* double dampening,
* double weight_decay,
* int nesterov);
* */
func AtoSgd(learningRate, momentum, dampening, weightDecay float64, nesterov int) Coptimizer {
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
cdampening := *(*C.double)(unsafe.Pointer(&dampening))
cweightDecay := *(*C.double)(unsafe.Pointer(&weightDecay))
cnesterov := *(*C.int)(unsafe.Pointer(&nesterov))
return C.ato_sgd(clearningRate, cmomentum, cdampening, cweightDecay, cnesterov)
}
// void ato_add_parameters(optimizer, tensor *, int ntensors);
func AtoAddParameters(coptimizer Coptimizer, tensors []Ctensor, ntensors int) {
var ctensors []C.tensor
for i := 0; i < len(tensors); i++ {
ctensors = append(ctensors, (C.tensor)(tensors[i]))
}
cntensors := *(*C.int)(unsafe.Pointer(&ntensors))
// Just give pointer to the first element of ctensors slice
C.ato_add_parameters(coptimizer, &ctensors[0], cntensors)
}
// void ato_set_learning_rate(optimizer, double learning_rate);
func AtoSetLearningRate(coptimizer Coptimizer, learningRate float64) {
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
C.ato_set_learning_rate(coptimizer, clearningRate)
}
// void ato_set_momentum(optimizer, double momentum);
func AtoSetMomentum(coptimizer Coptimizer, momentum float64) {
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
C.ato_set_momentum(coptimizer, cmomentum)
}
// void ato_zero_grad(optimizer);
func AtoZeroGrad(coptimizer Coptimizer) {
C.ato_zero_grad(coptimizer)
}
// void ato_step(optimizer);
func AtoStep(coptimizer Coptimizer) {
C.ato_step(coptimizer)
}
// void ato_free(optimizer);
func AtoFree(coptimizer Coptimizer) {
C.ato_free(coptimizer)
}
// tensor at_load_image(char *filename);
func AtLoadImage(path string) Ctensor {
cpath := C.CString(path)
defer C.free(unsafe.Pointer(cpath))
return C.at_load_image(cpath)
}
// int at_save_image(tensor, char *filename);
func AtSaveImage(ts Ctensor, path string) {
cpath := C.CString(path)
defer C.free(unsafe.Pointer(cpath))
// TODO: we don't take the return value
// as we handle error with `TochErr()` anyway
_ = C.at_save_image(ts, cpath)
}
// tensor at_resize_image(tensor, int w, int h);
func AtResizeImage(ts Ctensor, w, h int64) Ctensor {
cw := *(*C.int)(unsafe.Pointer(&w))
ch := *(*C.int)(unsafe.Pointer(&h))
return C.at_resize_image(ts, cw, ch)
}