2022-03-12 07:20:20 +00:00
|
|
|
|
package ts
|
2020-05-28 17:58:23 +01:00
|
|
|
|
|
2020-06-10 07:13:47 +01:00
|
|
|
|
//#include "stdlib.h"
|
|
|
|
|
//#include "stdbool.h"
|
2020-06-21 14:37:42 +01:00
|
|
|
|
//#include<stdio.h>
|
2020-05-28 17:58:23 +01:00
|
|
|
|
import "C"
|
|
|
|
|
|
|
|
|
|
import (
|
2020-06-02 04:07:35 +01:00
|
|
|
|
"bytes"
|
|
|
|
|
"encoding/binary"
|
2020-05-30 03:36:49 +01:00
|
|
|
|
"fmt"
|
2020-06-02 04:07:35 +01:00
|
|
|
|
"log"
|
2020-05-28 17:58:23 +01:00
|
|
|
|
"reflect"
|
2023-07-04 14:26:20 +01:00
|
|
|
|
"runtime"
|
|
|
|
|
"sync"
|
|
|
|
|
"sync/atomic"
|
|
|
|
|
"time"
|
2020-06-02 04:07:35 +01:00
|
|
|
|
"unsafe"
|
2020-05-28 17:58:23 +01:00
|
|
|
|
|
2020-05-30 00:04:47 +01:00
|
|
|
|
gotch "github.com/sugarme/gotch"
|
2020-05-28 17:58:23 +01:00
|
|
|
|
lib "github.com/sugarme/gotch/libtch"
|
|
|
|
|
)
|
|
|
|
|
|
2023-07-04 14:26:20 +01:00
|
|
|
|
var (
|
|
|
|
|
TensorCount int64 // incremental counting created tensors
|
|
|
|
|
ScalarCount int64 // incremental counting created scalars
|
|
|
|
|
AllocatedMem int64 // bytes - keeping track of memory created and still occupied by gotch/tensor (excluding mem allocated by libtorch at C side)
|
|
|
|
|
|
|
|
|
|
ExistingTensors map[string]struct{} = make(map[string]struct{}) // keep track of existing tensors by name
|
|
|
|
|
ExistingScalars map[string]struct{} = make(map[string]struct{}) // keep track of existing scalar by name
|
|
|
|
|
lock sync.Mutex
|
|
|
|
|
)
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
// NOTE. None is an undefined tensor.
|
2023-07-04 14:26:20 +01:00
|
|
|
|
// It can be used in optional tensor parameter where 'None' value used.
|
|
|
|
|
// `ts.MustDefined()` function is used for checking 'null'
|
|
|
|
|
var None = NewTensor()
|
|
|
|
|
|
|
|
|
|
type bigStruct struct {
|
2023-07-05 14:56:48 +01:00
|
|
|
|
lots [1e5]byte // 100k - always on host memory.
|
2023-07-04 14:26:20 +01:00
|
|
|
|
}
|
|
|
|
|
|
2023-07-23 03:36:55 +01:00
|
|
|
|
// Tensor is a Go wrapper of a "C tensor pointer" - 8 Bytes (64-bits OS)
|
|
|
|
|
// or 4 Bytes (32-bits OS).
|
|
|
|
|
// `ctensor` is just a "C pointer" to `torch::Tensor` (torch::Tensor *lib.Ctensor)
|
2023-07-04 14:26:20 +01:00
|
|
|
|
//
|
2023-07-23 03:36:55 +01:00
|
|
|
|
// NOTE.Tensor should be big enough to be in heap memory.
|
|
|
|
|
// (yes, we choose to place tensor consistently in heap memory so that
|
|
|
|
|
// it can be targeted by Go garbage collector).
|
|
|
|
|
//
|
|
|
|
|
// For heap allocation see. https://stackoverflow.com/questions/10866195
|
2020-05-28 17:58:23 +01:00
|
|
|
|
type Tensor struct {
|
2023-07-04 14:26:20 +01:00
|
|
|
|
d *bigStruct
|
|
|
|
|
name string
|
2020-06-04 07:23:53 +01:00
|
|
|
|
ctensor lib.Ctensor
|
2020-05-28 17:58:23 +01:00
|
|
|
|
}
|
|
|
|
|
|
2023-07-04 14:26:20 +01:00
|
|
|
|
func newTensor(ctensor lib.Ctensor, nameOpt ...string) *Tensor {
|
|
|
|
|
if len(nameOpt) == 0 {
|
|
|
|
|
nameOpt = []string{}
|
|
|
|
|
}
|
|
|
|
|
name := newName(nameOpt...)
|
|
|
|
|
|
|
|
|
|
x := new(Tensor)
|
|
|
|
|
x.ctensor = ctensor
|
|
|
|
|
x.name = name
|
|
|
|
|
x.d = new(bigStruct)
|
|
|
|
|
|
|
|
|
|
atomic.AddInt64(&TensorCount, 1)
|
|
|
|
|
nbytes := x.nbytes()
|
|
|
|
|
atomic.AddInt64(&AllocatedMem, nbytes)
|
|
|
|
|
lock.Lock()
|
|
|
|
|
ExistingTensors[name] = struct{}{}
|
|
|
|
|
lock.Unlock()
|
|
|
|
|
|
|
|
|
|
if gotch.Debug {
|
|
|
|
|
log.Printf("INFO: Added tensor %q - Allocated memory: %d bytes.\n", x.name, nbytes)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
runtime.SetFinalizer(x, freeCTensor)
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func CheckCMemLeak() string {
|
|
|
|
|
tensors := []string{}
|
2023-07-07 13:30:08 +01:00
|
|
|
|
lock.Lock()
|
2023-07-04 14:26:20 +01:00
|
|
|
|
for n := range ExistingTensors {
|
|
|
|
|
tensors = append(tensors, n)
|
|
|
|
|
}
|
2023-07-07 13:30:08 +01:00
|
|
|
|
memUsed := AllocatedMem
|
|
|
|
|
lock.Unlock()
|
|
|
|
|
|
2023-07-04 14:26:20 +01:00
|
|
|
|
var msg string
|
|
|
|
|
msg += fmt.Sprintf("============================= C MEMORY CHECK RESULT ==================================\n")
|
2023-07-07 13:30:08 +01:00
|
|
|
|
msg += fmt.Sprintf("C memory allocated not been released: %v bytes\n", memUsed)
|
2023-07-04 14:26:20 +01:00
|
|
|
|
msg += fmt.Sprintf("Tensors not been released: %q\n", tensors)
|
|
|
|
|
msg += fmt.Sprintf("======================================================================================\n")
|
2023-07-07 13:30:08 +01:00
|
|
|
|
|
2023-07-04 14:26:20 +01:00
|
|
|
|
return msg
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-07 13:30:08 +01:00
|
|
|
|
// CleanUp calls double runtime.GC() with sleep time in between.
|
2023-07-04 14:26:20 +01:00
|
|
|
|
func CleanUp(sleepTimeOpt ...int) {
|
|
|
|
|
sleepTime := time.Duration(1000) // 1 second
|
|
|
|
|
if len(sleepTimeOpt) > 0 {
|
|
|
|
|
sleepTime = time.Duration(sleepTimeOpt[0])
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
runtime.GC()
|
|
|
|
|
time.Sleep(time.Millisecond * sleepTime)
|
|
|
|
|
runtime.GC()
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
// Ctensor return C pointer value.
|
2022-02-16 00:39:27 +00:00
|
|
|
|
func (ts *Tensor) Ctensor() unsafe.Pointer {
|
|
|
|
|
return unsafe.Pointer(ts.ctensor)
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-04 14:26:20 +01:00
|
|
|
|
// free releases C allocated memory.
|
|
|
|
|
func freeCTensor(ts *Tensor) error {
|
|
|
|
|
lock.Lock()
|
|
|
|
|
defer lock.Unlock()
|
2023-07-07 13:30:08 +01:00
|
|
|
|
|
|
|
|
|
// Just return if it has been deleted previously!
|
|
|
|
|
if unsafe.Pointer(ts.ctensor) == nil {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
nbytes := ts.nbytes()
|
|
|
|
|
atomic.AddInt64(&AllocatedMem, -nbytes)
|
2023-07-04 14:26:20 +01:00
|
|
|
|
delete(ExistingTensors, ts.name)
|
|
|
|
|
|
|
|
|
|
lib.AtFree(ts.ctensor)
|
|
|
|
|
if err := TorchErr(); err != nil {
|
|
|
|
|
err := fmt.Errorf("ERROR: failed to release tensor %q - %w", ts.name, err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if gotch.Debug {
|
|
|
|
|
log.Printf("INFO: Released tensor %q - C memory(%d bytes).\n", ts.name, nbytes)
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-07 13:30:08 +01:00
|
|
|
|
// IMPORTANT. make it nil so won't double free.
|
|
|
|
|
ts.ctensor = nil
|
|
|
|
|
|
2023-07-04 14:26:20 +01:00
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func newName(nameOpt ...string) string {
|
|
|
|
|
var name string
|
|
|
|
|
if len(nameOpt) > 0 {
|
|
|
|
|
name = nameOpt[0]
|
|
|
|
|
} else {
|
2023-07-05 14:56:48 +01:00
|
|
|
|
name = fmt.Sprintf("tensor_%09d", TensorCount)
|
2023-07-04 14:26:20 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return name
|
|
|
|
|
}
|
2020-08-05 01:31:01 +01:00
|
|
|
|
|
2020-05-28 17:58:23 +01:00
|
|
|
|
// NewTensor creates a new tensor
|
2023-07-04 14:26:20 +01:00
|
|
|
|
func NewTensor(nameOpt ...string) *Tensor {
|
2020-05-28 17:58:23 +01:00
|
|
|
|
ctensor := lib.AtNewTensor()
|
2023-07-04 14:26:20 +01:00
|
|
|
|
|
|
|
|
|
return newTensor(ctensor, nameOpt...)
|
2020-05-28 17:58:23 +01:00
|
|
|
|
}
|
|
|
|
|
|
2022-02-16 00:39:27 +00:00
|
|
|
|
func FromCtensor(ctensor unsafe.Pointer) *Tensor {
|
|
|
|
|
cts := (lib.Ctensor)(ctensor)
|
2023-07-04 14:26:20 +01:00
|
|
|
|
|
|
|
|
|
return newTensor(cts)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (ts *Tensor) Name() string {
|
|
|
|
|
return ts.name
|
2022-02-16 00:39:27 +00:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Dim() uint64 {
|
|
|
|
|
dim := lib.AtDim(ts.ctensor)
|
2020-06-03 02:03:38 +01:00
|
|
|
|
if err := TorchErr(); err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return dim
|
2020-06-01 08:37:05 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-06-02 04:07:35 +01:00
|
|
|
|
// Size return shape of the tensor
|
|
|
|
|
//
|
|
|
|
|
// NOTE: C++ libtorch calls at_shape() -> t.sizes()
|
|
|
|
|
// And returns a slice of sizes or shape using given pointer
|
|
|
|
|
// to that slice.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Size() ([]int64, error) {
|
2020-06-01 08:37:05 +01:00
|
|
|
|
dim := lib.AtDim(ts.ctensor)
|
2020-06-02 04:07:35 +01:00
|
|
|
|
sz := make([]int64, dim)
|
|
|
|
|
szPtr, err := DataAsPtr(sz)
|
|
|
|
|
if err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-06-02 04:07:35 +01:00
|
|
|
|
}
|
|
|
|
|
defer C.free(unsafe.Pointer(szPtr))
|
|
|
|
|
|
|
|
|
|
lib.AtShape(ts.ctensor, szPtr)
|
2020-06-06 04:20:00 +01:00
|
|
|
|
if err = TorchErr(); err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-06-06 04:20:00 +01:00
|
|
|
|
}
|
2020-06-02 04:07:35 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
shape := decodeSize(szPtr, dim)
|
2020-07-22 06:56:30 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return shape, nil
|
2020-06-06 04:20:00 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) MustSize() []int64 {
|
|
|
|
|
shape, err := ts.Size()
|
2020-06-08 04:28:07 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
2020-07-22 06:56:30 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return shape
|
2020-06-08 04:28:07 +01:00
|
|
|
|
}
|
|
|
|
|
|
2023-07-23 03:17:27 +01:00
|
|
|
|
func (ts *Tensor) Stride() ([]int64, error) {
|
|
|
|
|
dim := lib.AtDim(ts.ctensor)
|
|
|
|
|
sz := make([]int64, dim)
|
|
|
|
|
szPtr, err := DataAsPtr(sz)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
defer C.free(unsafe.Pointer(szPtr))
|
|
|
|
|
|
|
|
|
|
lib.AtStride(ts.ctensor, szPtr)
|
|
|
|
|
if err = TorchErr(); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
strides := decodeSize(szPtr, dim)
|
|
|
|
|
|
|
|
|
|
return strides, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (ts *Tensor) MustStride() []int64 {
|
|
|
|
|
strides, err := ts.Stride()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return strides
|
|
|
|
|
}
|
|
|
|
|
|
2020-06-06 04:20:00 +01:00
|
|
|
|
// Size1 returns the tensor size for 1D tensors.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Size1() (int64, error) {
|
2020-06-06 04:20:00 +01:00
|
|
|
|
shape, err := ts.Size()
|
|
|
|
|
if err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return 0, err
|
2020-06-06 04:20:00 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-06-02 04:07:35 +01:00
|
|
|
|
if len(shape) != 1 {
|
|
|
|
|
err = fmt.Errorf("Expected one dim, got %v\n", len(shape))
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return 0, err
|
2020-06-02 04:07:35 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return shape[0], nil
|
2020-06-01 08:37:05 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-06-02 04:07:35 +01:00
|
|
|
|
// Size2 returns the tensor size for 2D tensors.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Size2() ([]int64, error) {
|
2020-06-06 04:20:00 +01:00
|
|
|
|
shape, err := ts.Size()
|
|
|
|
|
if err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-06-06 04:20:00 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-06-02 04:07:35 +01:00
|
|
|
|
if len(shape) != 2 {
|
|
|
|
|
err = fmt.Errorf("Expected two dims, got %v\n", len(shape))
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-06-02 04:07:35 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return shape, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Size3 returns the tensor size for 3D tensors.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Size3() ([]int64, error) {
|
2020-06-06 04:20:00 +01:00
|
|
|
|
shape, err := ts.Size()
|
|
|
|
|
if err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-06-06 04:20:00 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-06-02 04:07:35 +01:00
|
|
|
|
if len(shape) != 3 {
|
|
|
|
|
err = fmt.Errorf("Expected three dims, got %v\n", len(shape))
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-06-02 04:07:35 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return shape, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Size4 returns the tensor size for 4D tensors.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Size4() ([]int64, error) {
|
2020-06-06 04:20:00 +01:00
|
|
|
|
shape, err := ts.Size()
|
|
|
|
|
if err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-06-06 04:20:00 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-06-02 04:07:35 +01:00
|
|
|
|
if len(shape) != 4 {
|
|
|
|
|
err = fmt.Errorf("Expected four dims, got %v\n", len(shape))
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-06-02 04:07:35 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return shape, nil
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-04 14:26:20 +01:00
|
|
|
|
// nbytes calculates tensor data size in bytes.
|
|
|
|
|
func (ts *Tensor) nbytes() int64 {
|
|
|
|
|
numel := ts.Numel()
|
|
|
|
|
if numel == 0 {
|
|
|
|
|
return 0 // ts.None
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
|
return int64(numel * ts.DType().Size())
|
2023-07-04 14:26:20 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-06-02 04:07:35 +01:00
|
|
|
|
func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
|
2023-07-06 15:01:23 +01:00
|
|
|
|
dtype := gotch.Int64 // tensor size dtype = int64
|
|
|
|
|
nbytes := int(nsize) * int(dtype.Size())
|
2020-06-02 04:07:35 +01:00
|
|
|
|
dataSlice := (*[1 << 30]byte)(ptr)[:nbytes:nbytes]
|
|
|
|
|
r := bytes.NewReader(dataSlice)
|
|
|
|
|
dataIn := make([]int64, nsize)
|
|
|
|
|
if err := binary.Read(r, nativeEndian, dataIn); err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return dataIn
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
|
// TensorOptions constructs options to build/rebuild tensor.
|
|
|
|
|
type TensorOptions struct {
|
|
|
|
|
Name string
|
|
|
|
|
DType gotch.DType
|
|
|
|
|
Quantized bool
|
|
|
|
|
// TODO. can expand as needed
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type TensorOpt func(*TensorOptions)
|
|
|
|
|
|
|
|
|
|
func DefaultTensorOptions() *TensorOptions {
|
|
|
|
|
return &TensorOptions{
|
|
|
|
|
Name: "",
|
|
|
|
|
DType: gotch.Float,
|
|
|
|
|
Quantized: false,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func WithName(v string) TensorOpt {
|
|
|
|
|
return func(o *TensorOptions) {
|
|
|
|
|
o.Name = v
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func WithDType(v gotch.DType) TensorOpt {
|
|
|
|
|
return func(o *TensorOptions) {
|
|
|
|
|
o.DType = v
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func WithQuantized(v bool) TensorOpt {
|
|
|
|
|
return func(o *TensorOptions) {
|
|
|
|
|
o.Quantized = v
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2020-06-03 03:07:08 +01:00
|
|
|
|
// OfSlice creates tensor from a slice data
|
2023-07-06 15:01:23 +01:00
|
|
|
|
func OfSlice(data interface{}, opts ...TensorOpt) (*Tensor, error) {
|
|
|
|
|
o := DefaultTensorOptions()
|
|
|
|
|
for _, opt := range opts {
|
|
|
|
|
opt(o)
|
|
|
|
|
}
|
|
|
|
|
|
2022-02-13 11:46:50 +00:00
|
|
|
|
// convert []int -> int32. `binary.Write()` can't write `[]int` because it's not fixed-size!
|
|
|
|
|
if reflect.TypeOf(data).String() == "[]int" {
|
|
|
|
|
data = sliceIntToInt32(data.([]int))
|
|
|
|
|
}
|
2020-05-28 17:58:23 +01:00
|
|
|
|
|
2022-02-23 23:47:15 +00:00
|
|
|
|
v := reflect.ValueOf(data)
|
|
|
|
|
kind := v.Kind().String()
|
|
|
|
|
if kind != "slice" && kind != "array" {
|
|
|
|
|
err := fmt.Errorf("Expected slice data. Got %q", kind)
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-06-03 03:07:08 +01:00
|
|
|
|
}
|
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
|
elementKind := reflect.TypeOf(data).Elem().Kind()
|
2022-02-23 23:47:15 +00:00
|
|
|
|
dataLen := v.Len()
|
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
|
dtype, err := gotch.GoKind2DType(elementKind, gotch.HalfDTypePref(o.DType), gotch.WithQuantized(o.Quantized))
|
2020-06-03 03:07:08 +01:00
|
|
|
|
if err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-05-30 03:36:49 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-05-30 00:04:47 +01:00
|
|
|
|
shape := []int64{int64(dataLen)}
|
|
|
|
|
elementNum := ElementCount(shape)
|
2020-05-30 02:15:36 +01:00
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
|
nbytes := int(dtype.Size()) * elementNum
|
2020-05-28 17:58:23 +01:00
|
|
|
|
|
2020-05-30 00:04:47 +01:00
|
|
|
|
dataPtr, buff := CMalloc(nbytes)
|
2020-06-06 04:20:00 +01:00
|
|
|
|
defer C.free(unsafe.Pointer(dataPtr))
|
2020-05-28 17:58:23 +01:00
|
|
|
|
|
2020-05-30 00:04:47 +01:00
|
|
|
|
if err = EncodeTensor(buff, reflect.ValueOf(data), shape); err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-05-30 00:04:47 +01:00
|
|
|
|
}
|
2020-05-28 17:58:23 +01:00
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
|
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(dtype.Size()), int(dtype.CKind()))
|
2020-06-06 04:20:00 +01:00
|
|
|
|
if err = TorchErr(); err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-06-06 04:20:00 +01:00
|
|
|
|
}
|
2020-05-28 17:58:23 +01:00
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
|
return newTensor(ctensor, o.Name), nil
|
2023-07-04 14:26:20 +01:00
|
|
|
|
// return newTensor(ctensor), nil
|
2020-05-28 17:58:23 +01:00
|
|
|
|
}
|
2020-05-30 00:04:47 +01:00
|
|
|
|
|
2020-11-18 02:07:08 +00:00
|
|
|
|
// OfDataSize creates Tensor from input byte data, shape and dtype.
|
2023-07-06 15:01:23 +01:00
|
|
|
|
func OfDataSize(data []byte, shape []int64, dtype gotch.DType, opts ...TensorOpt) (*Tensor, error) {
|
|
|
|
|
o := DefaultTensorOptions()
|
|
|
|
|
for _, opt := range opts {
|
|
|
|
|
opt(o)
|
|
|
|
|
}
|
2020-11-16 12:37:44 +00:00
|
|
|
|
|
2020-11-17 07:31:29 +00:00
|
|
|
|
elementNum := ElementCount(shape)
|
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
|
nbytes := elementNum * int(dtype.Size())
|
2020-11-17 07:31:29 +00:00
|
|
|
|
|
|
|
|
|
if nbytes != len(data) {
|
|
|
|
|
err := fmt.Errorf("data and shape mismatched for dtype (%v): byte data (%v) - shape (%v).\n", dtype, len(data), shape)
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dataPtr, buff := CMalloc(nbytes)
|
|
|
|
|
defer C.free(unsafe.Pointer(dataPtr))
|
|
|
|
|
|
2020-11-18 02:07:08 +00:00
|
|
|
|
if err := binary.Write(buff, nativeEndian, data); err != nil {
|
2020-11-17 07:31:29 +00:00
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
|
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), dtype.Size(), int(dtype.CKind()))
|
|
|
|
|
if err := TorchErr(); err != nil {
|
2020-11-17 07:31:29 +00:00
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
|
return newTensor(ctensor, o.Name), nil
|
2023-07-04 14:26:20 +01:00
|
|
|
|
// return newTensor(ctensor), nil
|
2020-11-16 12:37:44 +00:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// MustOfDataSize create Tensor from input byte data and specified shape and dtype
|
|
|
|
|
// or panic if error
|
2023-07-06 15:01:23 +01:00
|
|
|
|
func MustOfDataSize(data []byte, size []int64, dtype gotch.DType, opts ...TensorOpt) *Tensor {
|
|
|
|
|
ts, err := OfDataSize(data, size, dtype, opts...)
|
2020-11-16 12:37:44 +00:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ts
|
|
|
|
|
}
|
|
|
|
|
|
2020-06-10 07:13:47 +01:00
|
|
|
|
// MustOfSlice create a tensor from slice of data. It will be panic if error.
|
2023-07-06 15:01:23 +01:00
|
|
|
|
func MustOfSlice(data interface{}, opts ...TensorOpt) *Tensor {
|
|
|
|
|
ts, err := OfSlice(data, opts...)
|
2020-06-10 07:13:47 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return ts
|
2020-06-10 07:13:47 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TensorFrom create a tensor from slice of data. It will be panic if error.
|
2023-07-06 15:01:23 +01:00
|
|
|
|
func TensorFrom(data interface{}, opts ...TensorOpt) *Tensor {
|
|
|
|
|
ts, err := OfSlice(data, opts...)
|
2020-06-07 22:31:07 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return ts
|
2020-06-07 22:31:07 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-05-30 02:15:36 +01:00
|
|
|
|
// Print prints tensor values to console.
|
|
|
|
|
//
|
|
|
|
|
// NOTE: it is printed from C and will print ALL elements of tensor
|
|
|
|
|
// with no truncation at all.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Print() {
|
2020-05-30 00:04:47 +01:00
|
|
|
|
lib.AtPrint(ts.ctensor)
|
2020-06-06 04:20:00 +01:00
|
|
|
|
if err := TorchErr(); err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
2020-05-30 00:04:47 +01:00
|
|
|
|
}
|
2020-05-30 06:39:56 +01:00
|
|
|
|
|
|
|
|
|
// NewTensorFromData creates tensor from given data and shape
|
2023-07-06 15:01:23 +01:00
|
|
|
|
func NewTensorFromData(data interface{}, shape []int64, opts ...TensorOpt) (*Tensor, error) {
|
|
|
|
|
o := DefaultTensorOptions()
|
|
|
|
|
for _, opt := range opts {
|
|
|
|
|
opt(o)
|
|
|
|
|
}
|
|
|
|
|
|
2020-05-30 06:39:56 +01:00
|
|
|
|
// 1. Check whether data and shape match
|
|
|
|
|
elementNum, err := DataDim(data)
|
|
|
|
|
if err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-05-30 06:39:56 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
nflattend := FlattenDim(shape)
|
|
|
|
|
|
|
|
|
|
if elementNum != nflattend {
|
2020-06-01 06:45:25 +01:00
|
|
|
|
err = fmt.Errorf("Number of data elements (%v) and flatten shape (%v) dimension mismatched.\n", elementNum, nflattend)
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-05-30 06:39:56 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 2. Write raw data to C memory and get C pointer
|
|
|
|
|
dataPtr, err := DataAsPtr(data)
|
2020-06-06 04:20:00 +01:00
|
|
|
|
defer C.free(unsafe.Pointer(dataPtr))
|
2020-05-30 06:39:56 +01:00
|
|
|
|
if err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-05-30 06:39:56 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 3. Create tensor with pointer and shape
|
|
|
|
|
dtype, err := gotch.DTypeFromData(data)
|
|
|
|
|
if err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-05-30 06:39:56 +01:00
|
|
|
|
}
|
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
|
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), dtype.Size(), int(dtype.CKind()))
|
2020-06-06 04:20:00 +01:00
|
|
|
|
if err = TorchErr(); err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return nil, err
|
2020-06-06 04:20:00 +01:00
|
|
|
|
}
|
2020-05-30 06:39:56 +01:00
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
|
return newTensor(ctensor, o.Name), nil
|
2020-05-30 06:39:56 +01:00
|
|
|
|
}
|
2020-06-02 10:29:24 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) DType() gotch.DType {
|
2020-06-02 10:29:24 +01:00
|
|
|
|
cint := lib.AtScalarType(ts.ctensor)
|
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
|
return gotch.CKind2DType(cint)
|
2020-06-02 10:29:24 +01:00
|
|
|
|
}
|
2020-06-04 04:36:20 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Device() (gotch.Device, error) {
|
|
|
|
|
var (
|
|
|
|
|
retVal gotch.Device
|
|
|
|
|
err error
|
|
|
|
|
)
|
2020-06-06 09:12:42 +01:00
|
|
|
|
cInt := lib.AtDevice(ts.ctensor)
|
|
|
|
|
|
|
|
|
|
if err = TorchErr(); err != nil {
|
|
|
|
|
return retVal, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var device gotch.Device
|
|
|
|
|
|
|
|
|
|
return device.OfCInt(int32(cInt)), nil
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) MustDevice() gotch.Device {
|
|
|
|
|
device, err := ts.Device()
|
2020-08-01 07:33:30 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return device
|
2020-08-01 07:33:30 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-07-22 06:56:30 +01:00
|
|
|
|
/*
|
|
|
|
|
* func (ts Tensor) Eq1(other Tensor, del bool) (retVal Tensor, err error) {
|
|
|
|
|
*
|
|
|
|
|
* // Get a C null pointer
|
|
|
|
|
* // https://stackoverflow.com/a/2022369
|
|
|
|
|
* ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
|
|
|
* if del {
|
|
|
|
|
* defer ts.MustDrop()
|
|
|
|
|
* }
|
|
|
|
|
*
|
|
|
|
|
* lib.AtgEq1(ptr, ts.ctensor, other.ctensor)
|
|
|
|
|
* if err = TorchErr(); err != nil {
|
|
|
|
|
* return retVal, err
|
|
|
|
|
* }
|
|
|
|
|
*
|
|
|
|
|
* return Tensor{ctensor: *ptr}, nil
|
|
|
|
|
*
|
|
|
|
|
* }
|
|
|
|
|
*
|
|
|
|
|
* func (ts Tensor) MustEq1(other Tensor, del bool) (retVal Tensor) {
|
|
|
|
|
* retVal, err := ts.Eq1(other, del)
|
|
|
|
|
* if err != nil {
|
|
|
|
|
* log.Fatal(err)
|
|
|
|
|
* }
|
|
|
|
|
*
|
|
|
|
|
* return retVal
|
|
|
|
|
* }
|
|
|
|
|
* */
|
2023-07-06 15:01:23 +01:00
|
|
|
|
|
2020-06-14 03:51:38 +01:00
|
|
|
|
// Float64Value returns a float value on tensors holding a single element.
|
2020-06-06 09:12:42 +01:00
|
|
|
|
// An error is returned otherwise.
|
|
|
|
|
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Float64Value(idx []int64) (float64, error) {
|
2020-06-06 09:12:42 +01:00
|
|
|
|
|
|
|
|
|
idxPtr, err := DataAsPtr(idx)
|
|
|
|
|
if err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return 0, err
|
2020-06-06 09:12:42 +01:00
|
|
|
|
}
|
|
|
|
|
defer C.free(unsafe.Pointer(idxPtr))
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
f64Val := lib.AtDoubleValueAtIndexes(ts.ctensor, idxPtr, len(idx))
|
2020-06-06 09:12:42 +01:00
|
|
|
|
if err = TorchErr(); err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return 0, err
|
2020-06-06 09:12:42 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return f64Val, err
|
2020-06-06 09:12:42 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) MustFloat64Value(idx []int64) float64 {
|
|
|
|
|
f64Val, err := ts.Float64Value(idx)
|
2020-06-15 16:59:41 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return f64Val
|
2020-06-15 16:59:41 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-06-06 09:12:42 +01:00
|
|
|
|
// Int64Value returns an int value on tensors holding a single element. An error is
|
|
|
|
|
// returned otherwise.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Int64Value(idx []int64) (int64, error) {
|
2020-06-06 09:12:42 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
var (
|
|
|
|
|
retVal int64
|
|
|
|
|
err error
|
|
|
|
|
)
|
2020-06-06 09:12:42 +01:00
|
|
|
|
idxPtr, err := DataAsPtr(idx)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return retVal, err
|
|
|
|
|
}
|
|
|
|
|
defer C.free(unsafe.Pointer(idxPtr))
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
int64Val := lib.AtInt64ValueAtIndexes(ts.ctensor, idxPtr, len(idx))
|
2020-06-06 09:12:42 +01:00
|
|
|
|
if err = TorchErr(); err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return 0, err
|
2020-06-06 09:12:42 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return int64Val, err
|
2020-06-06 09:12:42 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) MustInt64Value(idx []int64) int64 {
|
|
|
|
|
int64Val, err := ts.Int64Value(idx)
|
2020-06-30 11:01:01 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return int64Val
|
2020-06-30 11:01:01 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-06-06 09:12:42 +01:00
|
|
|
|
// RequiresGrad returns true if gradient are currently tracked for this tensor.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) RequiresGrad() (bool, error) {
|
|
|
|
|
state := lib.AtRequiresGrad(ts.ctensor)
|
2020-06-06 09:12:42 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
if err := TorchErr(); err != nil {
|
|
|
|
|
return false, err
|
2020-06-06 09:12:42 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return state, nil
|
2020-06-06 09:12:42 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) MustRequiresGrad() bool {
|
|
|
|
|
state, err := ts.RequiresGrad()
|
2020-06-17 07:24:27 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return state
|
2020-06-17 07:24:27 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-06-06 09:12:42 +01:00
|
|
|
|
// DataPtr returns the address of the first element of this tensor.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) DataPtr() (unsafe.Pointer, error) {
|
2020-06-06 09:12:42 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
datPtr := lib.AtDataPtr(ts.ctensor)
|
2020-06-06 09:12:42 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
if err := TorchErr(); err != nil {
|
|
|
|
|
return nil, err
|
2020-06-06 09:12:42 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return datPtr, nil
|
2020-06-06 09:12:42 +01:00
|
|
|
|
}
|
2020-06-07 22:31:07 +01:00
|
|
|
|
|
2023-07-23 04:10:24 +01:00
|
|
|
|
func (ts *Tensor) MustDataPtr() unsafe.Pointer {
|
|
|
|
|
p, err := ts.DataPtr()
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return p
|
|
|
|
|
}
|
|
|
|
|
|
2020-06-07 22:31:07 +01:00
|
|
|
|
// Defined returns true is the tensor is defined.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Defined() (bool, error) {
|
|
|
|
|
state := lib.AtDefined(ts.ctensor)
|
2020-06-07 22:31:07 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
if err := TorchErr(); err != nil {
|
|
|
|
|
return false, err
|
2020-06-07 22:31:07 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return state, nil
|
2020-06-07 22:31:07 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) MustDefined() bool {
|
|
|
|
|
state, err := ts.Defined()
|
2020-06-07 22:31:07 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return state
|
2020-06-07 22:31:07 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// IsSparse returns true is the tensor is spare.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) IsSparse() (bool, error) {
|
|
|
|
|
state := lib.AtIsSparse(ts.ctensor)
|
2020-06-07 22:31:07 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
if err := TorchErr(); err != nil {
|
|
|
|
|
return false, err
|
2020-06-07 22:31:07 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return state, nil
|
2020-06-07 22:31:07 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ZeroGrad zeroes the gradient tensor attached to this tensor if defined.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) ZeroGrad() {
|
2020-07-22 06:56:30 +01:00
|
|
|
|
grad := ts.MustGrad(false)
|
2020-06-07 22:31:07 +01:00
|
|
|
|
if grad.MustDefined() {
|
2020-06-21 01:57:29 +01:00
|
|
|
|
grad.Detach_()
|
|
|
|
|
grad.Zero_()
|
2020-06-07 22:31:07 +01:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Backward runs the backward pass, populating the gradient tensors for tensors
|
|
|
|
|
// which gradients are tracked.
|
|
|
|
|
//
|
|
|
|
|
// Gradients tracking can be turned on via `SetRequiresGrad`.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Backward() error {
|
2020-06-07 22:31:07 +01:00
|
|
|
|
lib.AtBackward(ts.ctensor, 0, 0)
|
2020-10-31 08:25:32 +00:00
|
|
|
|
if err := TorchErr(); err != nil {
|
2020-06-07 22:31:07 +01:00
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) MustBackward() {
|
2020-06-07 22:31:07 +01:00
|
|
|
|
if err := ts.Backward(); err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// RunBackward runs the backward ...
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func RunBackward(tensors []*Tensor, inputs []*Tensor, keepGraphB bool, createGraphB bool) ([]*Tensor, error) {
|
2020-06-07 22:31:07 +01:00
|
|
|
|
// NOTE: outputs is a slice of tensors with length = len(inputs)
|
|
|
|
|
var outputsPtr []*lib.Ctensor
|
2020-10-31 08:25:32 +00:00
|
|
|
|
// Are they allocated contigously??? Definitely not.
|
|
|
|
|
// TODO. calculate C memory size = C pointer size x n pointers
|
|
|
|
|
// Then C.malloc such calculated amount
|
|
|
|
|
// NOTE. replace with the following code and test.
|
|
|
|
|
/*
|
|
|
|
|
* ntensors := len(inputs)
|
|
|
|
|
* nbytes := C.size_t(ntensors) * C.size_t(unsafe.Sizeof(uintptr(0)))
|
|
|
|
|
* ctensorsPtr := (*[1 << 30]lib.Ctensor)(C.malloc(nbytes))
|
|
|
|
|
* for i :=0; i < ntensors; i++ {
|
|
|
|
|
* outputPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
|
|
|
* outputsPtr[i] = outputPtr
|
|
|
|
|
* }
|
|
|
|
|
* */
|
2020-06-07 22:31:07 +01:00
|
|
|
|
for i := 0; i < len(inputs); i++ {
|
|
|
|
|
outputPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
2020-06-17 07:45:59 +01:00
|
|
|
|
defer C.free(unsafe.Pointer(outputPtr))
|
2020-06-07 22:31:07 +01:00
|
|
|
|
outputsPtr = append(outputsPtr, outputPtr)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get first element pointer
|
|
|
|
|
ctensor := tensors[0].ctensor
|
|
|
|
|
cinput := inputs[0].ctensor
|
|
|
|
|
tensorsPtr := (*lib.Ctensor)(unsafe.Pointer(&ctensor))
|
|
|
|
|
inputsPtr := (*lib.Ctensor)(unsafe.Pointer(&cinput))
|
|
|
|
|
var keepGraph int = 0
|
|
|
|
|
if keepGraphB {
|
|
|
|
|
keepGraph = 1
|
|
|
|
|
}
|
|
|
|
|
var createGraph int = 0
|
|
|
|
|
if createGraphB {
|
|
|
|
|
createGraph = 1
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
lib.AtRunBackward(tensorsPtr, len(tensors), inputsPtr, len(inputs), outputsPtr[0], keepGraph, createGraph)
|
2020-10-31 08:25:32 +00:00
|
|
|
|
if err := TorchErr(); err != nil {
|
|
|
|
|
return nil, err
|
2020-06-07 22:31:07 +01:00
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
var oTensors []*Tensor
|
2020-06-07 22:31:07 +01:00
|
|
|
|
for i := 0; i < len(inputs); i++ {
|
|
|
|
|
outputPtr := outputsPtr[i]
|
2023-07-05 14:56:48 +01:00
|
|
|
|
oTensors = append(oTensors, newTensor(*outputPtr))
|
2020-06-07 22:31:07 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return oTensors, nil
|
2020-06-07 22:31:07 +01:00
|
|
|
|
}
|
2020-06-08 04:28:07 +01:00
|
|
|
|
|
|
|
|
|
// CopyDataUint8 copies `numel` elements from `self` to `dst`.
|
|
|
|
|
//
|
|
|
|
|
// NOTE: `dst` located in Go memory. Should it be?
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) CopyDataUint8(dst []uint8, numel uint) error {
|
2020-06-08 04:28:07 +01:00
|
|
|
|
// NOTE: we must make sure that `dst` has same len as `numel`. Otherwise,
|
|
|
|
|
// there will be memory leak and or out of range error.
|
|
|
|
|
if len(dst) < int(numel) {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
err := fmt.Errorf("CopyDataUint8 Error: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", len(dst), numel)
|
2020-06-08 04:28:07 +01:00
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
vs := unsafe.Pointer(&dst[0])
|
2023-07-06 15:01:23 +01:00
|
|
|
|
dtype := gotch.Uint8
|
|
|
|
|
lib.AtCopyData(ts.ctensor, vs, numel, dtype.Size())
|
|
|
|
|
if err := TorchErr(); err != nil {
|
2020-06-08 04:28:07 +01:00
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) MustCopyDataUint8(dst []uint8, numel uint) {
|
2020-06-08 04:28:07 +01:00
|
|
|
|
err := ts.CopyDataUint8(dst, numel)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// CopyData copies `numel` elements from `self` to `dst`.
|
|
|
|
|
// `dst` should be a slice of Go type equivalent to tensor type.
|
|
|
|
|
//
|
|
|
|
|
// NOTE: `dst` located in Go memory. Should it be?
|
2020-07-17 02:22:04 +01:00
|
|
|
|
// We will render Go pointer of first element of `dst` slice
|
|
|
|
|
// and number of elements to C land. This may break in the future
|
|
|
|
|
// if Go policy changes.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) CopyData(dst interface{}, numel uint) error {
|
2023-07-06 15:01:23 +01:00
|
|
|
|
dtype, dlen, err := DataCheck(dst)
|
2020-06-08 04:28:07 +01:00
|
|
|
|
if dlen < int(numel) {
|
2023-07-06 15:01:23 +01:00
|
|
|
|
err = fmt.Errorf("ts.CopyData() failed: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel)
|
2020-06-08 04:28:07 +01:00
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if ts.DType() != dtype {
|
2023-07-06 15:01:23 +01:00
|
|
|
|
err = fmt.Errorf("ts.CopyData() failed: Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
|
2020-06-08 04:28:07 +01:00
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
|
// Get data pointer
|
|
|
|
|
dataPtr := reflect.ValueOf(dst).UnsafePointer()
|
2020-06-08 04:28:07 +01:00
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
|
lib.AtCopyData(ts.ctensor, dataPtr, numel, dtype.Size())
|
2020-06-08 04:28:07 +01:00
|
|
|
|
if err = TorchErr(); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
2020-06-08 05:37:37 +01:00
|
|
|
|
// MustCopyData copies number of elements from tensor to a slice of data
|
|
|
|
|
//
|
|
|
|
|
// NOTE: `dst` is a slice with length = numel and Go type equavalent to tensor
|
|
|
|
|
// DType
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) MustCopyData(dst interface{}, numel uint) {
|
2020-06-08 04:28:07 +01:00
|
|
|
|
err := ts.CopyData(dst, numel)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Numel returns the total number of elements stored in a tensor.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Numel() uint {
|
2023-07-04 14:26:20 +01:00
|
|
|
|
if !ts.MustDefined() {
|
|
|
|
|
return 0 // ts.None case
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
shape := ts.MustSize()
|
2020-06-08 04:28:07 +01:00
|
|
|
|
return uint(FlattenDim(shape))
|
|
|
|
|
}
|
2020-06-08 07:13:23 +01:00
|
|
|
|
|
2020-07-28 07:08:40 +01:00
|
|
|
|
// ShallowClone returns a new tensor that share storage with the input tensor.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) ShallowClone() (*Tensor, error) {
|
2020-06-08 07:13:23 +01:00
|
|
|
|
ctensor := lib.AtShallowClone(ts.ctensor)
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
if err := TorchErr(); err != nil {
|
|
|
|
|
return nil, err
|
2020-06-08 07:13:23 +01:00
|
|
|
|
}
|
|
|
|
|
|
2023-07-04 14:26:20 +01:00
|
|
|
|
name := fmt.Sprintf("%s_cloned", ts.name)
|
|
|
|
|
return newTensor(ctensor, name), nil
|
2020-06-08 07:13:23 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// MustShallowClone returns a new tensor that share storage with the input
|
|
|
|
|
// tensor. It will panic if error occurred
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) MustShallowClone() *Tensor {
|
|
|
|
|
newTs, err := ts.ShallowClone()
|
2020-06-08 07:13:23 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return newTs
|
2020-06-08 07:13:23 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get gets the sub-tensor at the given index.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Get(index int) (*Tensor, error) {
|
2020-06-08 07:13:23 +01:00
|
|
|
|
|
|
|
|
|
ctensor := lib.AtGet(ts.ctensor, index)
|
2020-10-31 08:25:32 +00:00
|
|
|
|
if err := TorchErr(); err != nil {
|
|
|
|
|
return nil, err
|
2020-06-08 07:13:23 +01:00
|
|
|
|
}
|
|
|
|
|
|
2023-07-04 14:26:20 +01:00
|
|
|
|
return newTensor(ctensor), nil
|
2020-06-08 07:13:23 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// MustGet gets the sub-tensor at the given index. It will panic if error
|
|
|
|
|
// occurred.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) MustGet(index int) *Tensor {
|
|
|
|
|
|
|
|
|
|
subTs, err := ts.Get(index)
|
2020-06-08 07:13:23 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
2020-10-31 08:25:32 +00:00
|
|
|
|
|
|
|
|
|
return subTs
|
2020-06-08 07:13:23 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Copy_ copies in-place values from the argument tensor to the input tensor.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func Copy_(self, src *Tensor) {
|
2020-06-08 07:13:23 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
lib.AtCopy_(self.ctensor, src.ctensor)
|
|
|
|
|
if err := TorchErr(); err != nil {
|
2020-06-14 03:51:38 +01:00
|
|
|
|
log.Fatal(err)
|
2020-06-08 07:13:23 +01:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2020-06-14 03:51:38 +01:00
|
|
|
|
// Copy_ copies in-place values from the argument tensor to existing tensor
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Copy_(src *Tensor) {
|
2020-06-14 03:51:38 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
lib.AtCopy_(ts.ctensor, src.ctensor)
|
|
|
|
|
if err := TorchErr(); err != nil {
|
2020-06-08 07:13:23 +01:00
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
}
|
2020-06-08 08:06:35 +01:00
|
|
|
|
|
|
|
|
|
// Save saves a tensor to a file.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Save(path string) error {
|
2020-06-08 08:06:35 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
lib.AtSave(ts.ctensor, path)
|
|
|
|
|
if err := TorchErr(); err != nil {
|
2020-06-08 08:06:35 +01:00
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// MustSave saves a tensor to a file. It will panic if error
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) MustSave(path string) {
|
2020-06-08 08:06:35 +01:00
|
|
|
|
if err := ts.Save(path); err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Load loads a tensor from a file.
|
2023-07-04 14:26:20 +01:00
|
|
|
|
func Load(path string, nameOpt ...string) (*Tensor, error) {
|
2020-06-08 08:06:35 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
ctensor := lib.AtLoad(path)
|
|
|
|
|
if err := TorchErr(); err != nil {
|
|
|
|
|
return nil, err
|
2020-06-08 08:06:35 +01:00
|
|
|
|
}
|
|
|
|
|
|
2023-07-04 14:26:20 +01:00
|
|
|
|
return newTensor(ctensor, nameOpt...), nil
|
2020-06-08 08:06:35 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// MustLoad loads a tensor to a file. It will panic if error
|
2023-07-04 14:26:20 +01:00
|
|
|
|
func MustLoad(path string, nameOpt ...string) *Tensor {
|
|
|
|
|
ts, err := Load(path, nameOpt...)
|
2020-06-08 08:06:35 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return ts
|
2020-06-08 08:06:35 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
// NamedTensor wraps C tensor and its name
|
2020-06-08 08:06:35 +01:00
|
|
|
|
type NamedTensor struct {
|
|
|
|
|
Name string
|
2020-10-31 08:25:32 +00:00
|
|
|
|
Tensor *Tensor
|
2020-06-08 08:06:35 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// SaveMulti saves some named tensors to a file
|
|
|
|
|
//
|
|
|
|
|
// The file format is the same as the one used by the PyTorch C++ API.
|
2020-11-18 02:07:08 +00:00
|
|
|
|
// NOTE. This method is depreciated and will be replaced with `SaveMultiNew`
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func SaveMulti(namedTensors []NamedTensor, path string) error {
|
2020-06-10 07:13:47 +01:00
|
|
|
|
var ctensors []lib.Ctensor
|
2020-06-08 08:06:35 +01:00
|
|
|
|
var names []string
|
|
|
|
|
|
|
|
|
|
for _, ts := range namedTensors {
|
|
|
|
|
ctensors = append(ctensors, ts.Tensor.ctensor)
|
|
|
|
|
names = append(names, ts.Name)
|
|
|
|
|
}
|
|
|
|
|
|
2020-06-10 07:13:47 +01:00
|
|
|
|
lib.AtSaveMulti(ctensors, names, len(namedTensors), path)
|
2020-10-31 08:25:32 +00:00
|
|
|
|
if err := TorchErr(); err != nil {
|
2020-06-10 07:13:47 +01:00
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
2020-06-08 08:06:35 +01:00
|
|
|
|
return nil
|
|
|
|
|
}
|
2020-06-10 07:13:47 +01:00
|
|
|
|
|
|
|
|
|
// MustSaveMulti saves some named tensors to a file. It will panic if error
|
2020-11-18 02:07:08 +00:00
|
|
|
|
//
|
|
|
|
|
// NOTE. This method is depreciated and will be replaced with `MustSaveMultiNew`
|
2020-06-10 07:13:47 +01:00
|
|
|
|
func MustSaveMulti(namedTensors []NamedTensor, path string) {
|
|
|
|
|
err := SaveMulti(namedTensors, path)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// LoadMulti loads some named tensors from a file
|
|
|
|
|
//
|
|
|
|
|
// The file format is the same as the one used by the PyTorch C++ API.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func LoadMulti(path string) ([]NamedTensor, error) {
|
2020-06-10 07:13:47 +01:00
|
|
|
|
|
2020-06-10 08:38:14 +01:00
|
|
|
|
var data lib.LoadData
|
|
|
|
|
dataPtr := lib.PStore.Set(&data)
|
|
|
|
|
lib.AtLoadCallback(path, dataPtr)
|
2020-10-31 08:25:32 +00:00
|
|
|
|
if err := TorchErr(); err != nil {
|
|
|
|
|
return nil, err
|
2020-06-10 07:13:47 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
var namedTensors []NamedTensor
|
2020-06-10 09:31:07 +01:00
|
|
|
|
for _, v := range data.NamedCtensors {
|
|
|
|
|
namedTensor := NamedTensor{
|
|
|
|
|
Name: v.Name,
|
2023-07-04 14:26:20 +01:00
|
|
|
|
Tensor: newTensor(v.Ctensor, v.Name),
|
2020-06-10 09:31:07 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
namedTensors = append(namedTensors, namedTensor)
|
2020-06-10 09:31:07 +01:00
|
|
|
|
}
|
2020-06-10 07:13:47 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return namedTensors, nil
|
2020-06-10 07:13:47 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// MustLoadMulti loads some named tensors from a file. It will panic if error
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func MustLoadMulti(path string) []NamedTensor {
|
|
|
|
|
namedTensors, err := LoadMulti(path)
|
2020-06-10 07:13:47 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return namedTensors
|
2020-06-10 07:13:47 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-06-10 22:24:10 +01:00
|
|
|
|
// LoadMultiWithDevice loads some named tensors from a file to a given device
|
2020-06-10 22:03:59 +01:00
|
|
|
|
//
|
|
|
|
|
// The file format is the same as the one used by the PyTorch C++ API.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func LoadMultiWithDevice(path string, device gotch.Device) ([]NamedTensor, error) {
|
2020-06-10 22:03:59 +01:00
|
|
|
|
var data lib.LoadData
|
|
|
|
|
dataPtr := lib.PStore.Set(&data)
|
2020-06-10 07:13:47 +01:00
|
|
|
|
|
2020-06-10 22:03:59 +01:00
|
|
|
|
lib.AtLoadCallbackWithDevice(path, dataPtr, device.CInt())
|
2020-10-31 08:25:32 +00:00
|
|
|
|
if err := TorchErr(); err != nil {
|
|
|
|
|
return nil, err
|
2020-06-10 07:13:47 +01:00
|
|
|
|
}
|
2020-06-10 22:03:59 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
var namedTensors []NamedTensor
|
2020-06-10 22:03:59 +01:00
|
|
|
|
for _, v := range data.NamedCtensors {
|
|
|
|
|
namedTensor := NamedTensor{
|
|
|
|
|
Name: v.Name,
|
2023-07-04 14:26:20 +01:00
|
|
|
|
Tensor: newTensor(v.Ctensor, v.Name),
|
2020-06-10 22:03:59 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
namedTensors = append(namedTensors, namedTensor)
|
2020-06-10 22:03:59 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return namedTensors, nil
|
2020-06-10 22:03:59 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// MustLoadMulti loads some named tensors from a file. It will panic if error
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func MustLoadMultiWithDevice(path string, device gotch.Device) []NamedTensor {
|
|
|
|
|
namedTensors, err := LoadMultiWithDevice(path, device)
|
2020-06-10 22:03:59 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return namedTensors
|
2020-06-10 07:13:47 +01:00
|
|
|
|
}
|
2020-06-10 22:24:10 +01:00
|
|
|
|
|
|
|
|
|
// ToString returns a string representation for the tensor.
|
|
|
|
|
//
|
|
|
|
|
// lw : line width (size)
|
|
|
|
|
// NOTE: The representation will contain all the tensor element hence may be huge for
|
|
|
|
|
// large tensors.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) ToString(lw int64) (string, error) {
|
|
|
|
|
tensorStr := lib.AtToString(ts.ctensor, lw)
|
|
|
|
|
if err := TorchErr(); err != nil {
|
|
|
|
|
return "", err
|
2020-06-10 22:24:10 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return tensorStr, nil
|
2020-06-10 22:24:10 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// MustToString returns a string representation for the tensor. It will be panic
|
|
|
|
|
// if error.
|
2020-06-17 02:23:00 +01:00
|
|
|
|
// lw : line width (size)
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) MustToString(lw int64) string {
|
|
|
|
|
tensorStr, err := ts.ToString(lw)
|
2020-06-10 22:24:10 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return tensorStr
|
2020-06-10 22:24:10 +01:00
|
|
|
|
}
|
2020-06-10 22:37:09 +01:00
|
|
|
|
|
|
|
|
|
// Drop drops (frees) the tensor
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Drop() error {
|
2023-07-07 13:30:08 +01:00
|
|
|
|
return freeCTensor(ts)
|
2020-06-10 22:37:09 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// MustDrop drops the tensor. It will be panic if error
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) MustDrop() {
|
2020-06-10 22:37:09 +01:00
|
|
|
|
if err := ts.Drop(); err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
}
|
2020-06-11 02:57:56 +01:00
|
|
|
|
|
|
|
|
|
// GradSetEnabled sets globally whether GradMode gradient accumulation is enable or not.
|
|
|
|
|
// It returns PREVIOUS state of Grad before setting.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func GradSetEnabled(b bool) (bool, error) {
|
2020-06-11 02:57:56 +01:00
|
|
|
|
|
|
|
|
|
var cbool, cretVal int
|
|
|
|
|
switch b {
|
|
|
|
|
case true:
|
|
|
|
|
cbool = 1
|
|
|
|
|
case false:
|
|
|
|
|
cbool = 0
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
var (
|
|
|
|
|
err error
|
|
|
|
|
state bool
|
|
|
|
|
)
|
2020-06-11 02:57:56 +01:00
|
|
|
|
cretVal = lib.AtGradSetEnabled(cbool)
|
|
|
|
|
if err = TorchErr(); err != nil {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return false, err
|
2020-06-11 02:57:56 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch cretVal {
|
|
|
|
|
case 0:
|
2020-10-31 08:25:32 +00:00
|
|
|
|
state = false
|
2020-06-11 02:57:56 +01:00
|
|
|
|
break
|
|
|
|
|
case 1:
|
2020-10-31 08:25:32 +00:00
|
|
|
|
state = true
|
2020-06-11 02:57:56 +01:00
|
|
|
|
break
|
|
|
|
|
// case -1: // should be unreachable as error is captured above with TorchrErr()
|
|
|
|
|
// err = fmt.Errorf("Cannot set grad enable. \n")
|
|
|
|
|
// return retVal, err
|
|
|
|
|
// default: // should be unreachable as error is captured above with TorchrErr()
|
|
|
|
|
// err = fmt.Errorf("Cannot set grad enable. \n")
|
|
|
|
|
// return retVal, err
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return state, nil
|
2020-06-11 02:57:56 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// MustGradSetEnabled sets globally whether GradMode gradient accumuation is enable or not.
|
|
|
|
|
// It returns PREVIOUS state of Grad before setting. It will be panic if error
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func MustGradSetEnabled(b bool) bool {
|
|
|
|
|
state, err := GradSetEnabled(b)
|
2020-06-11 02:57:56 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return state
|
2020-06-11 02:57:56 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NoGrad runs a closure without keeping track of gradients.
|
2023-07-05 15:20:11 +01:00
|
|
|
|
func NoGrad(fn func()) {
|
2020-06-11 02:57:56 +01:00
|
|
|
|
// Switch off Grad
|
2023-07-04 14:26:20 +01:00
|
|
|
|
MustGradSetEnabled(false)
|
2020-06-11 02:57:56 +01:00
|
|
|
|
|
2023-07-04 14:26:20 +01:00
|
|
|
|
fn()
|
2020-06-11 02:57:56 +01:00
|
|
|
|
|
|
|
|
|
// Switch on Grad
|
2023-07-04 14:26:20 +01:00
|
|
|
|
MustGradSetEnabled(true)
|
2020-06-11 02:57:56 +01:00
|
|
|
|
}
|
2020-06-11 05:11:58 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func NoGrad1(fn func() interface{}) interface{} {
|
2020-07-02 07:26:54 +01:00
|
|
|
|
newTs := NewTensor()
|
|
|
|
|
newTs.Drop()
|
|
|
|
|
|
|
|
|
|
// Switch off Grad
|
|
|
|
|
prev := MustGradSetEnabled(false)
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
retVal := fn()
|
2020-07-02 07:26:54 +01:00
|
|
|
|
|
|
|
|
|
// Switch on Grad
|
|
|
|
|
_ = MustGradSetEnabled(prev)
|
|
|
|
|
|
|
|
|
|
return retVal
|
|
|
|
|
}
|
|
|
|
|
|
2020-06-11 05:11:58 +01:00
|
|
|
|
// NoGradGuard is a RAII guard that prevents gradient tracking until deallocated.
|
2020-07-10 09:37:07 +01:00
|
|
|
|
// It actually sets a global flag that is checked by the backend whenever an op is done on a variable.
|
|
|
|
|
// The guard itself saved the current status and set it to false in the constructor.
|
|
|
|
|
// And restore the saved status in it’s destructor.
|
|
|
|
|
// That way it is similar to a with torch.no_grad(): block in python.
|
|
|
|
|
// Ref. https://discuss.pytorch.org/t/how-does-nogradguard-works-in-cpp/34960/2
|
|
|
|
|
//
|
|
|
|
|
// TODO: should we implement Go `mutex` here???
|
2020-06-11 05:11:58 +01:00
|
|
|
|
type NoGradGuard struct {
|
|
|
|
|
enabled bool
|
|
|
|
|
}
|
|
|
|
|
|
2020-06-17 11:29:31 +01:00
|
|
|
|
// Init NoGradGuard and disables gradient tracking
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func NewNoGradGuard() *NoGradGuard {
|
2020-06-17 11:29:31 +01:00
|
|
|
|
return noGradGuardInit()
|
|
|
|
|
}
|
|
|
|
|
|
2020-06-11 05:11:58 +01:00
|
|
|
|
// Disables gradient tracking, this will be enabled back when the
|
|
|
|
|
// returned value gets deallocated.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func noGradGuardInit() *NoGradGuard {
|
|
|
|
|
return &NoGradGuard{enabled: MustGradSetEnabled(false)}
|
2020-06-11 05:11:58 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Drop drops the NoGradGuard state.
|
2020-06-17 11:29:31 +01:00
|
|
|
|
func (ngg *NoGradGuard) Drop() {
|
|
|
|
|
ngg.enabled = true
|
|
|
|
|
_ = MustGradSetEnabled(ngg.enabled)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (ngg *NoGradGuard) Enable() {
|
|
|
|
|
ngg.enabled = false
|
|
|
|
|
_ = MustGradSetEnabled(ngg.enabled)
|
2020-06-11 05:11:58 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
|
// Do not reduce
|
2022-03-12 04:47:15 +00:00
|
|
|
|
ReductionNone int64 = 0
|
2020-06-11 05:11:58 +01:00
|
|
|
|
// Mean of losses
|
2022-03-12 04:47:15 +00:00
|
|
|
|
ReductionMean int64 = 1
|
2020-06-11 05:11:58 +01:00
|
|
|
|
// Sum of losses
|
2022-03-12 04:47:15 +00:00
|
|
|
|
ReductionSum int64 = 2
|
2020-06-11 05:11:58 +01:00
|
|
|
|
// Escape hatch in case new options become available
|
2022-03-12 04:47:15 +00:00
|
|
|
|
ReductionOther int64 = 3
|
2020-06-11 05:11:58 +01:00
|
|
|
|
)
|
|
|
|
|
|
2022-03-12 04:47:15 +00:00
|
|
|
|
// func (r Reduction) ToInt() int {
|
|
|
|
|
// switch r {
|
|
|
|
|
// case ReductionNone:
|
|
|
|
|
// return 0
|
|
|
|
|
// case ReductionMean:
|
|
|
|
|
// return 1
|
|
|
|
|
// case ReductionSum:
|
|
|
|
|
// return 2
|
|
|
|
|
// case ReductionOther:
|
|
|
|
|
// return 3
|
|
|
|
|
// }
|
|
|
|
|
//
|
|
|
|
|
// // NOTE. should it be panic here instead of returning -1?
|
|
|
|
|
// return -1
|
|
|
|
|
// }
|
2020-06-21 01:57:29 +01:00
|
|
|
|
|
2020-07-22 06:56:30 +01:00
|
|
|
|
// Float64Values returns values of tensor in a slice of float64.
|
2022-03-12 04:47:15 +00:00
|
|
|
|
func (ts *Tensor) Float64Values(delOpt ...bool) []float64 {
|
|
|
|
|
del := false
|
|
|
|
|
if len(delOpt) > 0 {
|
|
|
|
|
del = delOpt[0]
|
|
|
|
|
}
|
2020-07-17 02:32:33 +01:00
|
|
|
|
numel := ts.Numel()
|
|
|
|
|
vec := make([]float64, numel)
|
|
|
|
|
|
|
|
|
|
float64Ts := ts.MustTotype(gotch.Double, false)
|
|
|
|
|
|
|
|
|
|
float64Ts.MustCopyData(vec, numel)
|
2023-07-04 14:26:20 +01:00
|
|
|
|
// float64Ts.MustDrop()
|
2020-07-17 02:32:33 +01:00
|
|
|
|
|
2022-03-12 04:47:15 +00:00
|
|
|
|
if del {
|
|
|
|
|
ts.MustDrop()
|
|
|
|
|
}
|
|
|
|
|
|
2020-07-17 02:32:33 +01:00
|
|
|
|
return vec
|
2020-06-21 01:57:29 +01:00
|
|
|
|
}
|
2020-06-29 08:56:11 +01:00
|
|
|
|
|
2020-07-25 10:02:48 +01:00
|
|
|
|
// Int64Values returns values of tensor in a slice of int64.
|
2022-03-12 04:47:15 +00:00
|
|
|
|
func (ts *Tensor) Int64Values(delOpt ...bool) []int64 {
|
|
|
|
|
del := false
|
|
|
|
|
if len(delOpt) > 0 {
|
|
|
|
|
del = delOpt[0]
|
|
|
|
|
}
|
2020-07-25 10:02:48 +01:00
|
|
|
|
numel := ts.Numel()
|
|
|
|
|
vec := make([]int64, numel)
|
|
|
|
|
|
|
|
|
|
int64Ts := ts.MustTotype(gotch.Int64, false)
|
|
|
|
|
|
|
|
|
|
int64Ts.MustCopyData(vec, numel)
|
|
|
|
|
int64Ts.MustDrop()
|
|
|
|
|
|
2022-03-12 04:47:15 +00:00
|
|
|
|
if del {
|
|
|
|
|
ts.MustDrop()
|
|
|
|
|
}
|
|
|
|
|
|
2020-07-25 10:02:48 +01:00
|
|
|
|
return vec
|
|
|
|
|
}
|
|
|
|
|
|
2020-07-10 06:28:37 +01:00
|
|
|
|
// Vals returns tensor values in a slice
|
|
|
|
|
// NOTE: need a type insersion to get runtime type
|
|
|
|
|
// E.g. res := xs.Vals().([]int64)
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Vals() interface{} {
|
2020-07-10 06:28:37 +01:00
|
|
|
|
dtype := ts.DType()
|
2023-07-06 15:01:23 +01:00
|
|
|
|
numel := int(ts.Numel())
|
2020-07-10 06:28:37 +01:00
|
|
|
|
|
2023-07-06 15:01:23 +01:00
|
|
|
|
typ, err := dtype.GoType()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dataSlice := reflect.MakeSlice(reflect.SliceOf(typ), numel, numel).Interface()
|
|
|
|
|
|
|
|
|
|
ts.CopyData(dataSlice, uint(numel))
|
|
|
|
|
|
|
|
|
|
return dataSlice
|
2020-07-10 06:28:37 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-06-29 08:56:11 +01:00
|
|
|
|
// FlatView flattens a tensor.
|
|
|
|
|
//
|
|
|
|
|
// This returns a flattened version of the given tensor. The first dimension
|
|
|
|
|
// is preserved as it is assumed to be the mini-batch dimension.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) FlatView() *Tensor {
|
2020-06-29 08:56:11 +01:00
|
|
|
|
batchSize := ts.MustSize()[0]
|
|
|
|
|
return ts.MustView([]int64{batchSize, -1}, false)
|
|
|
|
|
}
|
2020-07-06 03:02:38 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) ZeroPad2d(left, right, top, bottom int64, del bool) (*Tensor, error) {
|
2020-07-06 03:02:38 +01:00
|
|
|
|
if ts.Dim() != 4 {
|
2020-10-31 08:25:32 +00:00
|
|
|
|
err := fmt.Errorf("Expected a 4 dimension tensor, got %v\n", ts.MustSize())
|
|
|
|
|
return nil, err
|
2020-07-06 03:02:38 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ts.ConstantPadNd([]int64{left, right, top, bottom}, del)
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) MustZeroPad2d(left, right, top, bottom int64, del bool) *Tensor {
|
2020-07-06 03:02:38 +01:00
|
|
|
|
retVal, err := ts.ZeroPad2d(left, right, top, bottom, del)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal
|
|
|
|
|
}
|
|
|
|
|
|
2020-07-27 02:31:42 +01:00
|
|
|
|
// Onehot converts a tensor to a one-hot encoded version.
|
|
|
|
|
//
|
|
|
|
|
// If the input has a size [N1, N2, ..., Nk], the returned tensor has a size
|
|
|
|
|
// [N1, ..., Nk, labels]. The returned tensor uses float values.
|
|
|
|
|
// Elements of the input vector are expected to be between 0 and labels-1.
|
|
|
|
|
//
|
|
|
|
|
// NOTE: There's other `ts.OneHot` and `ts.MustOneHot` generated from Atg C++ API
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Onehot(labels int64) *Tensor {
|
2020-07-27 02:31:42 +01:00
|
|
|
|
dims := ts.MustSize()
|
|
|
|
|
dims = append(dims, labels)
|
2020-07-28 07:08:40 +01:00
|
|
|
|
unsqueezeTs := ts.MustUnsqueeze(-1, false)
|
|
|
|
|
inputTs := unsqueezeTs.MustTotype(gotch.Int64, true)
|
|
|
|
|
|
2020-07-27 02:31:42 +01:00
|
|
|
|
zerosTs := MustZeros(dims, gotch.Float, gotch.CPU)
|
2021-07-22 15:54:41 +01:00
|
|
|
|
retVal := zerosTs.MustScatterValue(-1, inputTs, FloatScalar(1.0), true)
|
2020-07-28 07:08:40 +01:00
|
|
|
|
inputTs.MustDrop()
|
|
|
|
|
|
|
|
|
|
return retVal
|
2020-07-27 02:31:42 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) Swish() *Tensor {
|
2020-07-06 03:02:38 +01:00
|
|
|
|
sig := ts.MustSigmoid(false)
|
2020-10-31 08:25:32 +00:00
|
|
|
|
mulTs := ts.MustMul(sig, false)
|
2020-07-06 03:02:38 +01:00
|
|
|
|
sig.MustDrop()
|
2020-10-31 08:25:32 +00:00
|
|
|
|
return mulTs
|
2020-07-06 03:02:38 +01:00
|
|
|
|
}
|
2020-07-06 09:32:43 +01:00
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) AvgPool2DDefault(ksize int64, del bool) *Tensor {
|
2020-11-02 11:35:25 +00:00
|
|
|
|
return ts.MustAvgPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, false, true, []int64{1}, del)
|
2020-07-06 09:32:43 +01:00
|
|
|
|
}
|
2020-10-29 12:40:24 +00:00
|
|
|
|
|
|
|
|
|
// SaveMultiNew saves a slice of named tensors to the given file path.
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func SaveMultiNew(namedTensors []NamedTensor, path string) error {
|
2020-10-29 12:40:24 +00:00
|
|
|
|
var (
|
|
|
|
|
tensors []lib.Ctensor
|
|
|
|
|
names []string
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for _, nts := range namedTensors {
|
|
|
|
|
tensors = append(tensors, nts.Tensor.ctensor)
|
|
|
|
|
names = append(names, nts.Name)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
lib.AtSaveMultiNew(tensors, names, path)
|
2020-10-31 08:25:32 +00:00
|
|
|
|
if err := TorchErr(); err != nil {
|
2020-10-29 12:40:24 +00:00
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
2022-01-17 10:41:16 +00:00
|
|
|
|
|
|
|
|
|
func (ts *Tensor) ConstantPadNdWithVal(pad []int64, value *Scalar, del bool) (retVal *Tensor, err error) {
|
|
|
|
|
if del {
|
|
|
|
|
defer ts.MustDrop()
|
|
|
|
|
}
|
|
|
|
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
|
|
|
|
|
|
|
|
lib.AtoConstantPadNd(ptr, ts.ctensor, pad, len(pad), value.cscalar)
|
|
|
|
|
if err = TorchErr(); err != nil {
|
|
|
|
|
return retVal, err
|
|
|
|
|
}
|
2023-07-04 14:26:20 +01:00
|
|
|
|
|
|
|
|
|
retVal = newTensor(*ptr)
|
2022-01-17 10:41:16 +00:00
|
|
|
|
|
|
|
|
|
return retVal, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (ts *Tensor) MustConstantPadNdWithVal(pad []int64, value *Scalar, del bool) (retVal *Tensor) {
|
|
|
|
|
|
|
|
|
|
retVal, err := ts.ConstantPadNdWithVal(pad, value, del)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal
|
|
|
|
|
}
|
2023-07-11 06:35:36 +01:00
|
|
|
|
|
|
|
|
|
// TT. Added some torch.cuda APIs for handling CUDA qt
|
|
|
|
|
|
|
|
|
|
// CudaCurrentDevice get device index of current CUDA device.
|
|
|
|
|
func CudaCurrentDevice() (int, error) {
|
|
|
|
|
currentDeviceIndex := lib.AtcGetDevice()
|
|
|
|
|
if err := TorchErr(); err != nil {
|
|
|
|
|
err = fmt.Errorf("ts.CudaCurrentDevice() failed: %w\n", err)
|
|
|
|
|
return -99, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return currentDeviceIndex, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// CudaSetDevice set new cuda device index and returns previous cuda index.
|
|
|
|
|
func CudaSetDevice(cudaDeviceIndex int) (int, error) {
|
|
|
|
|
currentDeviceIndex, err := CudaCurrentDevice()
|
|
|
|
|
if err != nil {
|
|
|
|
|
err = fmt.Errorf("ts.CudaSetDevice() failed: %w\n", err)
|
|
|
|
|
return -99, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
lib.AtcSetDevice(cudaDeviceIndex)
|
|
|
|
|
if err := TorchErr(); err != nil {
|
|
|
|
|
err = fmt.Errorf("ts.CudaSetDevice() failed: %w\n", err)
|
|
|
|
|
return -99, err
|
|
|
|
|
}
|
|
|
|
|
return currentDeviceIndex, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// CudaSynchronize waits for all kernels in all streams on a CUDA device to complete.
|
|
|
|
|
func CudaSynchronize(cudaDeviceIndexOpt ...int) error {
|
|
|
|
|
var cudaDeviceIndex int
|
|
|
|
|
var err error
|
|
|
|
|
if len(cudaDeviceIndexOpt) > 0 {
|
|
|
|
|
cudaDeviceIndex = cudaDeviceIndexOpt[0]
|
|
|
|
|
} else {
|
|
|
|
|
cudaDeviceIndex, err = CudaCurrentDevice()
|
|
|
|
|
if err != nil {
|
|
|
|
|
err := fmt.Errorf("ts.CudaSynchronize() failed: %w\n", err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
lib.AtcSynchronize(int64(cudaDeviceIndex))
|
|
|
|
|
|
|
|
|
|
return TorchErr()
|
|
|
|
|
}
|