gotch/ts/tensor.go

1557 lines
36 KiB
Go
Raw Normal View History

2022-03-12 07:20:20 +00:00
package ts
2020-05-28 17:58:23 +01:00
//#include "stdlib.h"
//#include "stdbool.h"
//#include<stdio.h>
2020-05-28 17:58:23 +01:00
import "C"
import (
"bytes"
"encoding/binary"
"fmt"
"log"
2020-05-28 17:58:23 +01:00
"reflect"
2023-07-04 14:26:20 +01:00
"runtime"
"sync"
"sync/atomic"
"time"
"unsafe"
2020-05-28 17:58:23 +01:00
gotch "github.com/sugarme/gotch"
2020-05-28 17:58:23 +01:00
lib "github.com/sugarme/gotch/libtch"
)
2023-07-04 14:26:20 +01:00
var (
TensorCount int64 // incremental counting created tensors
ScalarCount int64 // incremental counting created scalars
AllocatedMem int64 // bytes - keeping track of memory created and still occupied by gotch/tensor (excluding mem allocated by libtorch at C side)
ExistingTensors map[string]struct{} = make(map[string]struct{}) // keep track of existing tensors by name
ExistingScalars map[string]struct{} = make(map[string]struct{}) // keep track of existing scalar by name
lock sync.Mutex
)
2023-07-05 14:56:48 +01:00
// NOTE. None is an undefined tensor.
2023-07-04 14:26:20 +01:00
// It can be used in optional tensor parameter where 'None' value used.
// `ts.MustDefined()` function is used for checking 'null'
var None = NewTensor()
type bigStruct struct {
2023-07-05 14:56:48 +01:00
lots [1e5]byte // 100k - always on host memory.
2023-07-04 14:26:20 +01:00
}
2023-07-23 03:36:55 +01:00
// Tensor is a Go wrapper of a "C tensor pointer" - 8 Bytes (64-bits OS)
// or 4 Bytes (32-bits OS).
// `ctensor` is just a "C pointer" to `torch::Tensor` (torch::Tensor *lib.Ctensor)
2023-07-04 14:26:20 +01:00
//
2023-07-23 03:36:55 +01:00
// NOTE.Tensor should be big enough to be in heap memory.
// (yes, we choose to place tensor consistently in heap memory so that
// it can be targeted by Go garbage collector).
//
// For heap allocation see. https://stackoverflow.com/questions/10866195
2020-05-28 17:58:23 +01:00
type Tensor struct {
d *bigStruct
name string
ctensor lib.Ctensor
calledFrom string
2020-05-28 17:58:23 +01:00
}
2023-07-04 14:26:20 +01:00
func newTensor(ctensor lib.Ctensor, nameOpt ...string) *Tensor {
if len(nameOpt) == 0 {
nameOpt = []string{}
}
x := new(Tensor)
x.ctensor = ctensor
x.d = new(bigStruct)
atomic.AddInt64(&TensorCount, 1)
nbytes := x.nbytes()
atomic.AddInt64(&AllocatedMem, nbytes)
lock.Lock()
2023-10-19 04:17:21 +01:00
name := newName(nameOpt...)
if _, ok := ExistingTensors[name]; ok {
name = fmt.Sprintf("%s_%09d", name, TensorCount)
}
2023-07-04 14:26:20 +01:00
ExistingTensors[name] = struct{}{}
lock.Unlock()
x.name = name
2023-07-04 14:26:20 +01:00
if gotch.Debug {
log.Printf("INFO: Added tensor %q - Allocated memory: %d bytes.\n", x.name, nbytes)
}
x.calledFrom = "newTensor()"
2023-07-04 14:26:20 +01:00
runtime.SetFinalizer(x, freeCTensor)
return x
}
2023-07-23 05:54:02 +01:00
// New creates new tensor from C tensor.
func New(ctensor lib.Ctensor, nameOpt ...string) *Tensor {
return newTensor(ctensor, nameOpt...)
}
2023-07-04 14:26:20 +01:00
func CheckCMemLeak() string {
tensors := []string{}
lock.Lock()
2023-07-04 14:26:20 +01:00
for n := range ExistingTensors {
tensors = append(tensors, n)
}
memUsed := AllocatedMem
lock.Unlock()
2023-07-04 14:26:20 +01:00
var msg string
msg += fmt.Sprintf("============================= C MEMORY CHECK RESULT ==================================\n")
msg += fmt.Sprintf("C memory allocated not been released: %v bytes\n", memUsed)
2023-07-04 14:26:20 +01:00
msg += fmt.Sprintf("Tensors not been released: %q\n", tensors)
msg += fmt.Sprintf("======================================================================================\n")
2023-07-04 14:26:20 +01:00
return msg
}
// CleanUp calls double runtime.GC() with sleep time in between.
2023-07-04 14:26:20 +01:00
func CleanUp(sleepTimeOpt ...int) {
sleepTime := time.Duration(1000) // 1 second
if len(sleepTimeOpt) > 0 {
sleepTime = time.Duration(sleepTimeOpt[0])
}
runtime.GC()
time.Sleep(time.Millisecond * sleepTime)
runtime.GC()
}
2023-07-05 14:56:48 +01:00
// Ctensor return C pointer value.
2022-02-16 00:39:27 +00:00
func (ts *Tensor) Ctensor() unsafe.Pointer {
return unsafe.Pointer(ts.ctensor)
}
2023-07-04 14:26:20 +01:00
// free releases C allocated memory.
func freeCTensor(ts *Tensor) error {
if ts == nil || ts.ctensor == nil {
return nil
}
2023-07-04 14:26:20 +01:00
lock.Lock()
defer lock.Unlock()
if _, ok := ExistingTensors[ts.name]; !ok {
log.Printf("WARNING: Probably double free tensor %q. Called from %q. Just skipping...\n", ts.name, ts.calledFrom)
return nil
}
if gotch.Debug {
shape, err := ts.Size()
if err != nil {
err = fmt.Errorf("ERROR: failed to release tensor %q: %w\n", ts.name, err)
}
log.Printf(err.Error())
numel := uint(FlattenDim(shape))
dtype := ts.DType()
nbytes := int64(numel * dtype.Size())
atomic.AddInt64(&AllocatedMem, -nbytes)
log.Printf("INFO: Released tensor %q - C memory(%d bytes).\n", ts.name, nbytes)
}
2023-07-04 14:26:20 +01:00
lib.AtFree(ts.ctensor)
if err := TorchErr(); err != nil {
err := fmt.Errorf("ERROR: failed to release tensor %q - %w", ts.name, err)
return err
}
delete(ExistingTensors, ts.name)
2023-07-04 14:26:20 +01:00
// IMPORTANT. make it nil so won't double free.
ts.ctensor = nil
2023-07-04 14:26:20 +01:00
return nil
}
func newName(nameOpt ...string) string {
var name string
if len(nameOpt) > 0 {
name = nameOpt[0]
} else {
2023-07-05 14:56:48 +01:00
name = fmt.Sprintf("tensor_%09d", TensorCount)
2023-07-04 14:26:20 +01:00
}
return name
}
2020-08-05 01:31:01 +01:00
2020-05-28 17:58:23 +01:00
// NewTensor creates a new tensor
2023-07-04 14:26:20 +01:00
func NewTensor(nameOpt ...string) *Tensor {
2020-05-28 17:58:23 +01:00
ctensor := lib.AtNewTensor()
2023-07-04 14:26:20 +01:00
return newTensor(ctensor, nameOpt...)
2020-05-28 17:58:23 +01:00
}
2022-02-16 00:39:27 +00:00
func FromCtensor(ctensor unsafe.Pointer) *Tensor {
cts := (lib.Ctensor)(ctensor)
2023-07-04 14:26:20 +01:00
return newTensor(cts)
}
func (ts *Tensor) Name() string {
return ts.name
2022-02-16 00:39:27 +00:00
}
func (ts *Tensor) Dim() uint64 {
dim := lib.AtDim(ts.ctensor)
2020-06-03 02:03:38 +01:00
if err := TorchErr(); err != nil {
log.Fatal(err)
}
return dim
2020-06-01 08:37:05 +01:00
}
// Size return shape of the tensor
//
// NOTE: C++ libtorch calls at_shape() -> t.sizes()
// And returns a slice of sizes or shape using given pointer
// to that slice.
func (ts *Tensor) Size() ([]int64, error) {
2020-06-01 08:37:05 +01:00
dim := lib.AtDim(ts.ctensor)
if dim < 0 || dim > 100 {
err := fmt.Errorf("Invalid dim: %v\n", dim)
return nil, err
}
sz := make([]int64, dim)
szPtr, err := DataAsPtr(sz)
if err != nil {
return nil, err
}
defer C.free(unsafe.Pointer(szPtr))
lib.AtShape(ts.ctensor, szPtr)
if err = TorchErr(); err != nil {
return nil, err
}
shape := decodeSize(szPtr, dim)
return shape, nil
}
func (ts *Tensor) MustSize() []int64 {
shape, err := ts.Size()
2020-06-08 04:28:07 +01:00
if err != nil {
log.Fatal(err)
}
return shape
2020-06-08 04:28:07 +01:00
}
2023-07-23 03:17:27 +01:00
func (ts *Tensor) Stride() ([]int64, error) {
dim := lib.AtDim(ts.ctensor)
sz := make([]int64, dim)
szPtr, err := DataAsPtr(sz)
if err != nil {
return nil, err
}
defer C.free(unsafe.Pointer(szPtr))
lib.AtStride(ts.ctensor, szPtr)
if err = TorchErr(); err != nil {
return nil, err
}
strides := decodeSize(szPtr, dim)
return strides, nil
}
func (ts *Tensor) MustStride() []int64 {
strides, err := ts.Stride()
if err != nil {
log.Fatal(err)
}
return strides
}
// Size1 returns the tensor size for 1D tensors.
func (ts *Tensor) Size1() (int64, error) {
shape, err := ts.Size()
if err != nil {
return 0, err
}
if len(shape) != 1 {
err = fmt.Errorf("Expected one dim, got %v\n", len(shape))
return 0, err
}
return shape[0], nil
2020-06-01 08:37:05 +01:00
}
// Size2 returns the tensor size for 2D tensors.
func (ts *Tensor) Size2() ([]int64, error) {
shape, err := ts.Size()
if err != nil {
return nil, err
}
if len(shape) != 2 {
err = fmt.Errorf("Expected two dims, got %v\n", len(shape))
return nil, err
}
return shape, nil
}
// Size3 returns the tensor size for 3D tensors.
func (ts *Tensor) Size3() ([]int64, error) {
shape, err := ts.Size()
if err != nil {
return nil, err
}
if len(shape) != 3 {
err = fmt.Errorf("Expected three dims, got %v\n", len(shape))
return nil, err
}
return shape, nil
}
// Size4 returns the tensor size for 4D tensors.
func (ts *Tensor) Size4() ([]int64, error) {
shape, err := ts.Size()
if err != nil {
return nil, err
}
if len(shape) != 4 {
err = fmt.Errorf("Expected four dims, got %v\n", len(shape))
return nil, err
}
return shape, nil
}
2023-07-04 14:26:20 +01:00
// nbytes calculates tensor data size in bytes.
func (ts *Tensor) nbytes() int64 {
numel := ts.Numel()
if numel == 0 {
return 0 // ts.None
}
2023-07-06 15:01:23 +01:00
return int64(numel * ts.DType().Size())
2023-07-04 14:26:20 +01:00
}
func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
2023-07-06 15:01:23 +01:00
dtype := gotch.Int64 // tensor size dtype = int64
nbytes := int(nsize) * int(dtype.Size())
dataSlice := (*[1 << 30]byte)(ptr)[:nbytes:nbytes]
r := bytes.NewReader(dataSlice)
dataIn := make([]int64, nsize)
if err := binary.Read(r, nativeEndian, dataIn); err != nil {
log.Fatal(err)
}
return dataIn
}
2023-07-06 15:01:23 +01:00
// TensorOptions constructs options to build/rebuild tensor.
type TensorOptions struct {
Name string
DType gotch.DType
Quantized bool
// TODO. can expand as needed
}
type TensorOpt func(*TensorOptions)
func DefaultTensorOptions() *TensorOptions {
return &TensorOptions{
Name: "",
DType: gotch.Float,
Quantized: false,
}
}
func WithName(v string) TensorOpt {
return func(o *TensorOptions) {
o.Name = v
}
}
func WithDType(v gotch.DType) TensorOpt {
return func(o *TensorOptions) {
o.DType = v
}
}
func WithQuantized(v bool) TensorOpt {
return func(o *TensorOptions) {
o.Quantized = v
}
}
// OfSlice creates tensor from a slice data
2023-07-06 15:01:23 +01:00
func OfSlice(data interface{}, opts ...TensorOpt) (*Tensor, error) {
o := DefaultTensorOptions()
for _, opt := range opts {
opt(o)
}
// convert []int -> int32. `binary.Write()` can't write `[]int` because it's not fixed-size!
if reflect.TypeOf(data).String() == "[]int" {
data = sliceIntToInt32(data.([]int))
}
2020-05-28 17:58:23 +01:00
v := reflect.ValueOf(data)
kind := v.Kind().String()
if kind != "slice" && kind != "array" {
err := fmt.Errorf("Expected slice data. Got %q", kind)
return nil, err
}
2023-07-06 15:01:23 +01:00
elementKind := reflect.TypeOf(data).Elem().Kind()
dataLen := v.Len()
2023-07-06 15:01:23 +01:00
dtype, err := gotch.GoKind2DType(elementKind, gotch.HalfDTypePref(o.DType), gotch.WithQuantized(o.Quantized))
if err != nil {
return nil, err
}
shape := []int64{int64(dataLen)}
elementNum := ElementCount(shape)
2020-05-30 02:15:36 +01:00
2023-07-06 15:01:23 +01:00
nbytes := int(dtype.Size()) * elementNum
2020-05-28 17:58:23 +01:00
dataPtr, buff := CMalloc(nbytes)
defer C.free(unsafe.Pointer(dataPtr))
2020-05-28 17:58:23 +01:00
if err = EncodeTensor(buff, reflect.ValueOf(data), shape); err != nil {
return nil, err
}
2020-05-28 17:58:23 +01:00
2023-07-06 15:01:23 +01:00
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(dtype.Size()), int(dtype.CKind()))
if err = TorchErr(); err != nil {
return nil, err
}
2020-05-28 17:58:23 +01:00
2023-07-06 15:01:23 +01:00
return newTensor(ctensor, o.Name), nil
2023-07-04 14:26:20 +01:00
// return newTensor(ctensor), nil
2020-05-28 17:58:23 +01:00
}
// OfDataSize creates Tensor from input byte data, shape and dtype.
2023-07-06 15:01:23 +01:00
func OfDataSize(data []byte, shape []int64, dtype gotch.DType, opts ...TensorOpt) (*Tensor, error) {
o := DefaultTensorOptions()
for _, opt := range opts {
opt(o)
}
2020-11-16 12:37:44 +00:00
2020-11-17 07:31:29 +00:00
elementNum := ElementCount(shape)
2023-07-06 15:01:23 +01:00
nbytes := elementNum * int(dtype.Size())
2020-11-17 07:31:29 +00:00
if nbytes != len(data) {
err := fmt.Errorf("data and shape mismatched for dtype (%v): byte data (%v) - shape (%v).\n", dtype, len(data), shape)
return nil, err
}
dataPtr, buff := CMalloc(nbytes)
defer C.free(unsafe.Pointer(dataPtr))
if err := binary.Write(buff, nativeEndian, data); err != nil {
2020-11-17 07:31:29 +00:00
return nil, err
}
2023-07-06 15:01:23 +01:00
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), dtype.Size(), int(dtype.CKind()))
if err := TorchErr(); err != nil {
2020-11-17 07:31:29 +00:00
return nil, err
}
2023-07-06 15:01:23 +01:00
return newTensor(ctensor, o.Name), nil
2023-07-04 14:26:20 +01:00
// return newTensor(ctensor), nil
2020-11-16 12:37:44 +00:00
}
// MustOfDataSize create Tensor from input byte data and specified shape and dtype
// or panic if error
2023-07-06 15:01:23 +01:00
func MustOfDataSize(data []byte, size []int64, dtype gotch.DType, opts ...TensorOpt) *Tensor {
ts, err := OfDataSize(data, size, dtype, opts...)
2020-11-16 12:37:44 +00:00
if err != nil {
log.Fatal(err)
}
return ts
}
// MustOfSlice create a tensor from slice of data. It will be panic if error.
2023-07-06 15:01:23 +01:00
func MustOfSlice(data interface{}, opts ...TensorOpt) *Tensor {
ts, err := OfSlice(data, opts...)
if err != nil {
log.Fatal(err)
}
return ts
}
// TensorFrom create a tensor from slice of data. It will be panic if error.
2023-07-06 15:01:23 +01:00
func TensorFrom(data interface{}, opts ...TensorOpt) *Tensor {
ts, err := OfSlice(data, opts...)
if err != nil {
log.Fatal(err)
}
return ts
}
2020-05-30 02:15:36 +01:00
// Print prints tensor values to console.
//
// NOTE: it is printed from C and will print ALL elements of tensor
// with no truncation at all.
func (ts *Tensor) Print() {
lib.AtPrint(ts.ctensor)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
}
// NewTensorFromData creates tensor from given data and shape
2023-07-06 15:01:23 +01:00
func NewTensorFromData(data interface{}, shape []int64, opts ...TensorOpt) (*Tensor, error) {
o := DefaultTensorOptions()
for _, opt := range opts {
opt(o)
}
// 1. Check whether data and shape match
elementNum, err := DataDim(data)
if err != nil {
return nil, err
}
nflattend := FlattenDim(shape)
if elementNum != nflattend {
err = fmt.Errorf("Number of data elements (%v) and flatten shape (%v) dimension mismatched.\n", elementNum, nflattend)
return nil, err
}
// 2. Write raw data to C memory and get C pointer
dataPtr, err := DataAsPtr(data)
defer C.free(unsafe.Pointer(dataPtr))
if err != nil {
return nil, err
}
// 3. Create tensor with pointer and shape
dtype, err := gotch.DTypeFromData(data)
if err != nil {
return nil, err
}
2023-07-06 15:01:23 +01:00
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), dtype.Size(), int(dtype.CKind()))
if err = TorchErr(); err != nil {
return nil, err
}
2023-07-06 15:01:23 +01:00
return newTensor(ctensor, o.Name), nil
}
2020-06-02 10:29:24 +01:00
func (ts *Tensor) DType() gotch.DType {
2020-06-02 10:29:24 +01:00
cint := lib.AtScalarType(ts.ctensor)
2023-07-06 15:01:23 +01:00
return gotch.CKind2DType(cint)
2020-06-02 10:29:24 +01:00
}
func (ts *Tensor) Device() (gotch.Device, error) {
var (
retVal gotch.Device
err error
)
cInt := lib.AtDevice(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
var device gotch.Device
return device.OfCInt(int32(cInt)), nil
}
func (ts *Tensor) MustDevice() gotch.Device {
device, err := ts.Device()
2020-08-01 07:33:30 +01:00
if err != nil {
log.Fatal(err)
}
return device
2020-08-01 07:33:30 +01:00
}
/*
* func (ts Tensor) Eq1(other Tensor, del bool) (retVal Tensor, err error) {
*
* // Get a C null pointer
* // https://stackoverflow.com/a/2022369
* ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
* if del {
* defer ts.MustDrop()
* }
*
* lib.AtgEq1(ptr, ts.ctensor, other.ctensor)
* if err = TorchErr(); err != nil {
* return retVal, err
* }
*
* return Tensor{ctensor: *ptr}, nil
*
* }
*
* func (ts Tensor) MustEq1(other Tensor, del bool) (retVal Tensor) {
* retVal, err := ts.Eq1(other, del)
* if err != nil {
* log.Fatal(err)
* }
*
* return retVal
* }
* */
2023-07-06 15:01:23 +01:00
2020-06-14 03:51:38 +01:00
// Float64Value returns a float value on tensors holding a single element.
// An error is returned otherwise.
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
func (ts *Tensor) Float64Value(idx []int64) (float64, error) {
idxPtr, err := DataAsPtr(idx)
if err != nil {
return 0, err
}
defer C.free(unsafe.Pointer(idxPtr))
f64Val := lib.AtDoubleValueAtIndexes(ts.ctensor, idxPtr, len(idx))
if err = TorchErr(); err != nil {
return 0, err
}
return f64Val, err
}
func (ts *Tensor) MustFloat64Value(idx []int64) float64 {
f64Val, err := ts.Float64Value(idx)
if err != nil {
log.Fatal(err)
}
return f64Val
}
// Int64Value returns an int value on tensors holding a single element. An error is
// returned otherwise.
func (ts *Tensor) Int64Value(idx []int64) (int64, error) {
var (
retVal int64
err error
)
idxPtr, err := DataAsPtr(idx)
if err != nil {
return retVal, err
}
defer C.free(unsafe.Pointer(idxPtr))
int64Val := lib.AtInt64ValueAtIndexes(ts.ctensor, idxPtr, len(idx))
if err = TorchErr(); err != nil {
return 0, err
}
return int64Val, err
}
func (ts *Tensor) MustInt64Value(idx []int64) int64 {
int64Val, err := ts.Int64Value(idx)
2020-06-30 11:01:01 +01:00
if err != nil {
log.Fatal(err)
}
return int64Val
2020-06-30 11:01:01 +01:00
}
// RequiresGrad returns true if gradient are currently tracked for this tensor.
func (ts *Tensor) RequiresGrad() (bool, error) {
state := lib.AtRequiresGrad(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
}
return state, nil
}
func (ts *Tensor) MustRequiresGrad() bool {
state, err := ts.RequiresGrad()
if err != nil {
log.Fatal(err)
}
return state
}
// DataPtr returns the address of the first element of this tensor.
func (ts *Tensor) DataPtr() (unsafe.Pointer, error) {
datPtr := lib.AtDataPtr(ts.ctensor)
if err := TorchErr(); err != nil {
return nil, err
}
return datPtr, nil
}
2023-07-23 04:10:24 +01:00
func (ts *Tensor) MustDataPtr() unsafe.Pointer {
p, err := ts.DataPtr()
if err != nil {
panic(err)
}
return p
}
// Defined returns true is the tensor is defined.
func (ts *Tensor) Defined() (bool, error) {
state := lib.AtDefined(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
}
return state, nil
}
func (ts *Tensor) MustDefined() bool {
state, err := ts.Defined()
if err != nil {
log.Fatal(err)
}
return state
}
// IsSparse returns true is the tensor is spare.
func (ts *Tensor) IsSparse() (bool, error) {
state := lib.AtIsSparse(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
}
return state, nil
}
func (ts *Tensor) MustIsSparse() bool {
state, err := ts.IsSparse()
if err != nil {
log.Fatal(err)
}
return state
}
// IsContiguous returns true is the tensor is contiguous.
func (ts *Tensor) IsContiguous() (bool, error) {
state := lib.AtIsContiguous(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
}
return state, nil
}
func (ts *Tensor) MustIsContiguous() bool {
state, err := ts.IsContiguous()
if err != nil {
log.Fatal(err)
}
return state
}
// IsMkldnn returns true is the tensor is mkldnn.
func (ts *Tensor) IsMkldnn() (bool, error) {
state := lib.AtIsMkldnn(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
}
return state, nil
}
func (ts *Tensor) MustIsMkldnn() bool {
state, err := ts.IsMkldnn()
if err != nil {
log.Fatal(err)
}
return state
}
// ZeroGrad zeroes the gradient tensor attached to this tensor if defined.
func (ts *Tensor) ZeroGrad() {
grad := ts.MustGrad(false)
if grad.MustDefined() {
grad.Detach_()
grad.Zero_()
}
}
// Backward runs the backward pass, populating the gradient tensors for tensors
// which gradients are tracked.
//
// Gradients tracking can be turned on via `SetRequiresGrad`.
func (ts *Tensor) Backward() error {
lib.AtBackward(ts.ctensor, 0, 0)
if err := TorchErr(); err != nil {
return err
}
return nil
}
func (ts *Tensor) MustBackward() {
if err := ts.Backward(); err != nil {
log.Fatal(err)
}
}
// RunBackward runs the backward ...
2023-07-05 14:56:48 +01:00
func RunBackward(tensors []*Tensor, inputs []*Tensor, keepGraphB bool, createGraphB bool) ([]*Tensor, error) {
// NOTE: outputs is a slice of tensors with length = len(inputs)
var outputsPtr []*lib.Ctensor
// Are they allocated contigously??? Definitely not.
// TODO. calculate C memory size = C pointer size x n pointers
// Then C.malloc such calculated amount
// NOTE. replace with the following code and test.
/*
* ntensors := len(inputs)
* nbytes := C.size_t(ntensors) * C.size_t(unsafe.Sizeof(uintptr(0)))
* ctensorsPtr := (*[1 << 30]lib.Ctensor)(C.malloc(nbytes))
* for i :=0; i < ntensors; i++ {
* outputPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
* outputsPtr[i] = outputPtr
* }
* */
for i := 0; i < len(inputs); i++ {
outputPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(outputPtr))
outputsPtr = append(outputsPtr, outputPtr)
}
// Get first element pointer
ctensor := tensors[0].ctensor
cinput := inputs[0].ctensor
tensorsPtr := (*lib.Ctensor)(unsafe.Pointer(&ctensor))
inputsPtr := (*lib.Ctensor)(unsafe.Pointer(&cinput))
var keepGraph int = 0
if keepGraphB {
keepGraph = 1
}
var createGraph int = 0
if createGraphB {
createGraph = 1
}
lib.AtRunBackward(tensorsPtr, len(tensors), inputsPtr, len(inputs), outputsPtr[0], keepGraph, createGraph)
if err := TorchErr(); err != nil {
return nil, err
}
2023-07-05 14:56:48 +01:00
var oTensors []*Tensor
for i := 0; i < len(inputs); i++ {
outputPtr := outputsPtr[i]
2023-07-05 14:56:48 +01:00
oTensors = append(oTensors, newTensor(*outputPtr))
}
return oTensors, nil
}
2020-06-08 04:28:07 +01:00
// CopyDataUint8 copies `numel` elements from `self` to `dst`.
//
// NOTE: `dst` located in Go memory. Should it be?
func (ts *Tensor) CopyDataUint8(dst []uint8, numel uint) error {
2020-06-08 04:28:07 +01:00
// NOTE: we must make sure that `dst` has same len as `numel`. Otherwise,
// there will be memory leak and or out of range error.
if len(dst) < int(numel) {
err := fmt.Errorf("CopyDataUint8 Error: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", len(dst), numel)
2020-06-08 04:28:07 +01:00
return err
}
vs := unsafe.Pointer(&dst[0])
2023-07-06 15:01:23 +01:00
dtype := gotch.Uint8
lib.AtCopyData(ts.ctensor, vs, numel, dtype.Size())
if err := TorchErr(); err != nil {
2020-06-08 04:28:07 +01:00
return err
}
return nil
}
func (ts *Tensor) MustCopyDataUint8(dst []uint8, numel uint) {
2020-06-08 04:28:07 +01:00
err := ts.CopyDataUint8(dst, numel)
if err != nil {
log.Fatal(err)
}
}
// CopyData copies `numel` elements from `self` to `dst`.
// `dst` should be a slice of Go type equivalent to tensor type.
//
// NOTE: `dst` located in Go memory. Should it be?
2020-07-17 02:22:04 +01:00
// We will render Go pointer of first element of `dst` slice
// and number of elements to C land. This may break in the future
// if Go policy changes.
func (ts *Tensor) CopyData(dst interface{}, numel uint) error {
2023-07-06 15:01:23 +01:00
dtype, dlen, err := DataCheck(dst)
2020-06-08 04:28:07 +01:00
if dlen < int(numel) {
2023-07-06 15:01:23 +01:00
err = fmt.Errorf("ts.CopyData() failed: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel)
2020-06-08 04:28:07 +01:00
return err
}
if ts.DType() != dtype {
2023-07-06 15:01:23 +01:00
err = fmt.Errorf("ts.CopyData() failed: Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
2020-06-08 04:28:07 +01:00
return err
}
2023-07-06 15:01:23 +01:00
// Get data pointer
dataPtr := reflect.ValueOf(dst).UnsafePointer()
2020-06-08 04:28:07 +01:00
2023-07-06 15:01:23 +01:00
lib.AtCopyData(ts.ctensor, dataPtr, numel, dtype.Size())
2020-06-08 04:28:07 +01:00
if err = TorchErr(); err != nil {
return err
}
return nil
}
// MustCopyData copies number of elements from tensor to a slice of data
//
// NOTE: `dst` is a slice with length = numel and Go type equavalent to tensor
// DType
func (ts *Tensor) MustCopyData(dst interface{}, numel uint) {
2020-06-08 04:28:07 +01:00
err := ts.CopyData(dst, numel)
if err != nil {
log.Fatal(err)
}
}
// Numel returns the total number of elements stored in a tensor.
func (ts *Tensor) Numel() uint {
2023-07-04 14:26:20 +01:00
if !ts.MustDefined() {
return 0 // ts.None case
}
shape := ts.MustSize()
2020-06-08 04:28:07 +01:00
return uint(FlattenDim(shape))
}
// ShallowClone returns a new tensor that share storage with the input tensor.
func (ts *Tensor) ShallowClone() (*Tensor, error) {
ctensor := lib.AtShallowClone(ts.ctensor)
if err := TorchErr(); err != nil {
return nil, err
}
2023-07-04 14:26:20 +01:00
name := fmt.Sprintf("%s_cloned", ts.name)
return newTensor(ctensor, name), nil
}
// MustShallowClone returns a new tensor that share storage with the input
// tensor. It will panic if error occurred
func (ts *Tensor) MustShallowClone() *Tensor {
newTs, err := ts.ShallowClone()
if err != nil {
log.Fatal(err)
}
return newTs
}
// Get gets the sub-tensor at the given index.
func (ts *Tensor) Get(index int) (*Tensor, error) {
ctensor := lib.AtGet(ts.ctensor, index)
if err := TorchErr(); err != nil {
return nil, err
}
2023-07-04 14:26:20 +01:00
return newTensor(ctensor), nil
}
// MustGet gets the sub-tensor at the given index. It will panic if error
// occurred.
func (ts *Tensor) MustGet(index int) *Tensor {
subTs, err := ts.Get(index)
if err != nil {
log.Fatal(err)
}
return subTs
}
// Copy_ copies in-place values from the argument tensor to the input tensor.
func Copy_(self, src *Tensor) {
lib.AtCopy_(self.ctensor, src.ctensor)
if err := TorchErr(); err != nil {
2020-06-14 03:51:38 +01:00
log.Fatal(err)
}
}
2020-06-14 03:51:38 +01:00
// Copy_ copies in-place values from the argument tensor to existing tensor
func (ts *Tensor) Copy_(src *Tensor) {
2020-06-14 03:51:38 +01:00
lib.AtCopy_(ts.ctensor, src.ctensor)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
}
2020-06-08 08:06:35 +01:00
// Save saves a tensor to a file.
func (ts *Tensor) Save(path string) error {
2020-06-08 08:06:35 +01:00
lib.AtSave(ts.ctensor, path)
if err := TorchErr(); err != nil {
2020-06-08 08:06:35 +01:00
return err
}
return nil
}
// MustSave saves a tensor to a file. It will panic if error
func (ts *Tensor) MustSave(path string) {
2020-06-08 08:06:35 +01:00
if err := ts.Save(path); err != nil {
log.Fatal(err)
}
}
// Load loads a tensor from a file.
2023-07-04 14:26:20 +01:00
func Load(path string, nameOpt ...string) (*Tensor, error) {
2020-06-08 08:06:35 +01:00
ctensor := lib.AtLoad(path)
if err := TorchErr(); err != nil {
return nil, err
2020-06-08 08:06:35 +01:00
}
2023-07-04 14:26:20 +01:00
return newTensor(ctensor, nameOpt...), nil
2020-06-08 08:06:35 +01:00
}
// MustLoad loads a tensor to a file. It will panic if error
2023-07-04 14:26:20 +01:00
func MustLoad(path string, nameOpt ...string) *Tensor {
ts, err := Load(path, nameOpt...)
2020-06-08 08:06:35 +01:00
if err != nil {
log.Fatal(err)
}
return ts
2020-06-08 08:06:35 +01:00
}
// NamedTensor wraps C tensor and its name
2020-06-08 08:06:35 +01:00
type NamedTensor struct {
Name string
Tensor *Tensor
2020-06-08 08:06:35 +01:00
}
// SaveMulti saves some named tensors to a file
//
// The file format is the same as the one used by the PyTorch C++ API.
// NOTE. This method is depreciated and will be replaced with `SaveMultiNew`
func SaveMulti(namedTensors []NamedTensor, path string) error {
var ctensors []lib.Ctensor
2020-06-08 08:06:35 +01:00
var names []string
for _, ts := range namedTensors {
ctensors = append(ctensors, ts.Tensor.ctensor)
names = append(names, ts.Name)
}
lib.AtSaveMulti(ctensors, names, len(namedTensors), path)
if err := TorchErr(); err != nil {
return err
}
2020-06-08 08:06:35 +01:00
return nil
}
// MustSaveMulti saves some named tensors to a file. It will panic if error
//
// NOTE. This method is depreciated and will be replaced with `MustSaveMultiNew`
func MustSaveMulti(namedTensors []NamedTensor, path string) {
err := SaveMulti(namedTensors, path)
if err != nil {
log.Fatal(err)
}
}
// LoadMulti loads some named tensors from a file
//
// The file format is the same as the one used by the PyTorch C++ API.
func LoadMulti(path string) ([]NamedTensor, error) {
var data lib.LoadData
dataPtr := lib.PStore.Set(&data)
lib.AtLoadCallback(path, dataPtr)
if err := TorchErr(); err != nil {
return nil, err
}
var namedTensors []NamedTensor
2020-06-10 09:31:07 +01:00
for _, v := range data.NamedCtensors {
namedTensor := NamedTensor{
Name: v.Name,
2023-07-04 14:26:20 +01:00
Tensor: newTensor(v.Ctensor, v.Name),
2020-06-10 09:31:07 +01:00
}
namedTensors = append(namedTensors, namedTensor)
2020-06-10 09:31:07 +01:00
}
return namedTensors, nil
}
// MustLoadMulti loads some named tensors from a file. It will panic if error
func MustLoadMulti(path string) []NamedTensor {
namedTensors, err := LoadMulti(path)
if err != nil {
log.Fatal(err)
}
return namedTensors
}
2020-06-10 22:24:10 +01:00
// LoadMultiWithDevice loads some named tensors from a file to a given device
//
// The file format is the same as the one used by the PyTorch C++ API.
func LoadMultiWithDevice(path string, device gotch.Device) ([]NamedTensor, error) {
var data lib.LoadData
dataPtr := lib.PStore.Set(&data)
lib.AtLoadCallbackWithDevice(path, dataPtr, device.CInt())
if err := TorchErr(); err != nil {
return nil, err
}
var namedTensors []NamedTensor
for _, v := range data.NamedCtensors {
namedTensor := NamedTensor{
Name: v.Name,
2023-07-04 14:26:20 +01:00
Tensor: newTensor(v.Ctensor, v.Name),
}
namedTensors = append(namedTensors, namedTensor)
}
return namedTensors, nil
}
// MustLoadMulti loads some named tensors from a file. It will panic if error
func MustLoadMultiWithDevice(path string, device gotch.Device) []NamedTensor {
namedTensors, err := LoadMultiWithDevice(path, device)
if err != nil {
log.Fatal(err)
}
return namedTensors
}
2020-06-10 22:24:10 +01:00
// ToString returns a string representation for the tensor.
//
// lw : line width (size)
// NOTE: The representation will contain all the tensor element hence may be huge for
// large tensors.
func (ts *Tensor) ToString(lw int64) (string, error) {
tensorStr := lib.AtToString(ts.ctensor, lw)
if err := TorchErr(); err != nil {
return "", err
2020-06-10 22:24:10 +01:00
}
return tensorStr, nil
2020-06-10 22:24:10 +01:00
}
// MustToString returns a string representation for the tensor. It will be panic
// if error.
// lw : line width (size)
func (ts *Tensor) MustToString(lw int64) string {
tensorStr, err := ts.ToString(lw)
2020-06-10 22:24:10 +01:00
if err != nil {
log.Fatal(err)
}
return tensorStr
2020-06-10 22:24:10 +01:00
}
2020-06-10 22:37:09 +01:00
// Drop drops (frees) the tensor
func (ts *Tensor) Drop() error {
if ts.ctensor == nil {
return nil
}
// Clear SetFinalizer on ts so no double free tensor.
// Ref. https://pkg.go.dev/runtime#SetFinalizer
runtime.SetFinalizer(ts, nil)
ts.calledFrom = "ts.Drop()"
return freeCTensor(ts)
2020-06-10 22:37:09 +01:00
}
// MustDrop drops the tensor. It will be panic if error
func (ts *Tensor) MustDrop() {
2020-06-10 22:37:09 +01:00
if err := ts.Drop(); err != nil {
panic(err)
2020-06-10 22:37:09 +01:00
}
}
// GradSetEnabled sets globally whether GradMode gradient accumulation is enable or not.
// It returns PREVIOUS state of Grad before setting.
func GradSetEnabled(b bool) (bool, error) {
var cbool, cretVal int
switch b {
case true:
cbool = 1
case false:
cbool = 0
}
var (
err error
state bool
)
cretVal = lib.AtGradSetEnabled(cbool)
if err = TorchErr(); err != nil {
return false, err
}
switch cretVal {
case 0:
state = false
break
case 1:
state = true
break
// case -1: // should be unreachable as error is captured above with TorchrErr()
// err = fmt.Errorf("Cannot set grad enable. \n")
// return retVal, err
// default: // should be unreachable as error is captured above with TorchrErr()
// err = fmt.Errorf("Cannot set grad enable. \n")
// return retVal, err
}
return state, nil
}
// MustGradSetEnabled sets globally whether GradMode gradient accumuation is enable or not.
// It returns PREVIOUS state of Grad before setting. It will be panic if error
func MustGradSetEnabled(b bool) bool {
state, err := GradSetEnabled(b)
if err != nil {
log.Fatal(err)
}
return state
}
// NoGrad runs a closure without keeping track of gradients.
2023-07-05 15:20:11 +01:00
func NoGrad(fn func()) {
// Switch off Grad
2023-07-04 14:26:20 +01:00
MustGradSetEnabled(false)
2023-07-04 14:26:20 +01:00
fn()
// Switch on Grad
2023-07-04 14:26:20 +01:00
MustGradSetEnabled(true)
}
func NoGrad1(fn func() interface{}) interface{} {
newTs := NewTensor()
newTs.Drop()
// Switch off Grad
prev := MustGradSetEnabled(false)
retVal := fn()
// Switch on Grad
_ = MustGradSetEnabled(prev)
return retVal
}
// NoGradGuard is a RAII guard that prevents gradient tracking until deallocated.
// It actually sets a global flag that is checked by the backend whenever an op is done on a variable.
// The guard itself saved the current status and set it to false in the constructor.
// And restore the saved status in its destructor.
// That way it is similar to a with torch.no_grad(): block in python.
// Ref. https://discuss.pytorch.org/t/how-does-nogradguard-works-in-cpp/34960/2
//
// TODO: should we implement Go `mutex` here???
type NoGradGuard struct {
enabled bool
}
2020-06-17 11:29:31 +01:00
// Init NoGradGuard and disables gradient tracking
func NewNoGradGuard() *NoGradGuard {
2020-06-17 11:29:31 +01:00
return noGradGuardInit()
}
// Disables gradient tracking, this will be enabled back when the
// returned value gets deallocated.
func noGradGuardInit() *NoGradGuard {
return &NoGradGuard{enabled: MustGradSetEnabled(false)}
}
// Drop drops the NoGradGuard state.
2020-06-17 11:29:31 +01:00
func (ngg *NoGradGuard) Drop() {
ngg.enabled = true
_ = MustGradSetEnabled(ngg.enabled)
}
func (ngg *NoGradGuard) Enable() {
ngg.enabled = false
_ = MustGradSetEnabled(ngg.enabled)
}
const (
// Do not reduce
2022-03-12 04:47:15 +00:00
ReductionNone int64 = 0
// Mean of losses
2022-03-12 04:47:15 +00:00
ReductionMean int64 = 1
// Sum of losses
2022-03-12 04:47:15 +00:00
ReductionSum int64 = 2
// Escape hatch in case new options become available
2022-03-12 04:47:15 +00:00
ReductionOther int64 = 3
)
2022-03-12 04:47:15 +00:00
// func (r Reduction) ToInt() int {
// switch r {
// case ReductionNone:
// return 0
// case ReductionMean:
// return 1
// case ReductionSum:
// return 2
// case ReductionOther:
// return 3
// }
//
// // NOTE. should it be panic here instead of returning -1?
// return -1
// }
// Float64Values returns values of tensor in a slice of float64.
2022-03-12 04:47:15 +00:00
func (ts *Tensor) Float64Values(delOpt ...bool) []float64 {
del := false
if len(delOpt) > 0 {
del = delOpt[0]
}
2020-07-17 02:32:33 +01:00
numel := ts.Numel()
vec := make([]float64, numel)
float64Ts := ts.MustTotype(gotch.Double, false)
float64Ts.MustCopyData(vec, numel)
2023-07-04 14:26:20 +01:00
// float64Ts.MustDrop()
2020-07-17 02:32:33 +01:00
2022-03-12 04:47:15 +00:00
if del {
ts.MustDrop()
}
2020-07-17 02:32:33 +01:00
return vec
}
2020-06-29 08:56:11 +01:00
2020-07-25 10:02:48 +01:00
// Int64Values returns values of tensor in a slice of int64.
2022-03-12 04:47:15 +00:00
func (ts *Tensor) Int64Values(delOpt ...bool) []int64 {
del := false
if len(delOpt) > 0 {
del = delOpt[0]
}
2020-07-25 10:02:48 +01:00
numel := ts.Numel()
vec := make([]int64, numel)
int64Ts := ts.MustTotype(gotch.Int64, false)
int64Ts.MustCopyData(vec, numel)
int64Ts.MustDrop()
2022-03-12 04:47:15 +00:00
if del {
ts.MustDrop()
}
2020-07-25 10:02:48 +01:00
return vec
}
2020-07-10 06:28:37 +01:00
// Vals returns tensor values in a slice
// NOTE: need a type insersion to get runtime type
// E.g. res := xs.Vals().([]int64)
func (ts *Tensor) Vals() interface{} {
2020-07-10 06:28:37 +01:00
dtype := ts.DType()
2023-07-06 15:01:23 +01:00
numel := int(ts.Numel())
2020-07-10 06:28:37 +01:00
2023-07-06 15:01:23 +01:00
typ, err := dtype.GoType()
if err != nil {
log.Fatal(err)
}
dataSlice := reflect.MakeSlice(reflect.SliceOf(typ), numel, numel).Interface()
ts.CopyData(dataSlice, uint(numel))
return dataSlice
2020-07-10 06:28:37 +01:00
}
2020-06-29 08:56:11 +01:00
// FlatView flattens a tensor.
//
// This returns a flattened version of the given tensor. The first dimension
// is preserved as it is assumed to be the mini-batch dimension.
func (ts *Tensor) FlatView() *Tensor {
2020-06-29 08:56:11 +01:00
batchSize := ts.MustSize()[0]
return ts.MustView([]int64{batchSize, -1}, false)
}
2020-07-06 03:02:38 +01:00
func (ts *Tensor) ZeroPad2d(left, right, top, bottom int64, del bool) (*Tensor, error) {
2020-07-06 03:02:38 +01:00
if ts.Dim() != 4 {
err := fmt.Errorf("Expected a 4 dimension tensor, got %v\n", ts.MustSize())
return nil, err
2020-07-06 03:02:38 +01:00
}
return ts.ConstantPadNd([]int64{left, right, top, bottom}, del)
}
func (ts *Tensor) MustZeroPad2d(left, right, top, bottom int64, del bool) *Tensor {
2020-07-06 03:02:38 +01:00
retVal, err := ts.ZeroPad2d(left, right, top, bottom, del)
if err != nil {
log.Fatal(err)
}
return retVal
}
// Onehot converts a tensor to a one-hot encoded version.
//
// If the input has a size [N1, N2, ..., Nk], the returned tensor has a size
// [N1, ..., Nk, labels]. The returned tensor uses float values.
// Elements of the input vector are expected to be between 0 and labels-1.
//
// NOTE: There's other `ts.OneHot` and `ts.MustOneHot` generated from Atg C++ API
func (ts *Tensor) Onehot(labels int64) *Tensor {
dims := ts.MustSize()
dims = append(dims, labels)
unsqueezeTs := ts.MustUnsqueeze(-1, false)
inputTs := unsqueezeTs.MustTotype(gotch.Int64, true)
zerosTs := MustZeros(dims, gotch.Float, gotch.CPU)
2021-07-22 15:54:41 +01:00
retVal := zerosTs.MustScatterValue(-1, inputTs, FloatScalar(1.0), true)
inputTs.MustDrop()
return retVal
}
func (ts *Tensor) Swish() *Tensor {
2020-07-06 03:02:38 +01:00
sig := ts.MustSigmoid(false)
mulTs := ts.MustMul(sig, false)
2020-07-06 03:02:38 +01:00
sig.MustDrop()
return mulTs
2020-07-06 03:02:38 +01:00
}
2020-07-06 09:32:43 +01:00
func (ts *Tensor) AvgPool2DDefault(ksize int64, del bool) *Tensor {
return ts.MustAvgPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, false, true, []int64{1}, del)
2020-07-06 09:32:43 +01:00
}
// SaveMultiNew saves a slice of named tensors to the given file path.
func SaveMultiNew(namedTensors []NamedTensor, path string) error {
var (
tensors []lib.Ctensor
names []string
)
for _, nts := range namedTensors {
tensors = append(tensors, nts.Tensor.ctensor)
names = append(names, nts.Name)
}
lib.AtSaveMultiNew(tensors, names, path)
if err := TorchErr(); err != nil {
return err
}
return nil
}
2022-01-17 10:41:16 +00:00
func (ts *Tensor) ConstantPadNdWithVal(pad []int64, value *Scalar, del bool) (retVal *Tensor, err error) {
if del {
defer ts.MustDrop()
}
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtoConstantPadNd(ptr, ts.ctensor, pad, len(pad), value.cscalar)
if err = TorchErr(); err != nil {
return retVal, err
}
2023-07-04 14:26:20 +01:00
retVal = newTensor(*ptr)
2022-01-17 10:41:16 +00:00
return retVal, err
}
func (ts *Tensor) MustConstantPadNdWithVal(pad []int64, value *Scalar, del bool) (retVal *Tensor) {
retVal, err := ts.ConstantPadNdWithVal(pad, value, del)
if err != nil {
log.Fatal(err)
}
return retVal
}
// TT. Added some torch.cuda APIs for handling CUDA qt
// CudaCurrentDevice get device index of current CUDA device.
func CudaCurrentDevice() (int, error) {
currentDeviceIndex := lib.AtcGetDevice()
if err := TorchErr(); err != nil {
err = fmt.Errorf("ts.CudaCurrentDevice() failed: %w\n", err)
return -99, err
}
return currentDeviceIndex, nil
}
// CudaSetDevice set new cuda device index and returns previous cuda index.
func CudaSetDevice(cudaDeviceIndex int) (int, error) {
currentDeviceIndex, err := CudaCurrentDevice()
if err != nil {
err = fmt.Errorf("ts.CudaSetDevice() failed: %w\n", err)
return -99, err
}
lib.AtcSetDevice(cudaDeviceIndex)
if err := TorchErr(); err != nil {
err = fmt.Errorf("ts.CudaSetDevice() failed: %w\n", err)
return -99, err
}
return currentDeviceIndex, nil
}
// CudaSynchronize waits for all kernels in all streams on a CUDA device to complete.
func CudaSynchronize(cudaDeviceIndexOpt ...int) error {
var cudaDeviceIndex int
var err error
if len(cudaDeviceIndexOpt) > 0 {
cudaDeviceIndex = cudaDeviceIndexOpt[0]
} else {
cudaDeviceIndex, err = CudaCurrentDevice()
if err != nil {
err := fmt.Errorf("ts.CudaSynchronize() failed: %w\n", err)
return err
}
}
lib.AtcSynchronize(int64(cudaDeviceIndex))
return TorchErr()
}