Goncalves Henriques, Andre (UG - Computer Science) 9257404edd Move the name of the module
2024-04-21 15:15:00 +01:00

1557 lines
36 KiB
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package ts
//#include "stdlib.h"
//#include "stdbool.h"
import "C"
import (
gotch ""
lib ""
var (
TensorCount int64 // incremental counting created tensors
ScalarCount int64 // incremental counting created scalars
AllocatedMem int64 // bytes - keeping track of memory created and still occupied by gotch/tensor (excluding mem allocated by libtorch at C side)
ExistingTensors map[string]struct{} = make(map[string]struct{}) // keep track of existing tensors by name
ExistingScalars map[string]struct{} = make(map[string]struct{}) // keep track of existing scalar by name
lock sync.Mutex
// NOTE. None is an undefined tensor.
// It can be used in optional tensor parameter where 'None' value used.
// `ts.MustDefined()` function is used for checking 'null'
var None = NewTensor()
type bigStruct struct {
lots [1e5]byte // 100k - always on host memory.
// Tensor is a Go wrapper of a "C tensor pointer" - 8 Bytes (64-bits OS)
// or 4 Bytes (32-bits OS).
// `ctensor` is just a "C pointer" to `torch::Tensor` (torch::Tensor *lib.Ctensor)
// NOTE.Tensor should be big enough to be in heap memory.
// (yes, we choose to place tensor consistently in heap memory so that
// it can be targeted by Go garbage collector).
// For heap allocation see.
type Tensor struct {
d *bigStruct
name string
ctensor lib.Ctensor
calledFrom string
func newTensor(ctensor lib.Ctensor, nameOpt ...string) *Tensor {
if len(nameOpt) == 0 {
nameOpt = []string{}
x := new(Tensor)
x.ctensor = ctensor
x.d = new(bigStruct)
atomic.AddInt64(&TensorCount, 1)
if gotch.Debug {
nbytes := x.nbytes()
atomic.AddInt64(&AllocatedMem, nbytes)
log.Printf("INFO: Added tensor %q - Allocated memory: %d bytes.\n",, nbytes)
name := newName(nameOpt...)
if _, ok := ExistingTensors[name]; ok {
name = fmt.Sprintf("%s_%09d", name, TensorCount)
ExistingTensors[name] = struct{}{}
lock.Unlock() = name
x.calledFrom = "newTensor()"
runtime.SetFinalizer(x, freeCTensor)
return x
// New creates new tensor from C tensor.
func New(ctensor lib.Ctensor, nameOpt ...string) *Tensor {
return newTensor(ctensor, nameOpt...)
func CheckCMemLeak() string {
tensors := []string{}
for n := range ExistingTensors {
tensors = append(tensors, n)
memUsed := AllocatedMem
var msg string
msg += fmt.Sprintf("============================= C MEMORY CHECK RESULT ==================================\n")
msg += fmt.Sprintf("C memory allocated not been released: %v bytes\n", memUsed)
msg += fmt.Sprintf("Tensors not been released: %q\n", tensors)
msg += fmt.Sprintf("======================================================================================\n")
return msg
// CleanUp calls double runtime.GC() with sleep time in between.
func CleanUp(sleepTimeOpt {
sleepTime := time.Duration(1000) // 1 second
if len(sleepTimeOpt) > 0 {
sleepTime = time.Duration(sleepTimeOpt[0])
time.Sleep(time.Millisecond * sleepTime)
// Ctensor return C pointer value.
func (ts *Tensor) Ctensor() unsafe.Pointer {
return unsafe.Pointer(ts.ctensor)
// free releases C allocated memory.
func freeCTensor(ts *Tensor) error {
if ts == nil || ts.ctensor == nil {
return nil
defer lock.Unlock()
if _, ok := ExistingTensors[]; !ok {
log.Printf("WARNING: Probably double free tensor %q. Called from %q. Just skipping...\n",, ts.calledFrom)
return nil
if gotch.Debug {
shape, err := ts.Size()
if err != nil {
err = fmt.Errorf("ERROR: failed to release tensor %q: %w\n",, err)
numel := uint(FlattenDim(shape))
dtype := ts.DType()
nbytes := int64(numel * dtype.Size())
atomic.AddInt64(&AllocatedMem, -nbytes)
log.Printf("INFO: Released tensor %q - C memory(%d bytes).\n",, nbytes)
if err := TorchErr(); err != nil {
err := fmt.Errorf("ERROR: failed to release tensor %q - %w",, err)
return err
// IMPORTANT. make it nil so won't double free.
ts.ctensor = nil
// Clear SetFinalizer on ts so no double free tensor.
// Ref.
runtime.SetFinalizer(ts, nil)
return nil
func newName(nameOpt ...string) string {
var name string
if len(nameOpt) > 0 {
name = nameOpt[0]
} else {
name = fmt.Sprintf("tensor_%09d", TensorCount)
return name
// NewTensor creates a new tensor
func NewTensor(nameOpt ...string) *Tensor {
ctensor := lib.AtNewTensor()
return newTensor(ctensor, nameOpt...)
func FromCtensor(ctensor unsafe.Pointer) *Tensor {
cts := (lib.Ctensor)(ctensor)
return newTensor(cts)
func (ts *Tensor) Name() string {
func (ts *Tensor) Dim() uint64 {
dim := lib.AtDim(ts.ctensor)
if err := TorchErr(); err != nil {
return dim
// Size return shape of the tensor
// NOTE: C++ libtorch calls at_shape() -> t.sizes()
// And returns a slice of sizes or shape using given pointer
// to that slice.
func (ts *Tensor) Size() ([]int64, error) {
dim := lib.AtDim(ts.ctensor)
if dim < 0 || dim > 100 {
err := fmt.Errorf("Invalid dim: %v\n", dim)
return nil, err
sz := make([]int64, dim)
szPtr, err := DataAsPtr(sz)
if err != nil {
return nil, err
lib.AtShape(ts.ctensor, szPtr)
if err = TorchErr(); err != nil {
return nil, err
shape := decodeSize(szPtr, dim)
return shape, nil
func (ts *Tensor) MustSize() []int64 {
shape, err := ts.Size()
if err != nil {
return shape
func (ts *Tensor) Stride() ([]int64, error) {
dim := lib.AtDim(ts.ctensor)
sz := make([]int64, dim)
szPtr, err := DataAsPtr(sz)
if err != nil {
return nil, err
lib.AtStride(ts.ctensor, szPtr)
if err = TorchErr(); err != nil {
return nil, err
strides := decodeSize(szPtr, dim)
return strides, nil
func (ts *Tensor) MustStride() []int64 {
strides, err := ts.Stride()
if err != nil {
return strides
// Size1 returns the tensor size for 1D tensors.
func (ts *Tensor) Size1() (int64, error) {
shape, err := ts.Size()
if err != nil {
return 0, err
if len(shape) != 1 {
err = fmt.Errorf("Expected one dim, got %v\n", len(shape))
return 0, err
return shape[0], nil
// Size2 returns the tensor size for 2D tensors.
func (ts *Tensor) Size2() ([]int64, error) {
shape, err := ts.Size()
if err != nil {
return nil, err
if len(shape) != 2 {
err = fmt.Errorf("Expected two dims, got %v\n", len(shape))
return nil, err
return shape, nil
// Size3 returns the tensor size for 3D tensors.
func (ts *Tensor) Size3() ([]int64, error) {
shape, err := ts.Size()
if err != nil {
return nil, err
if len(shape) != 3 {
err = fmt.Errorf("Expected three dims, got %v\n", len(shape))
return nil, err
return shape, nil
// Size4 returns the tensor size for 4D tensors.
func (ts *Tensor) Size4() ([]int64, error) {
shape, err := ts.Size()
if err != nil {
return nil, err
if len(shape) != 4 {
err = fmt.Errorf("Expected four dims, got %v\n", len(shape))
return nil, err
return shape, nil
// nbytes calculates tensor data size in bytes.
func (ts *Tensor) nbytes() int64 {
numel := ts.Numel()
if numel == 0 {
return 0 // ts.None
return int64(numel * ts.DType().Size())
func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
dtype := gotch.Int64 // tensor size dtype = int64
nbytes := int(nsize) * int(dtype.Size())
dataSlice := (*[1 << 30]byte)(ptr)[:nbytes:nbytes]
r := bytes.NewReader(dataSlice)
dataIn := make([]int64, nsize)
if err := binary.Read(r, nativeEndian, dataIn); err != nil {
return dataIn
// TensorOptions constructs options to build/rebuild tensor.
type TensorOptions struct {
Name string
DType gotch.DType
Quantized bool
// TODO. can expand as needed
type TensorOpt func(*TensorOptions)
func DefaultTensorOptions() *TensorOptions {
return &TensorOptions{
Name: "",
DType: gotch.Float,
Quantized: false,
func WithName(v string) TensorOpt {
return func(o *TensorOptions) {
o.Name = v
func WithDType(v gotch.DType) TensorOpt {
return func(o *TensorOptions) {
o.DType = v
func WithQuantized(v bool) TensorOpt {
return func(o *TensorOptions) {
o.Quantized = v
// OfSlice creates tensor from a slice data
func OfSlice(data interface{}, opts ...TensorOpt) (*Tensor, error) {
o := DefaultTensorOptions()
for _, opt := range opts {
// convert []int -> int32. `binary.Write()` can't write `[]int` because it's not fixed-size!
if reflect.TypeOf(data).String() == "[]int" {
data = sliceIntToInt32(data.([]int))
v := reflect.ValueOf(data)
kind := v.Kind().String()
if kind != "slice" && kind != "array" {
err := fmt.Errorf("Expected slice data. Got %q", kind)
return nil, err
elementKind := reflect.TypeOf(data).Elem().Kind()
dataLen := v.Len()
dtype, err := gotch.GoKind2DType(elementKind, gotch.HalfDTypePref(o.DType), gotch.WithQuantized(o.Quantized))
if err != nil {
return nil, err
shape := []int64{int64(dataLen)}
elementNum := ElementCount(shape)
nbytes := int(dtype.Size()) * elementNum
dataPtr, buff := CMalloc(nbytes)
if err = EncodeTensor(buff, reflect.ValueOf(data), shape); err != nil {
return nil, err
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(dtype.Size()), int(dtype.CKind()))
if err = TorchErr(); err != nil {
return nil, err
return newTensor(ctensor, o.Name), nil
// return newTensor(ctensor), nil
// OfDataSize creates Tensor from input byte data, shape and dtype.
func OfDataSize(data []byte, shape []int64, dtype gotch.DType, opts ...TensorOpt) (*Tensor, error) {
o := DefaultTensorOptions()
for _, opt := range opts {
elementNum := ElementCount(shape)
nbytes := elementNum * int(dtype.Size())
if nbytes != len(data) {
err := fmt.Errorf("data and shape mismatched for dtype (%v): byte data (%v) - shape (%v).\n", dtype, len(data), shape)
return nil, err
dataPtr, buff := CMalloc(nbytes)
if err := binary.Write(buff, nativeEndian, data); err != nil {
return nil, err
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), dtype.Size(), int(dtype.CKind()))
if err := TorchErr(); err != nil {
return nil, err
return newTensor(ctensor, o.Name), nil
// return newTensor(ctensor), nil
// MustOfDataSize create Tensor from input byte data and specified shape and dtype
// or panic if error
func MustOfDataSize(data []byte, size []int64, dtype gotch.DType, opts ...TensorOpt) *Tensor {
ts, err := OfDataSize(data, size, dtype, opts...)
if err != nil {
return ts
// MustOfSlice create a tensor from slice of data. It will be panic if error.
func MustOfSlice(data interface{}, opts ...TensorOpt) *Tensor {
ts, err := OfSlice(data, opts...)
if err != nil {
return ts
// TensorFrom create a tensor from slice of data. It will be panic if error.
func TensorFrom(data interface{}, opts ...TensorOpt) *Tensor {
ts, err := OfSlice(data, opts...)
if err != nil {
return ts
// Print prints tensor values to console.
// NOTE: it is printed from C and will print ALL elements of tensor
// with no truncation at all.
func (ts *Tensor) Print() {
if err := TorchErr(); err != nil {
// NewTensorFromData creates tensor from given data and shape
func NewTensorFromData(data interface{}, shape []int64, opts ...TensorOpt) (*Tensor, error) {
o := DefaultTensorOptions()
for _, opt := range opts {
// 1. Check whether data and shape match
elementNum, err := DataDim(data)
if err != nil {
return nil, err
nflattend := FlattenDim(shape)
if elementNum != nflattend {
err = fmt.Errorf("Number of data elements (%v) and flatten shape (%v) dimension mismatched.\n", elementNum, nflattend)
return nil, err
// 2. Write raw data to C memory and get C pointer
dataPtr, err := DataAsPtr(data)
if err != nil {
return nil, err
// 3. Create tensor with pointer and shape
dtype, err := gotch.DTypeFromData(data)
if err != nil {
return nil, err
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), dtype.Size(), int(dtype.CKind()))
if err = TorchErr(); err != nil {
return nil, err
return newTensor(ctensor, o.Name), nil
func (ts *Tensor) DType() gotch.DType {
cint := lib.AtScalarType(ts.ctensor)
return gotch.CKind2DType(cint)
func (ts *Tensor) Device() (gotch.Device, error) {
var (
retVal gotch.Device
err error
cInt := lib.AtDevice(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
var device gotch.Device
return device.OfCInt(int32(cInt)), nil
func (ts *Tensor) MustDevice() gotch.Device {
device, err := ts.Device()
if err != nil {
return device
* func (ts Tensor) Eq1(other Tensor, del bool) (retVal Tensor, err error) {
* // Get a C null pointer
* //
* ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
* if del {
* defer ts.MustDrop()
* }
* lib.AtgEq1(ptr, ts.ctensor, other.ctensor)
* if err = TorchErr(); err != nil {
* return retVal, err
* }
* return Tensor{ctensor: *ptr}, nil
* }
* func (ts Tensor) MustEq1(other Tensor, del bool) (retVal Tensor) {
* retVal, err := ts.Eq1(other, del)
* if err != nil {
* log.Fatal(err)
* }
* return retVal
* }
* */
// Float64Value returns a float value on tensors holding a single element.
// An error is returned otherwise.
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
func (ts *Tensor) Float64Value(idx []int64) (float64, error) {
idxPtr, err := DataAsPtr(idx)
if err != nil {
return 0, err
f64Val := lib.AtDoubleValueAtIndexes(ts.ctensor, idxPtr, len(idx))
if err = TorchErr(); err != nil {
return 0, err
return f64Val, err
func (ts *Tensor) MustFloat64Value(idx []int64) float64 {
f64Val, err := ts.Float64Value(idx)
if err != nil {
return f64Val
// Int64Value returns an int value on tensors holding a single element. An error is
// returned otherwise.
func (ts *Tensor) Int64Value(idx []int64) (int64, error) {
var (
retVal int64
err error
idxPtr, err := DataAsPtr(idx)
if err != nil {
return retVal, err
int64Val := lib.AtInt64ValueAtIndexes(ts.ctensor, idxPtr, len(idx))
if err = TorchErr(); err != nil {
return 0, err
return int64Val, err
func (ts *Tensor) MustInt64Value(idx []int64) int64 {
int64Val, err := ts.Int64Value(idx)
if err != nil {
return int64Val
// RequiresGrad returns true if gradient are currently tracked for this tensor.
func (ts *Tensor) RequiresGrad() (bool, error) {
state := lib.AtRequiresGrad(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
return state, nil
func (ts *Tensor) MustRequiresGrad() bool {
state, err := ts.RequiresGrad()
if err != nil {
return state
// DataPtr returns the address of the first element of this tensor.
func (ts *Tensor) DataPtr() (unsafe.Pointer, error) {
datPtr := lib.AtDataPtr(ts.ctensor)
if err := TorchErr(); err != nil {
return nil, err
return datPtr, nil
func (ts *Tensor) MustDataPtr() unsafe.Pointer {
p, err := ts.DataPtr()
if err != nil {
return p
// Defined returns true is the tensor is defined.
func (ts *Tensor) Defined() (bool, error) {
state := lib.AtDefined(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
return state, nil
func (ts *Tensor) MustDefined() bool {
state, err := ts.Defined()
if err != nil {
return state
// IsSparse returns true is the tensor is spare.
func (ts *Tensor) IsSparse() (bool, error) {
state := lib.AtIsSparse(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
return state, nil
func (ts *Tensor) MustIsSparse() bool {
state, err := ts.IsSparse()
if err != nil {
return state
// IsContiguous returns true is the tensor is contiguous.
func (ts *Tensor) IsContiguous() (bool, error) {
state := lib.AtIsContiguous(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
return state, nil
func (ts *Tensor) MustIsContiguous() bool {
state, err := ts.IsContiguous()
if err != nil {
return state
// IsMkldnn returns true is the tensor is mkldnn.
func (ts *Tensor) IsMkldnn() (bool, error) {
state := lib.AtIsMkldnn(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
return state, nil
func (ts *Tensor) MustIsMkldnn() bool {
state, err := ts.IsMkldnn()
if err != nil {
return state
// ZeroGrad zeroes the gradient tensor attached to this tensor if defined.
func (ts *Tensor) ZeroGrad() {
grad := ts.MustGrad(false)
if grad.MustDefined() {
// Backward runs the backward pass, populating the gradient tensors for tensors
// which gradients are tracked.
// Gradients tracking can be turned on via `SetRequiresGrad`.
func (ts *Tensor) Backward() error {
lib.AtBackward(ts.ctensor, 0, 0)
if err := TorchErr(); err != nil {
return err
return nil
func (ts *Tensor) MustBackward() {
if err := ts.Backward(); err != nil {
// RunBackward runs the backward ...
func RunBackward(tensors []*Tensor, inputs []*Tensor, keepGraphB bool, createGraphB bool) ([]*Tensor, error) {
// NOTE: outputs is a slice of tensors with length = len(inputs)
var outputsPtr []*lib.Ctensor
// Are they allocated contigously??? Definitely not.
// TODO. calculate C memory size = C pointer size x n pointers
// Then C.malloc such calculated amount
// NOTE. replace with the following code and test.
* ntensors := len(inputs)
* nbytes := C.size_t(ntensors) * C.size_t(unsafe.Sizeof(uintptr(0)))
* ctensorsPtr := (*[1 << 30]lib.Ctensor)(C.malloc(nbytes))
* for i :=0; i < ntensors; i++ {
* outputPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
* outputsPtr[i] = outputPtr
* }
* */
for i := 0; i < len(inputs); i++ {
outputPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
outputsPtr = append(outputsPtr, outputPtr)
// Get first element pointer
ctensor := tensors[0].ctensor
cinput := inputs[0].ctensor
tensorsPtr := (*lib.Ctensor)(unsafe.Pointer(&ctensor))
inputsPtr := (*lib.Ctensor)(unsafe.Pointer(&cinput))
var keepGraph int = 0
if keepGraphB {
keepGraph = 1
var createGraph int = 0
if createGraphB {
createGraph = 1
lib.AtRunBackward(tensorsPtr, len(tensors), inputsPtr, len(inputs), outputsPtr[0], keepGraph, createGraph)
if err := TorchErr(); err != nil {
return nil, err
var oTensors []*Tensor
for i := 0; i < len(inputs); i++ {
outputPtr := outputsPtr[i]
oTensors = append(oTensors, newTensor(*outputPtr))
return oTensors, nil
// CopyDataUint8 copies `numel` elements from `self` to `dst`.
// NOTE: `dst` located in Go memory. Should it be?
func (ts *Tensor) CopyDataUint8(dst []uint8, numel uint) error {
// NOTE: we must make sure that `dst` has same len as `numel`. Otherwise,
// there will be memory leak and or out of range error.
if len(dst) < int(numel) {
err := fmt.Errorf("CopyDataUint8 Error: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", len(dst), numel)
return err
vs := unsafe.Pointer(&dst[0])
dtype := gotch.Uint8
lib.AtCopyData(ts.ctensor, vs, numel, dtype.Size())
if err := TorchErr(); err != nil {
return err
return nil
func (ts *Tensor) MustCopyDataUint8(dst []uint8, numel uint) {
err := ts.CopyDataUint8(dst, numel)
if err != nil {
// CopyData copies `numel` elements from `self` to `dst`.
// `dst` should be a slice of Go type equivalent to tensor type.
// NOTE: `dst` located in Go memory. Should it be?
// We will render Go pointer of first element of `dst` slice
// and number of elements to C land. This may break in the future
// if Go policy changes.
func (ts *Tensor) CopyData(dst interface{}, numel uint) error {
dtype, dlen, err := DataCheck(dst)
if dlen < int(numel) {
err = fmt.Errorf("ts.CopyData() failed: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel)
return err
if ts.DType() != dtype {
err = fmt.Errorf("ts.CopyData() failed: Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
return err
// Get data pointer
dataPtr := reflect.ValueOf(dst).UnsafePointer()
lib.AtCopyData(ts.ctensor, dataPtr, numel, dtype.Size())
if err = TorchErr(); err != nil {
return err
return nil
// MustCopyData copies number of elements from tensor to a slice of data
// NOTE: `dst` is a slice with length = numel and Go type equavalent to tensor
// DType
func (ts *Tensor) MustCopyData(dst interface{}, numel uint) {
err := ts.CopyData(dst, numel)
if err != nil {
// Numel returns the total number of elements stored in a tensor.
func (ts *Tensor) Numel() uint {
if !ts.MustDefined() {
return 0 // ts.None case
shape := ts.MustSize()
return uint(FlattenDim(shape))
// ShallowClone returns a new tensor that share storage with the input tensor.
func (ts *Tensor) ShallowClone() (*Tensor, error) {
ctensor := lib.AtShallowClone(ts.ctensor)
if err := TorchErr(); err != nil {
return nil, err
name := fmt.Sprintf("%s_cloned",
return newTensor(ctensor, name), nil
// MustShallowClone returns a new tensor that share storage with the input
// tensor. It will panic if error occurred
func (ts *Tensor) MustShallowClone() *Tensor {
newTs, err := ts.ShallowClone()
if err != nil {
return newTs
// Get gets the sub-tensor at the given index.
func (ts *Tensor) Get(index int) (*Tensor, error) {
ctensor := lib.AtGet(ts.ctensor, index)
if err := TorchErr(); err != nil {
return nil, err
return newTensor(ctensor), nil
// MustGet gets the sub-tensor at the given index. It will panic if error
// occurred.
func (ts *Tensor) MustGet(index int) *Tensor {
subTs, err := ts.Get(index)
if err != nil {
return subTs
// Copy_ copies in-place values from the argument tensor to the input tensor.
func Copy_(self, src *Tensor) {
lib.AtCopy_(self.ctensor, src.ctensor)
if err := TorchErr(); err != nil {
// Copy_ copies in-place values from the argument tensor to existing tensor
func (ts *Tensor) Copy_(src *Tensor) {
lib.AtCopy_(ts.ctensor, src.ctensor)
if err := TorchErr(); err != nil {
// Save saves a tensor to a file.
func (ts *Tensor) Save(path string) error {
lib.AtSave(ts.ctensor, path)
if err := TorchErr(); err != nil {
return err
return nil
// MustSave saves a tensor to a file. It will panic if error
func (ts *Tensor) MustSave(path string) {
if err := ts.Save(path); err != nil {
// Load loads a tensor from a file.
func Load(path string, nameOpt ...string) (*Tensor, error) {
ctensor := lib.AtLoad(path)
if err := TorchErr(); err != nil {
return nil, err
return newTensor(ctensor, nameOpt...), nil
// MustLoad loads a tensor to a file. It will panic if error
func MustLoad(path string, nameOpt ...string) *Tensor {
ts, err := Load(path, nameOpt...)
if err != nil {
return ts
// NamedTensor wraps C tensor and its name
type NamedTensor struct {
Name string
Tensor *Tensor
// SaveMulti saves some named tensors to a file
// The file format is the same as the one used by the PyTorch C++ API.
// NOTE. This method is depreciated and will be replaced with `SaveMultiNew`
func SaveMulti(namedTensors []NamedTensor, path string) error {
var ctensors []lib.Ctensor
var names []string
for _, ts := range namedTensors {
ctensors = append(ctensors, ts.Tensor.ctensor)
names = append(names, ts.Name)
lib.AtSaveMulti(ctensors, names, len(namedTensors), path)
if err := TorchErr(); err != nil {
return err
return nil
// MustSaveMulti saves some named tensors to a file. It will panic if error
// NOTE. This method is depreciated and will be replaced with `MustSaveMultiNew`
func MustSaveMulti(namedTensors []NamedTensor, path string) {
err := SaveMulti(namedTensors, path)
if err != nil {
// LoadMulti loads some named tensors from a file
// The file format is the same as the one used by the PyTorch C++ API.
func LoadMulti(path string) ([]NamedTensor, error) {
var data lib.LoadData
dataPtr := lib.PStore.Set(&data)
lib.AtLoadCallback(path, dataPtr)
if err := TorchErr(); err != nil {
return nil, err
var namedTensors []NamedTensor
for _, v := range data.NamedCtensors {
namedTensor := NamedTensor{
Name: v.Name,
Tensor: newTensor(v.Ctensor, v.Name),
namedTensors = append(namedTensors, namedTensor)
return namedTensors, nil
// MustLoadMulti loads some named tensors from a file. It will panic if error
func MustLoadMulti(path string) []NamedTensor {
namedTensors, err := LoadMulti(path)
if err != nil {
return namedTensors
// LoadMultiWithDevice loads some named tensors from a file to a given device
// The file format is the same as the one used by the PyTorch C++ API.
func LoadMultiWithDevice(path string, device gotch.Device) ([]NamedTensor, error) {
var data lib.LoadData
dataPtr := lib.PStore.Set(&data)
lib.AtLoadCallbackWithDevice(path, dataPtr, device.CInt())
if err := TorchErr(); err != nil {
return nil, err
var namedTensors []NamedTensor
for _, v := range data.NamedCtensors {
namedTensor := NamedTensor{
Name: v.Name,
Tensor: newTensor(v.Ctensor, v.Name),
namedTensors = append(namedTensors, namedTensor)
return namedTensors, nil
// MustLoadMulti loads some named tensors from a file. It will panic if error
func MustLoadMultiWithDevice(path string, device gotch.Device) []NamedTensor {
namedTensors, err := LoadMultiWithDevice(path, device)
if err != nil {
return namedTensors
// ToString returns a string representation for the tensor.
// lw : line width (size)
// NOTE: The representation will contain all the tensor element hence may be huge for
// large tensors.
func (ts *Tensor) ToString(lw int64) (string, error) {
tensorStr := lib.AtToString(ts.ctensor, lw)
if err := TorchErr(); err != nil {
return "", err
return tensorStr, nil
// MustToString returns a string representation for the tensor. It will be panic
// if error.
// lw : line width (size)
func (ts *Tensor) MustToString(lw int64) string {
tensorStr, err := ts.ToString(lw)
if err != nil {
return tensorStr
// Drop drops (frees) the tensor
func (ts *Tensor) Drop() error {
if ts.ctensor == nil {
return nil
ts.calledFrom = "ts.Drop()"
return freeCTensor(ts)
// MustDrop drops the tensor. It will be panic if error
func (ts *Tensor) MustDrop() {
if err := ts.Drop(); err != nil {
// GradSetEnabled sets globally whether GradMode gradient accumulation is enable or not.
// It returns PREVIOUS state of Grad before setting.
func GradSetEnabled(b bool) (bool, error) {
var cbool, cretVal int
switch b {
case true:
cbool = 1
case false:
cbool = 0
var (
err error
state bool
cretVal = lib.AtGradSetEnabled(cbool)
if err = TorchErr(); err != nil {
return false, err
switch cretVal {
case 0:
state = false
case 1:
state = true
// case -1: // should be unreachable as error is captured above with TorchrErr()
// err = fmt.Errorf("Cannot set grad enable. \n")
// return retVal, err
// default: // should be unreachable as error is captured above with TorchrErr()
// err = fmt.Errorf("Cannot set grad enable. \n")
// return retVal, err
return state, nil
// MustGradSetEnabled sets globally whether GradMode gradient accumuation is enable or not.
// It returns PREVIOUS state of Grad before setting. It will be panic if error
func MustGradSetEnabled(b bool) bool {
state, err := GradSetEnabled(b)
if err != nil {
return state
// NoGrad runs a closure without keeping track of gradients.
func NoGrad(fn func()) {
// Switch off Grad
// Switch on Grad
func NoGrad1(fn func() interface{}) interface{} {
newTs := NewTensor()
// Switch off Grad
prev := MustGradSetEnabled(false)
retVal := fn()
// Switch on Grad
_ = MustGradSetEnabled(prev)
return retVal
// NoGradGuard is a RAII guard that prevents gradient tracking until deallocated.
// It actually sets a global flag that is checked by the backend whenever an op is done on a variable.
// The guard itself saved the current status and set it to false in the constructor.
// And restore the saved status in its destructor.
// That way it is similar to a with torch.no_grad(): block in python.
// Ref.
// TODO: should we implement Go `mutex` here???
type NoGradGuard struct {
enabled bool
// Init NoGradGuard and disables gradient tracking
func NewNoGradGuard() *NoGradGuard {
return noGradGuardInit()
// Disables gradient tracking, this will be enabled back when the
// returned value gets deallocated.
func noGradGuardInit() *NoGradGuard {
return &NoGradGuard{enabled: MustGradSetEnabled(false)}
// Drop drops the NoGradGuard state.
func (ngg *NoGradGuard) Drop() {
ngg.enabled = true
_ = MustGradSetEnabled(ngg.enabled)
func (ngg *NoGradGuard) Enable() {
ngg.enabled = false
_ = MustGradSetEnabled(ngg.enabled)
const (
// Do not reduce
ReductionNone int64 = 0
// Mean of losses
ReductionMean int64 = 1
// Sum of losses
ReductionSum int64 = 2
// Escape hatch in case new options become available
ReductionOther int64 = 3
// func (r Reduction) ToInt() int {
// switch r {
// case ReductionNone:
// return 0
// case ReductionMean:
// return 1
// case ReductionSum:
// return 2
// case ReductionOther:
// return 3
// }
// // NOTE. should it be panic here instead of returning -1?
// return -1
// }
// Float64Values returns values of tensor in a slice of float64.
func (ts *Tensor) Float64Values(delOpt ...bool) []float64 {
del := false
if len(delOpt) > 0 {
del = delOpt[0]
numel := ts.Numel()
vec := make([]float64, numel)
float64Ts := ts.MustTotype(gotch.Double, false)
float64Ts.MustCopyData(vec, numel)
// float64Ts.MustDrop()
if del {
return vec
// Int64Values returns values of tensor in a slice of int64.
func (ts *Tensor) Int64Values(delOpt ...bool) []int64 {
del := false
if len(delOpt) > 0 {
del = delOpt[0]
numel := ts.Numel()
vec := make([]int64, numel)
int64Ts := ts.MustTotype(gotch.Int64, false)
int64Ts.MustCopyData(vec, numel)
if del {
return vec
// Vals returns tensor values in a slice
// NOTE: need a type insersion to get runtime type
// E.g. res := xs.Vals().([]int64)
func (ts *Tensor) Vals() interface{} {
dtype := ts.DType()
numel := int(ts.Numel())
typ, err := dtype.GoType()
if err != nil {
dataSlice := reflect.MakeSlice(reflect.SliceOf(typ), numel, numel).Interface()
ts.CopyData(dataSlice, uint(numel))
return dataSlice
// FlatView flattens a tensor.
// This returns a flattened version of the given tensor. The first dimension
// is preserved as it is assumed to be the mini-batch dimension.
func (ts *Tensor) FlatView() *Tensor {
batchSize := ts.MustSize()[0]
return ts.MustView([]int64{batchSize, -1}, false)
func (ts *Tensor) ZeroPad2d(left, right, top, bottom int64, del bool) (*Tensor, error) {
if ts.Dim() != 4 {
err := fmt.Errorf("Expected a 4 dimension tensor, got %v\n", ts.MustSize())
return nil, err
return ts.ConstantPadNd([]int64{left, right, top, bottom}, del)
func (ts *Tensor) MustZeroPad2d(left, right, top, bottom int64, del bool) *Tensor {
retVal, err := ts.ZeroPad2d(left, right, top, bottom, del)
if err != nil {
return retVal
// Onehot converts a tensor to a one-hot encoded version.
// If the input has a size [N1, N2, ..., Nk], the returned tensor has a size
// [N1, ..., Nk, labels]. The returned tensor uses float values.
// Elements of the input vector are expected to be between 0 and labels-1.
// NOTE: There's other `ts.OneHot` and `ts.MustOneHot` generated from Atg C++ API
func (ts *Tensor) Onehot(labels int64) *Tensor {
dims := ts.MustSize()
dims = append(dims, labels)
unsqueezeTs := ts.MustUnsqueeze(-1, false)
inputTs := unsqueezeTs.MustTotype(gotch.Int64, true)
zerosTs := MustZeros(dims, gotch.Float, gotch.CPU)
retVal := zerosTs.MustScatterValue(-1, inputTs, FloatScalar(1.0), true)
return retVal
func (ts *Tensor) Swish() *Tensor {
sig := ts.MustSigmoid(false)
mulTs := ts.MustMul(sig, false)
return mulTs
func (ts *Tensor) AvgPool2DDefault(ksize int64, del bool) *Tensor {
return ts.MustAvgPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, false, true, []int64{1}, del)
// SaveMultiNew saves a slice of named tensors to the given file path.
func SaveMultiNew(namedTensors []NamedTensor, path string) error {
var (
tensors []lib.Ctensor
names []string
for _, nts := range namedTensors {
tensors = append(tensors, nts.Tensor.ctensor)
names = append(names, nts.Name)
lib.AtSaveMultiNew(tensors, names, path)
if err := TorchErr(); err != nil {
return err
return nil
func (ts *Tensor) ConstantPadNdWithVal(pad []int64, value *Scalar, del bool) (retVal *Tensor, err error) {
if del {
defer ts.MustDrop()
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtoConstantPadNd(ptr, ts.ctensor, pad, len(pad), value.cscalar)
if err = TorchErr(); err != nil {
return retVal, err
retVal = newTensor(*ptr)
return retVal, err
func (ts *Tensor) MustConstantPadNdWithVal(pad []int64, value *Scalar, del bool) (retVal *Tensor) {
retVal, err := ts.ConstantPadNdWithVal(pad, value, del)
if err != nil {
return retVal
// TT. Added some torch.cuda APIs for handling CUDA qt
// CudaCurrentDevice get device index of current CUDA device.
func CudaCurrentDevice() (int, error) {
currentDeviceIndex := lib.AtcGetDevice()
if err := TorchErr(); err != nil {
err = fmt.Errorf("ts.CudaCurrentDevice() failed: %w\n", err)
return -99, err
return currentDeviceIndex, nil
// CudaSetDevice set new cuda device index and returns previous cuda index.
func CudaSetDevice(cudaDeviceIndex int) (int, error) {
currentDeviceIndex, err := CudaCurrentDevice()
if err != nil {
err = fmt.Errorf("ts.CudaSetDevice() failed: %w\n", err)
return -99, err
if err := TorchErr(); err != nil {
err = fmt.Errorf("ts.CudaSetDevice() failed: %w\n", err)
return -99, err
return currentDeviceIndex, nil
// CudaSynchronize waits for all kernels in all streams on a CUDA device to complete.
func CudaSynchronize(cudaDeviceIndexOpt error {
var cudaDeviceIndex int
var err error
if len(cudaDeviceIndexOpt) > 0 {
cudaDeviceIndex = cudaDeviceIndexOpt[0]
} else {
cudaDeviceIndex, err = CudaCurrentDevice()
if err != nil {
err := fmt.Errorf("ts.CudaSynchronize() failed: %w\n", err)
return err
return TorchErr()