gotch/wrapper/tensor.go

945 lines
21 KiB
Go

package wrapper
//#include "stdlib.h"
//#include "stdbool.h"
import "C"
import (
"bytes"
"encoding/binary"
"fmt"
"log"
"reflect"
// "strings"
"unsafe"
gotch "github.com/sugarme/gotch"
lib "github.com/sugarme/gotch/libtch"
)
type Tensor struct {
ctensor lib.Ctensor
}
// NewTensor creates a new tensor
func NewTensor() Tensor {
ctensor := lib.AtNewTensor()
return Tensor{ctensor}
}
func (ts Tensor) Dim() uint64 {
retVal := lib.AtDim(ts.ctensor)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
return retVal
}
// Size return shape of the tensor
//
// NOTE: C++ libtorch calls at_shape() -> t.sizes()
// And returns a slice of sizes or shape using given pointer
// to that slice.
func (ts Tensor) Size() (retVal []int64, err error) {
dim := lib.AtDim(ts.ctensor)
sz := make([]int64, dim)
szPtr, err := DataAsPtr(sz)
if err != nil {
return retVal, err
}
defer C.free(unsafe.Pointer(szPtr))
lib.AtShape(ts.ctensor, szPtr)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = decodeSize(szPtr, dim)
return retVal, nil
}
func (ts Tensor) MustSize() (retVal []int64) {
retVal, err := ts.Size()
if err != nil {
log.Fatal(err)
}
return retVal
}
// Size1 returns the tensor size for 1D tensors.
func (ts Tensor) Size1() (retVal int64, err error) {
shape, err := ts.Size()
if err != nil {
return retVal, err
}
if len(shape) != 1 {
err = fmt.Errorf("Expected one dim, got %v\n", len(shape))
return retVal, err
}
return shape[0], nil
}
// Size2 returns the tensor size for 2D tensors.
func (ts Tensor) Size2() (retVal []int64, err error) {
shape, err := ts.Size()
if err != nil {
return retVal, err
}
if len(shape) != 2 {
err = fmt.Errorf("Expected two dims, got %v\n", len(shape))
return retVal, err
}
return shape, nil
}
// Size3 returns the tensor size for 3D tensors.
func (ts Tensor) Size3() (retVal []int64, err error) {
shape, err := ts.Size()
if err != nil {
return retVal, err
}
if len(shape) != 3 {
err = fmt.Errorf("Expected three dims, got %v\n", len(shape))
return retVal, err
}
return shape, nil
}
// Size4 returns the tensor size for 4D tensors.
func (ts Tensor) Size4() (retVal []int64, err error) {
shape, err := ts.Size()
if err != nil {
return retVal, err
}
if len(shape) != 4 {
err = fmt.Errorf("Expected four dims, got %v\n", len(shape))
return retVal, err
}
return shape, nil
}
func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
// Decode sz
// 1. Count number of elements in data
elementNum := nsize
// 2. Element size in bytes
eltSizeInBytes, err := gotch.DTypeSize(gotch.Int64)
if err != nil {
log.Fatal(err)
}
nbytes := int(eltSizeInBytes) * int(elementNum)
dataSlice := (*[1 << 30]byte)(ptr)[:nbytes:nbytes]
r := bytes.NewReader(dataSlice)
dataIn := make([]int64, nsize)
if err := binary.Read(r, nativeEndian, dataIn); err != nil {
log.Fatal(err)
}
return dataIn
}
// OfSlice creates tensor from a slice data
func OfSlice(data interface{}) (retVal Tensor, err error) {
typ, dataLen, err := DataCheck(data)
if err != nil {
return retVal, err
}
dtype, err := gotch.ToDType(typ)
if err != nil {
return retVal, err
}
shape := []int64{int64(dataLen)}
elementNum := ElementCount(shape)
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return retVal, err
}
nbytes := int(eltSizeInBytes) * int(elementNum)
dataPtr, buff := CMalloc(nbytes)
defer C.free(unsafe.Pointer(dataPtr))
if err = EncodeTensor(buff, reflect.ValueOf(data), shape); err != nil {
return retVal, err
}
cint, err := gotch.DType2CInt(dtype)
if err != nil {
return retVal, err
}
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor}
return retVal, nil
}
// MustOfSlice create a tensor from slice of data. It will be panic if error.
func MustOfSlice(data interface{}) (retVal Tensor) {
retVal, err := OfSlice(data)
if err != nil {
log.Fatal(err)
}
return retVal
}
// TensorFrom create a tensor from slice of data. It will be panic if error.
func TensorFrom(data interface{}) (retVal Tensor) {
retVal, err := OfSlice(data)
if err != nil {
log.Fatal(err)
}
return retVal
}
// Print prints tensor values to console.
//
// NOTE: it is printed from C and will print ALL elements of tensor
// with no truncation at all.
func (ts Tensor) Print() {
lib.AtPrint(ts.ctensor)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
}
// NewTensorFromData creates tensor from given data and shape
func NewTensorFromData(data interface{}, shape []int64) (retVal Tensor, err error) {
// 1. Check whether data and shape match
elementNum, err := DataDim(data)
if err != nil {
return retVal, err
}
nflattend := FlattenDim(shape)
if elementNum != nflattend {
err = fmt.Errorf("Number of data elements (%v) and flatten shape (%v) dimension mismatched.\n", elementNum, nflattend)
return retVal, err
}
// 2. Write raw data to C memory and get C pointer
dataPtr, err := DataAsPtr(data)
defer C.free(unsafe.Pointer(dataPtr))
if err != nil {
return retVal, err
}
// 3. Create tensor with pointer and shape
dtype, err := gotch.DTypeFromData(data)
if err != nil {
return retVal, err
}
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return retVal, err
}
cint, err := gotch.DType2CInt(dtype)
if err != nil {
return retVal, err
}
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
// defer C.free(unsafe.Pointer(ctensor))
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor}
return retVal, nil
}
func (ts Tensor) DType() gotch.DType {
cint := lib.AtScalarType(ts.ctensor)
dtype, err := gotch.CInt2DType(cint)
if err != nil {
log.Fatalf("Tensor DType error: %v\n", err)
}
return dtype
}
func (ts Tensor) Device() (retVal gotch.Device, err error) {
cInt := lib.AtDevice(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
var device gotch.Device
return device.OfCInt(int32(cInt)), nil
}
func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) {
// Get a C null pointer
// https://stackoverflow.com/a/2022369
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgEq1(ptr, ts.ctensor, other.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return Tensor{ctensor: *ptr}, nil
}
// DoubleValue returns a float value on tensors holding a single element.
// An error is returned otherwise.
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
func (ts Tensor) Float64Value(idx []int64) (retVal float64, err error) {
idxPtr, err := DataAsPtr(idx)
if err != nil {
return retVal, err
}
defer C.free(unsafe.Pointer(idxPtr))
retVal = lib.AtDoubleValueAtIndexes(ts.ctensor, idxPtr, len(idx))
if err = TorchErr(); err != nil {
return retVal, err
}
return retVal, err
}
// Int64Value returns an int value on tensors holding a single element. An error is
// returned otherwise.
func (ts Tensor) Int64Value(idx []int64) (retVal int64, err error) {
idxPtr, err := DataAsPtr(idx)
if err != nil {
return retVal, err
}
defer C.free(unsafe.Pointer(idxPtr))
retVal = lib.AtInt64ValueAtIndexes(ts.ctensor, idxPtr, len(idx))
if err = TorchErr(); err != nil {
return retVal, err
}
return retVal, err
}
// RequiresGrad returns true if gradient are currently tracked for this tensor.
func (ts Tensor) RequiresGrad() (retVal bool, err error) {
retVal = lib.AtRequiresGrad(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return retVal, nil
}
// DataPtr returns the address of the first element of this tensor.
func (ts Tensor) DataPtr() (retVal unsafe.Pointer, err error) {
retVal = lib.AtDataPtr(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return retVal, nil
}
// Defined returns true is the tensor is defined.
func (ts Tensor) Defined() (retVal bool, err error) {
retVal = lib.AtDefined(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return retVal, nil
}
func (ts Tensor) MustDefined() (retVal bool) {
retVal, err := ts.Defined()
if err != nil {
log.Fatal(err)
}
return retVal
}
// IsSparse returns true is the tensor is spare.
func (ts Tensor) IsSparse() (retVal bool, err error) {
retVal = lib.AtIsSparse(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return retVal, nil
}
// ZeroGrad zeroes the gradient tensor attached to this tensor if defined.
func (ts Tensor) ZeroGrad() {
grad := ts.MustGrad()
if grad.MustDefined() {
// TODO: can we chain them?
// grad.MustDetach_().MustZero_()
// https://www.calhoun.io/using-functional-options-instead-of-method-chaining-in-go/
detach := grad.MustDetach_()
detach.MustZero_()
}
}
// Backward runs the backward pass, populating the gradient tensors for tensors
// which gradients are tracked.
//
// Gradients tracking can be turned on via `SetRequiresGrad`.
func (ts Tensor) Backward() (err error) {
lib.AtBackward(ts.ctensor, 0, 0)
if err = TorchErr(); err != nil {
return err
}
return nil
}
func (ts Tensor) MustBackward() {
if err := ts.Backward(); err != nil {
log.Fatal(err)
}
}
// RunBackward runs the backward ...
func RunBackward(tensors []Tensor, inputs []Tensor, keepGraphB bool, createGraphB bool) (retVal []Tensor, err error) {
// NOTE: outputs is a slice of tensors with length = len(inputs)
var outputsPtr []*lib.Ctensor
// TODO: Are they allocated continouslly???
for i := 0; i < len(inputs); i++ {
outputPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
// defer C.free(unsafe.Pointer(outputPtr))
outputsPtr = append(outputsPtr, outputPtr)
}
// Get first element pointer
ctensor := tensors[0].ctensor
cinput := inputs[0].ctensor
tensorsPtr := (*lib.Ctensor)(unsafe.Pointer(&ctensor))
inputsPtr := (*lib.Ctensor)(unsafe.Pointer(&cinput))
var keepGraph int = 0
if keepGraphB {
keepGraph = 1
}
var createGraph int = 0
if createGraphB {
createGraph = 1
}
lib.AtRunBackward(tensorsPtr, len(tensors), inputsPtr, len(inputs), outputsPtr[0], keepGraph, createGraph)
if err = TorchErr(); err != nil {
return retVal, err
}
for i := 0; i < len(inputs); i++ {
outputPtr := outputsPtr[i]
retVal = append(retVal, Tensor{ctensor: *outputPtr})
}
return retVal, nil
}
// CopyDataUint8 copies `numel` elements from `self` to `dst`.
//
// NOTE: `dst` located in Go memory. Should it be?
func (ts Tensor) CopyDataUint8(dst []uint8, numel uint) (err error) {
// NOTE: we must make sure that `dst` has same len as `numel`. Otherwise,
// there will be memory leak and or out of range error.
if len(dst) < int(numel) {
err = fmt.Errorf("CopyDataUint8 Error: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", len(dst), numel)
return err
}
vs := unsafe.Pointer(&dst[0])
elt_size_in_bytes, err := gotch.DTypeSize(gotch.Uint8)
if err != nil {
return err
}
lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes)
if err = TorchErr(); err != nil {
return err
}
return nil
}
func (ts Tensor) MustCopyDataUint8(dst []uint8, numel uint) {
err := ts.CopyDataUint8(dst, numel)
if err != nil {
log.Fatal(err)
}
}
// CopyData copies `numel` elements from `self` to `dst`.
// `dst` should be a slice of Go type equivalent to tensor type.
//
// NOTE: `dst` located in Go memory. Should it be?
func (ts Tensor) CopyData(dst interface{}, numel uint) (err error) {
gotype, dlen, err := DataCheck(dst)
if err != nil {
return err
}
dtype, err := gotch.ToDType(gotype)
if err != nil {
return err
}
if dlen < int(numel) {
err = fmt.Errorf("CopyDataUint8 Error: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel)
return err
}
if ts.DType() != dtype {
err = fmt.Errorf("Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
return err
}
var vs unsafe.Pointer
switch dtype {
case gotch.Uint8:
vs = unsafe.Pointer(&dst.([]uint8)[0])
case gotch.Int8:
vs = unsafe.Pointer(&dst.([]int8)[0])
case gotch.Int16:
vs = unsafe.Pointer(&dst.([]int16)[0])
case gotch.Int:
vs = unsafe.Pointer(&dst.([]int32)[0])
case gotch.Int64:
vs = unsafe.Pointer(&dst.([]int64)[0])
case gotch.Float:
vs = unsafe.Pointer(&dst.([]float32)[0])
case gotch.Double:
vs = unsafe.Pointer(&dst.([]float64)[0])
case gotch.Bool:
vs = unsafe.Pointer(&dst.([]bool)[0])
default:
err = fmt.Errorf("Unsupported type: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
return err
}
elt_size_in_bytes, err := gotch.DTypeSize(dtype)
if err != nil {
return err
}
lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes)
if err = TorchErr(); err != nil {
return err
}
return nil
}
// MustCopyData copies number of elements from tensor to a slice of data
//
// NOTE: `dst` is a slice with length = numel and Go type equavalent to tensor
// DType
func (ts Tensor) MustCopyData(dst interface{}, numel uint) {
err := ts.CopyData(dst, numel)
if err != nil {
log.Fatal(err)
}
}
// Numel returns the total number of elements stored in a tensor.
func (ts Tensor) Numel() (retVal uint) {
var shape []int64
shape = ts.MustSize()
return uint(FlattenDim(shape))
}
// ShallowCopy returns a new tensor that share storage with the input tensor.
func (ts Tensor) ShallowClone() (retVal Tensor, err error) {
ctensor := lib.AtShallowClone(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor}
return retVal, nil
}
// MustShallowClone returns a new tensor that share storage with the input
// tensor. It will panic if error occurred
func (ts Tensor) MustShallowClone() (retVal Tensor) {
retVal, err := ts.ShallowClone()
if err != nil {
log.Fatal(err)
}
return retVal
}
// Get gets the sub-tensor at the given index.
func (ts Tensor) Get(index int) (retVal Tensor, err error) {
ctensor := lib.AtGet(ts.ctensor, index)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor}
return retVal, nil
}
// MustGet gets the sub-tensor at the given index. It will panic if error
// occurred.
func (ts Tensor) MustGet(index int) (retVal Tensor) {
retVal, err := ts.Get(index)
if err != nil {
log.Fatal(err)
}
return retVal
}
// Copy_ copies in-place values from the argument tensor to the input tensor.
func Copy_(self, src Tensor) (err error) {
lib.AtCopy_(self.ctensor, src.ctensor)
if err = TorchErr(); err != nil {
return err
}
return nil
}
// MustCopy_ copies in-place values from the argument tensor to the input tensor.
// It will panic if error occurred.
func MustCopy_(self, src Tensor) {
if err := Copy_(self, src); err != nil {
log.Fatal(err)
}
}
// Save saves a tensor to a file.
func (ts Tensor) Save(path string) (err error) {
lib.AtSave(ts.ctensor, path)
if err = TorchErr(); err != nil {
return err
}
return nil
}
// MustSave saves a tensor to a file. It will panic if error
func (ts Tensor) MustSave(path string) {
if err := ts.Save(path); err != nil {
log.Fatal(err)
}
}
// Load loads a tensor from a file.
func Load(path string) (retVal Tensor, err error) {
ctensor := lib.AtLoad(path)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor}
return retVal, nil
}
// MustLoad loads a tensor to a file. It will panic if error
func MustLoad(path string) (retVal Tensor) {
retVal, err := Load(path)
if err != nil {
log.Fatal(err)
}
return retVal
}
type NamedTensor struct {
Name string
Tensor Tensor
}
// SaveMulti saves some named tensors to a file
//
// The file format is the same as the one used by the PyTorch C++ API.
func SaveMulti(namedTensors []NamedTensor, path string) (err error) {
var ctensors []lib.Ctensor
var names []string
for _, ts := range namedTensors {
ctensors = append(ctensors, ts.Tensor.ctensor)
names = append(names, ts.Name)
}
lib.AtSaveMulti(ctensors, names, len(namedTensors), path)
if err = TorchErr(); err != nil {
return err
}
return nil
}
// MustSaveMulti saves some named tensors to a file. It will panic if error
func MustSaveMulti(namedTensors []NamedTensor, path string) {
err := SaveMulti(namedTensors, path)
if err != nil {
log.Fatal(err)
}
}
// LoadMulti loads some named tensors from a file
//
// The file format is the same as the one used by the PyTorch C++ API.
func LoadMulti(path string) (retVal []NamedTensor, err error) {
var data lib.LoadData
dataPtr := lib.PStore.Set(&data)
lib.AtLoadCallback(path, dataPtr)
if err = TorchErr(); err != nil {
return retVal, err
}
for _, v := range data.NamedCtensors {
namedTensor := NamedTensor{
Name: v.Name,
Tensor: Tensor{v.Ctensor},
}
retVal = append(retVal, namedTensor)
}
return retVal, nil
}
// MustLoadMulti loads some named tensors from a file. It will panic if error
func MustLoadMulti(path string) (retVal []NamedTensor) {
retVal, err := LoadMulti(path)
if err != nil {
log.Fatal(err)
}
return retVal
}
// LoadMultiWithDevice loads some named tensors from a file to a given device
//
// The file format is the same as the one used by the PyTorch C++ API.
func LoadMultiWithDevice(path string, device gotch.Device) (retVal []NamedTensor, err error) {
var data lib.LoadData
dataPtr := lib.PStore.Set(&data)
lib.AtLoadCallbackWithDevice(path, dataPtr, device.CInt())
if err = TorchErr(); err != nil {
return retVal, err
}
for _, v := range data.NamedCtensors {
namedTensor := NamedTensor{
Name: v.Name,
Tensor: Tensor{v.Ctensor},
}
retVal = append(retVal, namedTensor)
}
return retVal, nil
}
// MustLoadMulti loads some named tensors from a file. It will panic if error
func MustLoadMultiWithDevice(path string, device gotch.Device) (retVal []NamedTensor) {
retVal, err := LoadMultiWithDevice(path, device)
if err != nil {
log.Fatal(err)
}
return retVal
}
// ToString returns a string representation for the tensor.
//
// lw : line width (size)
// NOTE: The representation will contain all the tensor element hence may be huge for
// large tensors.
func (ts Tensor) ToString(lw int64) (retVal string, err error) {
retVal = lib.AtToString(ts.ctensor, lw)
if err = TorchErr(); err != nil {
return retVal, err
}
return retVal, nil
}
// MustToString returns a string representation for the tensor. It will be panic
// if error.
func (ts Tensor) MustToString(lw int64) (retVal string) {
retVal, err := ts.ToString(lw)
if err != nil {
log.Fatal(err)
}
return retVal
}
// Drop drops (frees) the tensor
func (ts Tensor) Drop() (err error) {
lib.AtFree(ts.ctensor)
if err = TorchErr(); err != nil {
return err
}
return nil
}
// MustDrop drops the tensor. It will be panic if error
func (ts Tensor) MustDrop() {
if err := ts.Drop(); err != nil {
log.Fatal(err)
}
}
// GradSetEnabled sets globally whether GradMode gradient accumulation is enable or not.
// It returns PREVIOUS state of Grad before setting.
func GradSetEnabled(b bool) (retVal bool, err error) {
var cbool, cretVal int
switch b {
case true:
cbool = 1
case false:
cbool = 0
}
cretVal = lib.AtGradSetEnabled(cbool)
if err = TorchErr(); err != nil {
return retVal, err
}
switch cretVal {
case 0:
retVal = false
break
case 1:
retVal = true
break
// case -1: // should be unreachable as error is captured above with TorchrErr()
// err = fmt.Errorf("Cannot set grad enable. \n")
// return retVal, err
// default: // should be unreachable as error is captured above with TorchrErr()
// err = fmt.Errorf("Cannot set grad enable. \n")
// return retVal, err
}
return retVal, nil
}
// MustGradSetEnabled sets globally whether GradMode gradient accumuation is enable or not.
// It returns PREVIOUS state of Grad before setting. It will be panic if error
func MustGradSetEnabled(b bool) (retVal bool) {
retVal, err := GradSetEnabled(b)
if err != nil {
log.Fatal(err)
}
return retVal
}
// NoGrad runs a closure without keeping track of gradients.
func NoGrad(fn interface{}) (retVal interface{}, err error) {
// Switch off Grad
prev := MustGradSetEnabled(false)
// Analyze input as function. If not, throw error
f, err := NewFunc(fn)
if err != nil {
return retVal, nil
}
// invokes the function
retVal = f.Invoke()
// Switch on Grad
_ = MustGradSetEnabled(prev)
return retVal, nil
}
// NoGradGuard is a RAII guard that prevents gradient tracking until deallocated.
type NoGradGuard struct {
enabled bool
}
// Disables gradient tracking, this will be enabled back when the
// returned value gets deallocated.
func (ngg NoGradGuard) NoGradGuard() NoGradGuard {
return NoGradGuard{enabled: MustGradSetEnabled(false)}
}
// Drop drops the NoGradGuard state.
func (ngg NoGradGuard) Drop() {
MustGradSetEnabled(ngg.enabled)
}
// Reduction type is an enum-like type
type Reduction int
const (
// Do not reduce
ReduceNone Reduction = iota
// Mean of losses
ReduceMean
// Sum of losses
ReduceSum
// Escape hatch in case new options become available
Other
)
func (r Reduction) ToInt() (retVal int) {
switch r {
case ReduceNone:
return 0
case ReduceMean:
return 1
case ReduceSum:
return 2
case Other:
return 3
}
return
}