gotch/ts/tensor.go

1448 lines
33 KiB
Go
Raw Normal View History

2022-03-12 07:20:20 +00:00
package ts
2020-05-28 17:58:23 +01:00
//#include "stdlib.h"
//#include "stdbool.h"
//#include<stdio.h>
2020-05-28 17:58:23 +01:00
import "C"
import (
"bytes"
"encoding/binary"
"fmt"
"log"
2020-05-28 17:58:23 +01:00
"reflect"
2023-07-04 14:26:20 +01:00
"runtime"
"sync"
"sync/atomic"
"time"
"unsafe"
2020-05-28 17:58:23 +01:00
gotch "github.com/sugarme/gotch"
2020-05-28 17:58:23 +01:00
lib "github.com/sugarme/gotch/libtch"
)
2023-07-04 14:26:20 +01:00
var (
TensorCount int64 // incremental counting created tensors
ScalarCount int64 // incremental counting created scalars
AllocatedMem int64 // bytes - keeping track of memory created and still occupied by gotch/tensor (excluding mem allocated by libtorch at C side)
ExistingTensors map[string]struct{} = make(map[string]struct{}) // keep track of existing tensors by name
ExistingScalars map[string]struct{} = make(map[string]struct{}) // keep track of existing scalar by name
lock sync.Mutex
)
2023-07-05 14:56:48 +01:00
// NOTE. None is an undefined tensor.
2023-07-04 14:26:20 +01:00
// It can be used in optional tensor parameter where 'None' value used.
// `ts.MustDefined()` function is used for checking 'null'
var None = NewTensor()
type bigStruct struct {
2023-07-05 14:56:48 +01:00
lots [1e5]byte // 100k - always on host memory.
2023-07-04 14:26:20 +01:00
}
2023-07-05 14:56:48 +01:00
// Tensor is a Go wrapper to a C tensor pointer - 8 Bytes (64-bits OS) or 4 Bytes (32-bits OS)
// ctensor is just a C pointer to `torch::Tensor`
2023-07-04 14:26:20 +01:00
//
// NOTE.Tensor should be big enough to be in a heap.
// See. https://stackoverflow.com/questions/10866195
2020-05-28 17:58:23 +01:00
type Tensor struct {
2023-07-04 14:26:20 +01:00
d *bigStruct
name string
ctensor lib.Ctensor
2020-05-28 17:58:23 +01:00
}
2023-07-04 14:26:20 +01:00
func newTensor(ctensor lib.Ctensor, nameOpt ...string) *Tensor {
if len(nameOpt) == 0 {
nameOpt = []string{}
}
name := newName(nameOpt...)
x := new(Tensor)
x.ctensor = ctensor
x.name = name
x.d = new(bigStruct)
atomic.AddInt64(&TensorCount, 1)
nbytes := x.nbytes()
atomic.AddInt64(&AllocatedMem, nbytes)
lock.Lock()
ExistingTensors[name] = struct{}{}
lock.Unlock()
if gotch.Debug {
log.Printf("INFO: Added tensor %q - Allocated memory: %d bytes.\n", x.name, nbytes)
}
runtime.SetFinalizer(x, freeCTensor)
return x
}
func CheckCMemLeak() string {
tensors := []string{}
for n := range ExistingTensors {
tensors = append(tensors, n)
}
var msg string
msg += fmt.Sprintf("============================= C MEMORY CHECK RESULT ==================================\n")
msg += fmt.Sprintf("C memory allocated not been released: %v bytes\n", AllocatedMem)
msg += fmt.Sprintf("Tensors not been released: %q\n", tensors)
msg += fmt.Sprintf("======================================================================================\n")
return msg
}
func CleanUp(sleepTimeOpt ...int) {
sleepTime := time.Duration(1000) // 1 second
if len(sleepTimeOpt) > 0 {
sleepTime = time.Duration(sleepTimeOpt[0])
}
if gotch.Debug {
fmt.Printf(CheckCMemLeak())
fmt.Println(">>>>>>>>>>>>>>Last runtime.GC() call at ts.CleanUp() <<<<<<<<<<<<<<<<<<<<<<<<<<<")
}
runtime.GC()
time.Sleep(time.Millisecond * sleepTime)
runtime.GC()
if gotch.Debug {
fmt.Println(">>>>>>>>>>>>>>After last GC called at ts.CleanUp(): <<<<<<<<<<<<<<<<<<<<<<<<<")
fmt.Printf(CheckCMemLeak())
}
}
2023-07-05 14:56:48 +01:00
// Ctensor return C pointer value.
2022-02-16 00:39:27 +00:00
func (ts *Tensor) Ctensor() unsafe.Pointer {
return unsafe.Pointer(ts.ctensor)
}
2023-07-04 14:26:20 +01:00
// free releases C allocated memory.
func freeCTensor(ts *Tensor) error {
nbytes := ts.nbytes()
atomic.AddInt64(&AllocatedMem, -nbytes)
lock.Lock()
defer lock.Unlock()
delete(ExistingTensors, ts.name)
lib.AtFree(ts.ctensor)
if err := TorchErr(); err != nil {
err := fmt.Errorf("ERROR: failed to release tensor %q - %w", ts.name, err)
return err
}
if gotch.Debug {
log.Printf("INFO: Released tensor %q - C memory(%d bytes).\n", ts.name, nbytes)
}
return nil
}
func newName(nameOpt ...string) string {
var name string
if len(nameOpt) > 0 {
name = nameOpt[0]
} else {
2023-07-05 14:56:48 +01:00
name = fmt.Sprintf("tensor_%09d", TensorCount)
2023-07-04 14:26:20 +01:00
}
return name
}
2020-08-05 01:31:01 +01:00
2020-05-28 17:58:23 +01:00
// NewTensor creates a new tensor
2023-07-04 14:26:20 +01:00
func NewTensor(nameOpt ...string) *Tensor {
2020-05-28 17:58:23 +01:00
ctensor := lib.AtNewTensor()
2023-07-04 14:26:20 +01:00
return newTensor(ctensor, nameOpt...)
2020-05-28 17:58:23 +01:00
}
2022-02-16 00:39:27 +00:00
func FromCtensor(ctensor unsafe.Pointer) *Tensor {
cts := (lib.Ctensor)(ctensor)
2023-07-04 14:26:20 +01:00
return newTensor(cts)
}
func (ts *Tensor) Name() string {
return ts.name
2022-02-16 00:39:27 +00:00
}
func (ts *Tensor) Dim() uint64 {
dim := lib.AtDim(ts.ctensor)
2020-06-03 02:03:38 +01:00
if err := TorchErr(); err != nil {
log.Fatal(err)
}
return dim
2020-06-01 08:37:05 +01:00
}
// Size return shape of the tensor
//
// NOTE: C++ libtorch calls at_shape() -> t.sizes()
// And returns a slice of sizes or shape using given pointer
// to that slice.
func (ts *Tensor) Size() ([]int64, error) {
2020-06-01 08:37:05 +01:00
dim := lib.AtDim(ts.ctensor)
sz := make([]int64, dim)
szPtr, err := DataAsPtr(sz)
if err != nil {
return nil, err
}
defer C.free(unsafe.Pointer(szPtr))
lib.AtShape(ts.ctensor, szPtr)
if err = TorchErr(); err != nil {
return nil, err
}
shape := decodeSize(szPtr, dim)
return shape, nil
}
func (ts *Tensor) MustSize() []int64 {
shape, err := ts.Size()
2020-06-08 04:28:07 +01:00
if err != nil {
log.Fatal(err)
}
return shape
2020-06-08 04:28:07 +01:00
}
// Size1 returns the tensor size for 1D tensors.
func (ts *Tensor) Size1() (int64, error) {
shape, err := ts.Size()
if err != nil {
return 0, err
}
if len(shape) != 1 {
err = fmt.Errorf("Expected one dim, got %v\n", len(shape))
return 0, err
}
return shape[0], nil
2020-06-01 08:37:05 +01:00
}
// Size2 returns the tensor size for 2D tensors.
func (ts *Tensor) Size2() ([]int64, error) {
shape, err := ts.Size()
if err != nil {
return nil, err
}
if len(shape) != 2 {
err = fmt.Errorf("Expected two dims, got %v\n", len(shape))
return nil, err
}
return shape, nil
}
// Size3 returns the tensor size for 3D tensors.
func (ts *Tensor) Size3() ([]int64, error) {
shape, err := ts.Size()
if err != nil {
return nil, err
}
if len(shape) != 3 {
err = fmt.Errorf("Expected three dims, got %v\n", len(shape))
return nil, err
}
return shape, nil
}
// Size4 returns the tensor size for 4D tensors.
func (ts *Tensor) Size4() ([]int64, error) {
shape, err := ts.Size()
if err != nil {
return nil, err
}
if len(shape) != 4 {
err = fmt.Errorf("Expected four dims, got %v\n", len(shape))
return nil, err
}
return shape, nil
}
2023-07-04 14:26:20 +01:00
// nbytes calculates tensor data size in bytes.
func (ts *Tensor) nbytes() int64 {
numel := ts.Numel()
if numel == 0 {
return 0 // ts.None
}
dtype := ts.DType()
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
log.Fatal(err)
}
nbytes := int64(numel) * int64(eltSizeInBytes)
return nbytes
}
func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
// Decode sz
// 1. Count number of elements in data
elementNum := nsize
// 2. Element size in bytes
eltSizeInBytes, err := gotch.DTypeSize(gotch.Int64)
if err != nil {
log.Fatal(err)
}
nbytes := int(eltSizeInBytes) * int(elementNum)
dataSlice := (*[1 << 30]byte)(ptr)[:nbytes:nbytes]
r := bytes.NewReader(dataSlice)
dataIn := make([]int64, nsize)
if err := binary.Read(r, nativeEndian, dataIn); err != nil {
log.Fatal(err)
}
return dataIn
}
// OfSlice creates tensor from a slice data
2023-07-04 14:26:20 +01:00
func OfSlice(data interface{}, nameOpt ...string) (*Tensor, error) {
// convert []int -> int32. `binary.Write()` can't write `[]int` because it's not fixed-size!
if reflect.TypeOf(data).String() == "[]int" {
data = sliceIntToInt32(data.([]int))
}
2020-05-28 17:58:23 +01:00
v := reflect.ValueOf(data)
kind := v.Kind().String()
if kind != "slice" && kind != "array" {
err := fmt.Errorf("Expected slice data. Got %q", kind)
return nil, err
}
typ := reflect.TypeOf(data).Elem()
dataLen := v.Len()
dtype, err := gotch.ToDType(typ)
if err != nil {
return nil, err
}
shape := []int64{int64(dataLen)}
elementNum := ElementCount(shape)
2020-05-30 02:15:36 +01:00
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return nil, err
}
2020-05-28 17:58:23 +01:00
nbytes := int(eltSizeInBytes) * int(elementNum)
2020-05-28 17:58:23 +01:00
dataPtr, buff := CMalloc(nbytes)
defer C.free(unsafe.Pointer(dataPtr))
2020-05-28 17:58:23 +01:00
if err = EncodeTensor(buff, reflect.ValueOf(data), shape); err != nil {
return nil, err
}
2020-05-28 17:58:23 +01:00
cint, err := gotch.DType2CInt(dtype)
if err != nil {
return nil, err
}
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
if err = TorchErr(); err != nil {
return nil, err
}
2020-05-28 17:58:23 +01:00
2023-07-04 14:26:20 +01:00
return newTensor(ctensor, nameOpt...), nil
// return newTensor(ctensor), nil
2020-05-28 17:58:23 +01:00
}
// OfDataSize creates Tensor from input byte data, shape and dtype.
2023-07-04 14:26:20 +01:00
func OfDataSize(data []byte, shape []int64, dtype gotch.DType, nameOpt ...string) (*Tensor, error) {
2020-11-16 12:37:44 +00:00
2020-11-17 07:31:29 +00:00
elementNum := ElementCount(shape)
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return nil, err
}
nbytes := int(eltSizeInBytes) * int(elementNum)
if nbytes != len(data) {
err := fmt.Errorf("data and shape mismatched for dtype (%v): byte data (%v) - shape (%v).\n", dtype, len(data), shape)
return nil, err
}
dataPtr, buff := CMalloc(nbytes)
defer C.free(unsafe.Pointer(dataPtr))
if err := binary.Write(buff, nativeEndian, data); err != nil {
2020-11-17 07:31:29 +00:00
return nil, err
}
cint, err := gotch.DType2CInt(dtype)
if err != nil {
return nil, err
}
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
if err = TorchErr(); err != nil {
return nil, err
}
2023-07-04 14:26:20 +01:00
return newTensor(ctensor, nameOpt...), nil
// return newTensor(ctensor), nil
2020-11-16 12:37:44 +00:00
}
// MustOfDataSize create Tensor from input byte data and specified shape and dtype
// or panic if error
func MustOfDataSize(data []byte, size []int64, dtype gotch.DType) *Tensor {
ts, err := OfDataSize(data, size, dtype)
if err != nil {
log.Fatal(err)
}
return ts
}
// MustOfSlice create a tensor from slice of data. It will be panic if error.
func MustOfSlice(data interface{}) *Tensor {
ts, err := OfSlice(data)
if err != nil {
log.Fatal(err)
}
return ts
}
// TensorFrom create a tensor from slice of data. It will be panic if error.
2023-07-04 14:26:20 +01:00
func TensorFrom(data interface{}, nameOpt ...string) *Tensor {
ts, err := OfSlice(data, nameOpt...)
if err != nil {
log.Fatal(err)
}
return ts
}
2020-05-30 02:15:36 +01:00
// Print prints tensor values to console.
//
// NOTE: it is printed from C and will print ALL elements of tensor
// with no truncation at all.
func (ts *Tensor) Print() {
lib.AtPrint(ts.ctensor)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
}
// NewTensorFromData creates tensor from given data and shape
2023-07-04 14:26:20 +01:00
func NewTensorFromData(data interface{}, shape []int64, nameOpt ...string) (*Tensor, error) {
// 1. Check whether data and shape match
elementNum, err := DataDim(data)
if err != nil {
return nil, err
}
nflattend := FlattenDim(shape)
if elementNum != nflattend {
err = fmt.Errorf("Number of data elements (%v) and flatten shape (%v) dimension mismatched.\n", elementNum, nflattend)
return nil, err
}
// 2. Write raw data to C memory and get C pointer
dataPtr, err := DataAsPtr(data)
defer C.free(unsafe.Pointer(dataPtr))
if err != nil {
return nil, err
}
// 3. Create tensor with pointer and shape
dtype, err := gotch.DTypeFromData(data)
if err != nil {
return nil, err
}
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return nil, err
}
cint, err := gotch.DType2CInt(dtype)
if err != nil {
return nil, err
}
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
if err = TorchErr(); err != nil {
return nil, err
}
2023-07-04 14:26:20 +01:00
return newTensor(ctensor, nameOpt...), nil
}
2020-06-02 10:29:24 +01:00
func (ts *Tensor) DType() gotch.DType {
2020-06-02 10:29:24 +01:00
cint := lib.AtScalarType(ts.ctensor)
dtype, err := gotch.CInt2DType(cint)
if err != nil {
log.Fatalf("Tensor DType error: %v\n", err)
}
return dtype
}
func (ts *Tensor) Device() (gotch.Device, error) {
var (
retVal gotch.Device
err error
)
cInt := lib.AtDevice(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
var device gotch.Device
return device.OfCInt(int32(cInt)), nil
}
func (ts *Tensor) MustDevice() gotch.Device {
device, err := ts.Device()
2020-08-01 07:33:30 +01:00
if err != nil {
log.Fatal(err)
}
return device
2020-08-01 07:33:30 +01:00
}
/*
* func (ts Tensor) Eq1(other Tensor, del bool) (retVal Tensor, err error) {
*
* // Get a C null pointer
* // https://stackoverflow.com/a/2022369
* ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
* if del {
* defer ts.MustDrop()
* }
*
* lib.AtgEq1(ptr, ts.ctensor, other.ctensor)
* if err = TorchErr(); err != nil {
* return retVal, err
* }
*
* return Tensor{ctensor: *ptr}, nil
*
* }
*
* func (ts Tensor) MustEq1(other Tensor, del bool) (retVal Tensor) {
* retVal, err := ts.Eq1(other, del)
* if err != nil {
* log.Fatal(err)
* }
*
* return retVal
* }
* */
2020-06-14 03:51:38 +01:00
// Float64Value returns a float value on tensors holding a single element.
// An error is returned otherwise.
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
func (ts *Tensor) Float64Value(idx []int64) (float64, error) {
idxPtr, err := DataAsPtr(idx)
if err != nil {
return 0, err
}
defer C.free(unsafe.Pointer(idxPtr))
f64Val := lib.AtDoubleValueAtIndexes(ts.ctensor, idxPtr, len(idx))
if err = TorchErr(); err != nil {
return 0, err
}
return f64Val, err
}
func (ts *Tensor) MustFloat64Value(idx []int64) float64 {
f64Val, err := ts.Float64Value(idx)
if err != nil {
log.Fatal(err)
}
return f64Val
}
// Int64Value returns an int value on tensors holding a single element. An error is
// returned otherwise.
func (ts *Tensor) Int64Value(idx []int64) (int64, error) {
var (
retVal int64
err error
)
idxPtr, err := DataAsPtr(idx)
if err != nil {
return retVal, err
}
defer C.free(unsafe.Pointer(idxPtr))
int64Val := lib.AtInt64ValueAtIndexes(ts.ctensor, idxPtr, len(idx))
if err = TorchErr(); err != nil {
return 0, err
}
return int64Val, err
}
func (ts *Tensor) MustInt64Value(idx []int64) int64 {
int64Val, err := ts.Int64Value(idx)
2020-06-30 11:01:01 +01:00
if err != nil {
log.Fatal(err)
}
return int64Val
2020-06-30 11:01:01 +01:00
}
// RequiresGrad returns true if gradient are currently tracked for this tensor.
func (ts *Tensor) RequiresGrad() (bool, error) {
state := lib.AtRequiresGrad(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
}
return state, nil
}
func (ts *Tensor) MustRequiresGrad() bool {
state, err := ts.RequiresGrad()
if err != nil {
log.Fatal(err)
}
return state
}
// DataPtr returns the address of the first element of this tensor.
func (ts *Tensor) DataPtr() (unsafe.Pointer, error) {
datPtr := lib.AtDataPtr(ts.ctensor)
if err := TorchErr(); err != nil {
return nil, err
}
return datPtr, nil
}
// Defined returns true is the tensor is defined.
func (ts *Tensor) Defined() (bool, error) {
state := lib.AtDefined(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
}
return state, nil
}
func (ts *Tensor) MustDefined() bool {
state, err := ts.Defined()
if err != nil {
log.Fatal(err)
}
return state
}
// IsSparse returns true is the tensor is spare.
func (ts *Tensor) IsSparse() (bool, error) {
state := lib.AtIsSparse(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
}
return state, nil
}
// ZeroGrad zeroes the gradient tensor attached to this tensor if defined.
func (ts *Tensor) ZeroGrad() {
grad := ts.MustGrad(false)
if grad.MustDefined() {
grad.Detach_()
grad.Zero_()
}
}
// Backward runs the backward pass, populating the gradient tensors for tensors
// which gradients are tracked.
//
// Gradients tracking can be turned on via `SetRequiresGrad`.
func (ts *Tensor) Backward() error {
lib.AtBackward(ts.ctensor, 0, 0)
if err := TorchErr(); err != nil {
return err
}
return nil
}
func (ts *Tensor) MustBackward() {
if err := ts.Backward(); err != nil {
log.Fatal(err)
}
}
// RunBackward runs the backward ...
2023-07-05 14:56:48 +01:00
func RunBackward(tensors []*Tensor, inputs []*Tensor, keepGraphB bool, createGraphB bool) ([]*Tensor, error) {
// NOTE: outputs is a slice of tensors with length = len(inputs)
var outputsPtr []*lib.Ctensor
// Are they allocated contigously??? Definitely not.
// TODO. calculate C memory size = C pointer size x n pointers
// Then C.malloc such calculated amount
// NOTE. replace with the following code and test.
/*
* ntensors := len(inputs)
* nbytes := C.size_t(ntensors) * C.size_t(unsafe.Sizeof(uintptr(0)))
* ctensorsPtr := (*[1 << 30]lib.Ctensor)(C.malloc(nbytes))
* for i :=0; i < ntensors; i++ {
* outputPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
* outputsPtr[i] = outputPtr
* }
* */
for i := 0; i < len(inputs); i++ {
outputPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(outputPtr))
outputsPtr = append(outputsPtr, outputPtr)
}
// Get first element pointer
ctensor := tensors[0].ctensor
cinput := inputs[0].ctensor
tensorsPtr := (*lib.Ctensor)(unsafe.Pointer(&ctensor))
inputsPtr := (*lib.Ctensor)(unsafe.Pointer(&cinput))
var keepGraph int = 0
if keepGraphB {
keepGraph = 1
}
var createGraph int = 0
if createGraphB {
createGraph = 1
}
lib.AtRunBackward(tensorsPtr, len(tensors), inputsPtr, len(inputs), outputsPtr[0], keepGraph, createGraph)
if err := TorchErr(); err != nil {
return nil, err
}
2023-07-05 14:56:48 +01:00
var oTensors []*Tensor
for i := 0; i < len(inputs); i++ {
outputPtr := outputsPtr[i]
2023-07-05 14:56:48 +01:00
oTensors = append(oTensors, newTensor(*outputPtr))
}
return oTensors, nil
}
2020-06-08 04:28:07 +01:00
// CopyDataUint8 copies `numel` elements from `self` to `dst`.
//
// NOTE: `dst` located in Go memory. Should it be?
func (ts *Tensor) CopyDataUint8(dst []uint8, numel uint) error {
2020-06-08 04:28:07 +01:00
// NOTE: we must make sure that `dst` has same len as `numel`. Otherwise,
// there will be memory leak and or out of range error.
if len(dst) < int(numel) {
err := fmt.Errorf("CopyDataUint8 Error: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", len(dst), numel)
2020-06-08 04:28:07 +01:00
return err
}
vs := unsafe.Pointer(&dst[0])
elt_size_in_bytes, err := gotch.DTypeSize(gotch.Uint8)
if err != nil {
return err
}
lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes)
if err = TorchErr(); err != nil {
return err
}
return nil
}
func (ts *Tensor) MustCopyDataUint8(dst []uint8, numel uint) {
2020-06-08 04:28:07 +01:00
err := ts.CopyDataUint8(dst, numel)
if err != nil {
log.Fatal(err)
}
}
// CopyData copies `numel` elements from `self` to `dst`.
// `dst` should be a slice of Go type equivalent to tensor type.
//
// NOTE: `dst` located in Go memory. Should it be?
2020-07-17 02:22:04 +01:00
// We will render Go pointer of first element of `dst` slice
// and number of elements to C land. This may break in the future
// if Go policy changes.
func (ts *Tensor) CopyData(dst interface{}, numel uint) error {
2020-06-08 04:28:07 +01:00
gotype, dlen, err := DataCheck(dst)
if err != nil {
return err
}
dtype, err := gotch.ToDType(gotype)
2020-06-08 04:28:07 +01:00
if err != nil {
return err
}
if dlen < int(numel) {
2020-07-17 02:22:04 +01:00
err = fmt.Errorf("CopyData Error: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel)
2020-06-08 04:28:07 +01:00
return err
}
if ts.DType() != dtype {
err = fmt.Errorf("Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
return err
}
var vs unsafe.Pointer
switch dtype {
case gotch.Uint8:
vs = unsafe.Pointer(&dst.([]uint8)[0])
case gotch.Int8:
vs = unsafe.Pointer(&dst.([]int8)[0])
case gotch.Int16:
vs = unsafe.Pointer(&dst.([]int16)[0])
case gotch.Int:
vs = unsafe.Pointer(&dst.([]int32)[0])
case gotch.Int64:
vs = unsafe.Pointer(&dst.([]int64)[0])
case gotch.Float:
vs = unsafe.Pointer(&dst.([]float32)[0])
case gotch.Double:
vs = unsafe.Pointer(&dst.([]float64)[0])
case gotch.Bool:
vs = unsafe.Pointer(&dst.([]bool)[0])
default:
err = fmt.Errorf("Unsupported type: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
return err
}
elt_size_in_bytes, err := gotch.DTypeSize(dtype)
2020-06-08 04:28:07 +01:00
if err != nil {
return err
}
lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes)
if err = TorchErr(); err != nil {
return err
}
return nil
}
// MustCopyData copies number of elements from tensor to a slice of data
//
// NOTE: `dst` is a slice with length = numel and Go type equavalent to tensor
// DType
func (ts *Tensor) MustCopyData(dst interface{}, numel uint) {
2020-06-08 04:28:07 +01:00
err := ts.CopyData(dst, numel)
if err != nil {
log.Fatal(err)
}
}
// Numel returns the total number of elements stored in a tensor.
func (ts *Tensor) Numel() uint {
2023-07-04 14:26:20 +01:00
if !ts.MustDefined() {
return 0 // ts.None case
}
shape := ts.MustSize()
2020-06-08 04:28:07 +01:00
return uint(FlattenDim(shape))
}
// ShallowClone returns a new tensor that share storage with the input tensor.
func (ts *Tensor) ShallowClone() (*Tensor, error) {
ctensor := lib.AtShallowClone(ts.ctensor)
if err := TorchErr(); err != nil {
return nil, err
}
2023-07-04 14:26:20 +01:00
name := fmt.Sprintf("%s_cloned", ts.name)
return newTensor(ctensor, name), nil
}
// MustShallowClone returns a new tensor that share storage with the input
// tensor. It will panic if error occurred
func (ts *Tensor) MustShallowClone() *Tensor {
newTs, err := ts.ShallowClone()
if err != nil {
log.Fatal(err)
}
return newTs
}
// Get gets the sub-tensor at the given index.
func (ts *Tensor) Get(index int) (*Tensor, error) {
ctensor := lib.AtGet(ts.ctensor, index)
if err := TorchErr(); err != nil {
return nil, err
}
2023-07-04 14:26:20 +01:00
return newTensor(ctensor), nil
}
// MustGet gets the sub-tensor at the given index. It will panic if error
// occurred.
func (ts *Tensor) MustGet(index int) *Tensor {
subTs, err := ts.Get(index)
if err != nil {
log.Fatal(err)
}
return subTs
}
// Copy_ copies in-place values from the argument tensor to the input tensor.
func Copy_(self, src *Tensor) {
lib.AtCopy_(self.ctensor, src.ctensor)
if err := TorchErr(); err != nil {
2020-06-14 03:51:38 +01:00
log.Fatal(err)
}
}
2020-06-14 03:51:38 +01:00
// Copy_ copies in-place values from the argument tensor to existing tensor
func (ts *Tensor) Copy_(src *Tensor) {
2020-06-14 03:51:38 +01:00
lib.AtCopy_(ts.ctensor, src.ctensor)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
}
2020-06-08 08:06:35 +01:00
// Save saves a tensor to a file.
func (ts *Tensor) Save(path string) error {
2020-06-08 08:06:35 +01:00
lib.AtSave(ts.ctensor, path)
if err := TorchErr(); err != nil {
2020-06-08 08:06:35 +01:00
return err
}
return nil
}
// MustSave saves a tensor to a file. It will panic if error
func (ts *Tensor) MustSave(path string) {
2020-06-08 08:06:35 +01:00
if err := ts.Save(path); err != nil {
log.Fatal(err)
}
}
// Load loads a tensor from a file.
2023-07-04 14:26:20 +01:00
func Load(path string, nameOpt ...string) (*Tensor, error) {
2020-06-08 08:06:35 +01:00
ctensor := lib.AtLoad(path)
if err := TorchErr(); err != nil {
return nil, err
2020-06-08 08:06:35 +01:00
}
2023-07-04 14:26:20 +01:00
return newTensor(ctensor, nameOpt...), nil
2020-06-08 08:06:35 +01:00
}
// MustLoad loads a tensor to a file. It will panic if error
2023-07-04 14:26:20 +01:00
func MustLoad(path string, nameOpt ...string) *Tensor {
ts, err := Load(path, nameOpt...)
2020-06-08 08:06:35 +01:00
if err != nil {
log.Fatal(err)
}
return ts
2020-06-08 08:06:35 +01:00
}
// NamedTensor wraps C tensor and its name
2020-06-08 08:06:35 +01:00
type NamedTensor struct {
Name string
Tensor *Tensor
2020-06-08 08:06:35 +01:00
}
// SaveMulti saves some named tensors to a file
//
// The file format is the same as the one used by the PyTorch C++ API.
// NOTE. This method is depreciated and will be replaced with `SaveMultiNew`
func SaveMulti(namedTensors []NamedTensor, path string) error {
var ctensors []lib.Ctensor
2020-06-08 08:06:35 +01:00
var names []string
for _, ts := range namedTensors {
ctensors = append(ctensors, ts.Tensor.ctensor)
names = append(names, ts.Name)
}
lib.AtSaveMulti(ctensors, names, len(namedTensors), path)
if err := TorchErr(); err != nil {
return err
}
2020-06-08 08:06:35 +01:00
return nil
}
// MustSaveMulti saves some named tensors to a file. It will panic if error
//
// NOTE. This method is depreciated and will be replaced with `MustSaveMultiNew`
func MustSaveMulti(namedTensors []NamedTensor, path string) {
err := SaveMulti(namedTensors, path)
if err != nil {
log.Fatal(err)
}
}
// LoadMulti loads some named tensors from a file
//
// The file format is the same as the one used by the PyTorch C++ API.
func LoadMulti(path string) ([]NamedTensor, error) {
var data lib.LoadData
dataPtr := lib.PStore.Set(&data)
lib.AtLoadCallback(path, dataPtr)
if err := TorchErr(); err != nil {
return nil, err
}
var namedTensors []NamedTensor
2020-06-10 09:31:07 +01:00
for _, v := range data.NamedCtensors {
namedTensor := NamedTensor{
Name: v.Name,
2023-07-04 14:26:20 +01:00
Tensor: newTensor(v.Ctensor, v.Name),
2020-06-10 09:31:07 +01:00
}
namedTensors = append(namedTensors, namedTensor)
2020-06-10 09:31:07 +01:00
}
return namedTensors, nil
}
// MustLoadMulti loads some named tensors from a file. It will panic if error
func MustLoadMulti(path string) []NamedTensor {
namedTensors, err := LoadMulti(path)
if err != nil {
log.Fatal(err)
}
return namedTensors
}
2020-06-10 22:24:10 +01:00
// LoadMultiWithDevice loads some named tensors from a file to a given device
//
// The file format is the same as the one used by the PyTorch C++ API.
func LoadMultiWithDevice(path string, device gotch.Device) ([]NamedTensor, error) {
var data lib.LoadData
dataPtr := lib.PStore.Set(&data)
lib.AtLoadCallbackWithDevice(path, dataPtr, device.CInt())
if err := TorchErr(); err != nil {
return nil, err
}
var namedTensors []NamedTensor
for _, v := range data.NamedCtensors {
namedTensor := NamedTensor{
Name: v.Name,
2023-07-04 14:26:20 +01:00
Tensor: newTensor(v.Ctensor, v.Name),
}
namedTensors = append(namedTensors, namedTensor)
}
return namedTensors, nil
}
// MustLoadMulti loads some named tensors from a file. It will panic if error
func MustLoadMultiWithDevice(path string, device gotch.Device) []NamedTensor {
namedTensors, err := LoadMultiWithDevice(path, device)
if err != nil {
log.Fatal(err)
}
return namedTensors
}
2020-06-10 22:24:10 +01:00
// ToString returns a string representation for the tensor.
//
// lw : line width (size)
// NOTE: The representation will contain all the tensor element hence may be huge for
// large tensors.
func (ts *Tensor) ToString(lw int64) (string, error) {
tensorStr := lib.AtToString(ts.ctensor, lw)
if err := TorchErr(); err != nil {
return "", err
2020-06-10 22:24:10 +01:00
}
return tensorStr, nil
2020-06-10 22:24:10 +01:00
}
// MustToString returns a string representation for the tensor. It will be panic
// if error.
// lw : line width (size)
func (ts *Tensor) MustToString(lw int64) string {
tensorStr, err := ts.ToString(lw)
2020-06-10 22:24:10 +01:00
if err != nil {
log.Fatal(err)
}
return tensorStr
2020-06-10 22:24:10 +01:00
}
2020-06-10 22:37:09 +01:00
// Drop drops (frees) the tensor
func (ts *Tensor) Drop() error {
2023-07-04 14:26:20 +01:00
// TODO. Detect ctensor is valid pointer, then do free otherwise, do nothing to avoid double-free error.
// FIXME. Get rid of this method as long as runtime.SetFinalizer() works properly.
return nil
2020-06-10 22:37:09 +01:00
lib.AtFree(ts.ctensor)
if err := TorchErr(); err != nil {
2020-06-10 22:37:09 +01:00
return err
}
2023-07-04 14:26:20 +01:00
if gotch.Debug {
nbytes := ts.nbytes()
atomic.AddInt64(&AllocatedMem, -nbytes)
log.Printf("INFO: Released tensor %q - C memory(%d bytes).\n", ts.name, nbytes)
}
2020-06-10 22:37:09 +01:00
return nil
}
// MustDrop drops the tensor. It will be panic if error
func (ts *Tensor) MustDrop() {
2020-06-10 22:37:09 +01:00
if err := ts.Drop(); err != nil {
log.Fatal(err)
}
}
// GradSetEnabled sets globally whether GradMode gradient accumulation is enable or not.
// It returns PREVIOUS state of Grad before setting.
func GradSetEnabled(b bool) (bool, error) {
var cbool, cretVal int
switch b {
case true:
cbool = 1
case false:
cbool = 0
}
var (
err error
state bool
)
cretVal = lib.AtGradSetEnabled(cbool)
if err = TorchErr(); err != nil {
return false, err
}
switch cretVal {
case 0:
state = false
break
case 1:
state = true
break
// case -1: // should be unreachable as error is captured above with TorchrErr()
// err = fmt.Errorf("Cannot set grad enable. \n")
// return retVal, err
// default: // should be unreachable as error is captured above with TorchrErr()
// err = fmt.Errorf("Cannot set grad enable. \n")
// return retVal, err
}
return state, nil
}
// MustGradSetEnabled sets globally whether GradMode gradient accumuation is enable or not.
// It returns PREVIOUS state of Grad before setting. It will be panic if error
func MustGradSetEnabled(b bool) bool {
state, err := GradSetEnabled(b)
if err != nil {
log.Fatal(err)
}
return state
}
// NoGrad runs a closure without keeping track of gradients.
2023-07-04 14:26:20 +01:00
func NoGrad(fn func(), sleepTimeOpt ...int) {
CleanUp(sleepTimeOpt...)
// Switch off Grad
2023-07-04 14:26:20 +01:00
MustGradSetEnabled(false)
2023-07-04 14:26:20 +01:00
fn()
// Switch on Grad
2023-07-04 14:26:20 +01:00
MustGradSetEnabled(true)
2023-07-04 14:26:20 +01:00
CleanUp(sleepTimeOpt...)
}
func NoGrad1(fn func() interface{}) interface{} {
newTs := NewTensor()
newTs.Drop()
// Switch off Grad
prev := MustGradSetEnabled(false)
retVal := fn()
// Switch on Grad
_ = MustGradSetEnabled(prev)
return retVal
}
// NoGradGuard is a RAII guard that prevents gradient tracking until deallocated.
// It actually sets a global flag that is checked by the backend whenever an op is done on a variable.
// The guard itself saved the current status and set it to false in the constructor.
// And restore the saved status in its destructor.
// That way it is similar to a with torch.no_grad(): block in python.
// Ref. https://discuss.pytorch.org/t/how-does-nogradguard-works-in-cpp/34960/2
//
// TODO: should we implement Go `mutex` here???
type NoGradGuard struct {
enabled bool
}
2020-06-17 11:29:31 +01:00
// Init NoGradGuard and disables gradient tracking
func NewNoGradGuard() *NoGradGuard {
2020-06-17 11:29:31 +01:00
return noGradGuardInit()
}
// Disables gradient tracking, this will be enabled back when the
// returned value gets deallocated.
func noGradGuardInit() *NoGradGuard {
return &NoGradGuard{enabled: MustGradSetEnabled(false)}
}
// Drop drops the NoGradGuard state.
2020-06-17 11:29:31 +01:00
func (ngg *NoGradGuard) Drop() {
ngg.enabled = true
_ = MustGradSetEnabled(ngg.enabled)
}
func (ngg *NoGradGuard) Enable() {
ngg.enabled = false
_ = MustGradSetEnabled(ngg.enabled)
}
const (
// Do not reduce
2022-03-12 04:47:15 +00:00
ReductionNone int64 = 0
// Mean of losses
2022-03-12 04:47:15 +00:00
ReductionMean int64 = 1
// Sum of losses
2022-03-12 04:47:15 +00:00
ReductionSum int64 = 2
// Escape hatch in case new options become available
2022-03-12 04:47:15 +00:00
ReductionOther int64 = 3
)
2022-03-12 04:47:15 +00:00
// func (r Reduction) ToInt() int {
// switch r {
// case ReductionNone:
// return 0
// case ReductionMean:
// return 1
// case ReductionSum:
// return 2
// case ReductionOther:
// return 3
// }
//
// // NOTE. should it be panic here instead of returning -1?
// return -1
// }
// Float64Values returns values of tensor in a slice of float64.
2022-03-12 04:47:15 +00:00
func (ts *Tensor) Float64Values(delOpt ...bool) []float64 {
del := false
if len(delOpt) > 0 {
del = delOpt[0]
}
2020-07-17 02:32:33 +01:00
numel := ts.Numel()
vec := make([]float64, numel)
float64Ts := ts.MustTotype(gotch.Double, false)
float64Ts.MustCopyData(vec, numel)
2023-07-04 14:26:20 +01:00
// float64Ts.MustDrop()
2020-07-17 02:32:33 +01:00
2022-03-12 04:47:15 +00:00
if del {
ts.MustDrop()
}
2020-07-17 02:32:33 +01:00
return vec
}
2020-06-29 08:56:11 +01:00
2020-07-25 10:02:48 +01:00
// Int64Values returns values of tensor in a slice of int64.
2022-03-12 04:47:15 +00:00
func (ts *Tensor) Int64Values(delOpt ...bool) []int64 {
del := false
if len(delOpt) > 0 {
del = delOpt[0]
}
2020-07-25 10:02:48 +01:00
numel := ts.Numel()
vec := make([]int64, numel)
int64Ts := ts.MustTotype(gotch.Int64, false)
int64Ts.MustCopyData(vec, numel)
int64Ts.MustDrop()
2022-03-12 04:47:15 +00:00
if del {
ts.MustDrop()
}
2020-07-25 10:02:48 +01:00
return vec
}
2020-07-10 06:28:37 +01:00
// Vals returns tensor values in a slice
// NOTE: need a type insersion to get runtime type
// E.g. res := xs.Vals().([]int64)
func (ts *Tensor) Vals() interface{} {
2020-07-10 06:28:37 +01:00
dtype := ts.DType()
numel := ts.Numel()
var retVal interface{}
2020-07-10 06:28:37 +01:00
switch dtype.Name() {
case "uint8":
retVal = make([]uint8, numel)
case "int8":
retVal = make([]int8, numel)
case "int16":
retVal = make([]int16, numel)
case "int32":
retVal = make([]int32, numel)
case "int64":
retVal = make([]int64, numel)
case "float32":
retVal = make([]float32, numel)
case "float64":
retVal = make([]float64, numel)
case "bool":
retVal = make([]bool, numel)
default:
log.Fatalf("Unsupported dtype (%v)", dtype)
}
ts.CopyData(retVal, numel)
return retVal
}
2020-06-29 08:56:11 +01:00
// FlatView flattens a tensor.
//
// This returns a flattened version of the given tensor. The first dimension
// is preserved as it is assumed to be the mini-batch dimension.
func (ts *Tensor) FlatView() *Tensor {
2020-06-29 08:56:11 +01:00
batchSize := ts.MustSize()[0]
return ts.MustView([]int64{batchSize, -1}, false)
}
2020-07-06 03:02:38 +01:00
func (ts *Tensor) ZeroPad2d(left, right, top, bottom int64, del bool) (*Tensor, error) {
2020-07-06 03:02:38 +01:00
if ts.Dim() != 4 {
err := fmt.Errorf("Expected a 4 dimension tensor, got %v\n", ts.MustSize())
return nil, err
2020-07-06 03:02:38 +01:00
}
return ts.ConstantPadNd([]int64{left, right, top, bottom}, del)
}
func (ts *Tensor) MustZeroPad2d(left, right, top, bottom int64, del bool) *Tensor {
2020-07-06 03:02:38 +01:00
retVal, err := ts.ZeroPad2d(left, right, top, bottom, del)
if err != nil {
log.Fatal(err)
}
return retVal
}
// Onehot converts a tensor to a one-hot encoded version.
//
// If the input has a size [N1, N2, ..., Nk], the returned tensor has a size
// [N1, ..., Nk, labels]. The returned tensor uses float values.
// Elements of the input vector are expected to be between 0 and labels-1.
//
// NOTE: There's other `ts.OneHot` and `ts.MustOneHot` generated from Atg C++ API
func (ts *Tensor) Onehot(labels int64) *Tensor {
dims := ts.MustSize()
dims = append(dims, labels)
unsqueezeTs := ts.MustUnsqueeze(-1, false)
inputTs := unsqueezeTs.MustTotype(gotch.Int64, true)
zerosTs := MustZeros(dims, gotch.Float, gotch.CPU)
2021-07-22 15:54:41 +01:00
retVal := zerosTs.MustScatterValue(-1, inputTs, FloatScalar(1.0), true)
inputTs.MustDrop()
return retVal
}
func (ts *Tensor) Swish() *Tensor {
2020-07-06 03:02:38 +01:00
sig := ts.MustSigmoid(false)
mulTs := ts.MustMul(sig, false)
2020-07-06 03:02:38 +01:00
sig.MustDrop()
return mulTs
2020-07-06 03:02:38 +01:00
}
2020-07-06 09:32:43 +01:00
func (ts *Tensor) AvgPool2DDefault(ksize int64, del bool) *Tensor {
return ts.MustAvgPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, false, true, []int64{1}, del)
2020-07-06 09:32:43 +01:00
}
// SaveMultiNew saves a slice of named tensors to the given file path.
func SaveMultiNew(namedTensors []NamedTensor, path string) error {
var (
tensors []lib.Ctensor
names []string
)
for _, nts := range namedTensors {
tensors = append(tensors, nts.Tensor.ctensor)
names = append(names, nts.Name)
}
lib.AtSaveMultiNew(tensors, names, path)
if err := TorchErr(); err != nil {
return err
}
return nil
}
2022-01-17 10:41:16 +00:00
func (ts *Tensor) ConstantPadNdWithVal(pad []int64, value *Scalar, del bool) (retVal *Tensor, err error) {
if del {
defer ts.MustDrop()
}
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtoConstantPadNd(ptr, ts.ctensor, pad, len(pad), value.cscalar)
if err = TorchErr(); err != nil {
return retVal, err
}
2023-07-04 14:26:20 +01:00
retVal = newTensor(*ptr)
2022-01-17 10:41:16 +00:00
return retVal, err
}
func (ts *Tensor) MustConstantPadNdWithVal(pad []int64, value *Scalar, del bool) (retVal *Tensor) {
retVal, err := ts.ConstantPadNdWithVal(pad, value, del)
if err != nil {
log.Fatal(err)
}
return retVal
}