gotch/ts/patch.go
Goncalves Henriques, Andre (UG - Computer Science) 9257404edd Move the name of the module
2024-04-21 15:15:00 +01:00

584 lines
15 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package ts
// #include "stdlib.h"
import "C"
import (
"log"
"unsafe"
lib "git.andr3h3nriqu3s.com/andr3/gotch/libtch"
)
// NOTE. This is a temporarily patched to make it run.
// TODO. make change at generator for []*Tensor input
func (ts *Tensor) Lstm(hxData []*Tensor, paramsData []*Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c *Tensor, err error) {
// NOTE: `atg_lstm` will create 3 consecutive Ctensors in memory of C land. The first
// Ctensor will have address given by `ctensorPtr1` here.
// The next pointers can be calculated based on `ctensorPtr1`
ctensorPtr1 := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
ctensorPtr2 := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(ctensorPtr1)) + unsafe.Sizeof(ctensorPtr1)))
ctensorPtr3 := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(ctensorPtr2)) + unsafe.Sizeof(ctensorPtr1)))
var chxData []lib.Ctensor
for _, t := range hxData {
chxData = append(chxData, t.ctensor)
}
var cparamsData []lib.Ctensor
for _, t := range paramsData {
cparamsData = append(cparamsData, t.ctensor)
}
var chasBiases int32 = 0
if hasBiases {
chasBiases = 1
}
var ctrain int32 = 0
if train {
ctrain = 1
}
var cbidirectional int32 = 0
if bidirectional {
cbidirectional = 1
}
var cbatchFirst int32 = 0
if batchFirst {
cbatchFirst = 1
}
lib.AtgLstm(ctensorPtr1, ts.ctensor, chxData, len(hxData), cparamsData, len(paramsData), chasBiases, numLayers, dropout, ctrain, cbidirectional, cbatchFirst)
err = TorchErr()
if err != nil {
return output, h, c, err
}
return newTensor(*ctensorPtr1), newTensor(*ctensorPtr2), newTensor(*ctensorPtr3), nil
}
func (ts *Tensor) MustLstm(hxData []*Tensor, paramsData []*Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c *Tensor) {
output, h, c, err := ts.Lstm(hxData, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst)
if err != nil {
log.Fatal(err)
}
return output, h, c
}
func (ts *Tensor) Gru(hx *Tensor, paramsData []*Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h *Tensor, err error) {
// NOTE: `atg_gru` will create 2 consecutive Ctensors in memory of C land.
// The first Ctensor will have address given by `ctensorPtr1` here.
// The next pointer can be calculated based on `ctensorPtr1`
ctensorPtr1 := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
ctensorPtr2 := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(ctensorPtr1)) + unsafe.Sizeof(ctensorPtr1)))
var cparamsData []lib.Ctensor
for _, t := range paramsData {
cparamsData = append(cparamsData, t.ctensor)
}
var chasBiases int32 = 0
if hasBiases {
chasBiases = 1
}
var ctrain int32 = 0
if train {
ctrain = 1
}
var cbidirectional int32 = 0
if bidirectional {
cbidirectional = 1
}
var cbatchFirst int32 = 0
if batchFirst {
cbatchFirst = 1
}
lib.AtgGru(ctensorPtr1, ts.ctensor, hx.ctensor, cparamsData, len(paramsData), chasBiases, numLayers, dropout, ctrain, cbidirectional, cbatchFirst)
err = TorchErr()
if err != nil {
return output, h, err
}
return newTensor(*ctensorPtr1), newTensor(*ctensorPtr2), nil
}
func (ts *Tensor) MustGru(hx *Tensor, paramsData []*Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h *Tensor) {
output, h, err := ts.Gru(hx, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst)
if err != nil {
log.Fatal(err)
}
return output, h
}
func (ts *Tensor) TopK(k int64, dim int64, largest bool, sorted bool) (ts1, ts2 *Tensor, err error) {
// NOTE: `lib.AtgTopk` will return 2 tensors in C memory. First tensor pointer
// is given by ctensorPtr1
ctensorPtr1 := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
ctensorPtr2 := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(ctensorPtr1)) + unsafe.Sizeof(ctensorPtr1)))
var clargest int32 = 0
if largest {
clargest = 1
}
var csorted int32 = 0
if sorted {
csorted = 1
}
lib.AtgTopk(ctensorPtr1, ts.ctensor, k, dim, clargest, csorted)
err = TorchErr()
if err != nil {
return ts1, ts2, err
}
return newTensor(*ctensorPtr1), newTensor(*ctensorPtr2), nil
}
func (ts *Tensor) MustTopK(k int64, dim int64, largest bool, sorted bool) (ts1, ts2 *Tensor) {
ts1, ts2, err := ts.TopK(k, dim, largest, sorted)
if err != nil {
log.Fatal(err)
}
return ts1, ts2
}
// NOTE. `NLLLoss` is a version of `NllLoss` in tensor-generated
// with default weight, reduction and ignoreIndex
func (ts *Tensor) NLLLoss(target *Tensor, del bool) (retVal *Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
if del {
defer ts.MustDrop()
}
reduction := int64(1) // Mean of loss
ignoreIndex := int64(-100)
// defer C.free(unsafe.Pointer(ptr))
lib.AtgNllLoss(ptr, ts.ctensor, target.ctensor, nil, reduction, ignoreIndex)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = newTensor(*ptr)
return retVal, nil
}
func (ts *Tensor) MustNLLLoss(target *Tensor, del bool) (retVal *Tensor) {
retVal, err := ts.NLLLoss(target, del)
if err != nil {
log.Fatal(err)
}
return retVal
}
// NOTE: the following 9 APIs are missing from `tensor-generated.go` with
// pattern of **return tensor pointer**: `tensor *atg_FUNCTION_NAME()`.
// The returning tensor pointer actually is the FIRST element of a vector
// of C tensor pointers. Next pointer will be calculated from the first.
// In C land, verifying a valid pointer is to check whether it points to **NULL**.
//
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);
// tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim);
// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len);
// tensor *atg_nonzero_numpy(tensor self);
// tensor *atg_split(tensor self, int64_t split_size, int64_t dim);
// tensor *atg_split_with_sizes(tensor self, int64_t *split_sizes_data, int split_sizes_len, int64_t dim);
// tensor *atg_unbind(tensor self, int64_t dim);
// tensor *atg_where(tensor condition);
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
func AlignTensors(tensors []*Tensor) (retVal []*Tensor, err error) {
var ctensors []lib.Ctensor
for _, t := range tensors {
ctensors = append(ctensors, t.ctensor)
}
ctensorsPtr := lib.AtgAlignTensors(ctensors, len(ctensors))
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func MustAlignTensors(tensors []*Tensor, del bool) (retVal []*Tensor) {
if del {
for _, t := range tensors {
defer t.MustDrop()
}
}
retVal, err := AlignTensors(tensors)
if err != nil {
log.Fatal(err)
}
return retVal
}
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);
func BroadcastTensors(tensors []*Tensor) (retVal []*Tensor, err error) {
var ctensors []lib.Ctensor
for _, t := range tensors {
ctensors = append(ctensors, t.ctensor)
}
ctensorsPtr := lib.AtgBroadcastTensors(ctensors, len(ctensors))
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func MustBroadcastTensors(tensors []*Tensor, del bool) (retVal []*Tensor) {
if del {
for _, t := range tensors {
defer t.MustDrop()
}
}
retVal, err := BroadcastTensors(tensors)
if err != nil {
log.Fatal(err)
}
return retVal
}
// tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim);
func (ts *Tensor) Chunk(chunks int64, dim int64) (retVal []*Tensor, err error) {
ctensorsPtr := lib.AtgChunk(ts.ctensor, chunks, dim)
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func (ts *Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []*Tensor) {
if del {
defer ts.MustDrop()
}
retVal, err := ts.Chunk(chunks, dim)
if err != nil {
log.Fatal(err)
}
return retVal
}
// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len);
func Meshgrid(tensors []*Tensor) (retVal []*Tensor, err error) {
var ctensors []lib.Ctensor
for _, t := range tensors {
ctensors = append(ctensors, t.ctensor)
}
ctensorsPtr := lib.AtgMeshgrid(ctensors, len(ctensors))
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func MustMeshgrid(tensors []*Tensor) (retVal []*Tensor) {
retVal, err := Meshgrid(tensors)
if err != nil {
log.Fatal(err)
}
return retVal
}
// tensor *atg_nonzero_numpy(tensor self);
func (ts *Tensor) NonzeroNumpy() (retVal []*Tensor, err error) {
ctensorsPtr := lib.AtgNonzeroNumpy(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func (ts *Tensor) MustNonzeroNumpy(del bool) (retVal []*Tensor) {
if del {
defer ts.MustDrop()
}
retVal, err := ts.NonzeroNumpy()
if err != nil {
log.Fatal(err)
}
return retVal
}
// Split splits tensor into chunks
//
// Parameters:
// - splitSize size of a single chunk
// - dim dimension along which to split the tensor.
//
// Ref. https://pytorch.org/docs/stable/generated/torch.split.html
func (ts *Tensor) Split(splitSize, dim int64) (retVal []*Tensor, err error) {
ctensorsPtr := lib.AtgSplit(ts.ctensor, splitSize, dim)
if err = TorchErr(); err != nil {
return retVal, err
}
// NOTE: ctensorsPtr is a c-pointer to a vector of tensors. The first
// C tensor is the `ctensorsPtr` value. The next pointer will be
// calculated from there. The vector of tensors will end if the calculated
// pointer value is `null`.
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func (ts *Tensor) MustSplit(splitSize, dim int64, del bool) (retVal []*Tensor) {
if del {
defer ts.MustDrop()
}
retVal, err := ts.Split(splitSize, dim)
if err != nil {
log.Fatal(err)
}
return retVal
}
// SplitWithSizes splits tensor into chunks
//
// Parameters:
// - splitSizes slice of sizes for each chunk
// - dim dimension along which to split the tensor.
//
// Ref. https://pytorch.org/docs/stable/generated/torch.split.html
func (ts *Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []*Tensor, err error) {
ctensorsPtr := lib.AtgSplitWithSizes(ts.ctensor, splitSizes, len(splitSizes), dim)
if err = TorchErr(); err != nil {
return retVal, err
}
// NOTE: ctensorsPtr is a c-pointer to a vector of tensors. The first
// C tensor is the `ctensorsPtr` value. The next pointer will be
// calculated from there. The vector of tensors will end if the calculated
// pointer value is `null`.
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func (ts *Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (retVal []*Tensor) {
if del {
defer ts.MustDrop()
}
retVal, err := ts.SplitWithSizes(splitSizes, dim)
if err != nil {
log.Fatal(err)
}
return retVal
}
// tensor *atg_unbind(tensor self, int64_t dim);
func (ts *Tensor) Unbind(dim int64) (retVal []*Tensor, err error) {
ctensorsPtr := lib.AtgUnbind(ts.ctensor, dim)
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func (ts *Tensor) MustUnbind(dim int64, del bool) (retVal []*Tensor) {
if del {
defer ts.MustDrop()
}
retVal, err := ts.Unbind(dim)
if err != nil {
log.Fatal(err)
}
return retVal
}
// tensor *atg_where(tensor condition);
func Where(condition Tensor) (retVal []*Tensor, err error) {
ctensorsPtr := lib.AtgWhere(condition.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func MustWhere(condition Tensor, del bool) (retVal []*Tensor) {
if del {
defer condition.MustDrop()
}
retVal, err := Where(condition)
if err != nil {
log.Fatal(err)
}
return retVal
}
// NOTE. patches for APIs `agt_` missing in tensor/ but existing in lib
// ====================================================================
// // void atg_lstsq(tensor *, tensor self, tensor A);
// func (ts *Tensor) Lstsq(a *Tensor, del bool) (retVal *Tensor, err error) {
// if del {
// defer ts.MustDrop()
// }
// ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
//
// lib.AtgLstsq(ptr, ts.ctensor, a.ctensor)
// if err = TorchErr(); err != nil {
// return retVal, err
// }
// retVal = &Tensor{ctensor: *ptr}
//
// return retVal, err
// }
//
// func (ts *Tensor) MustLstsq(a *Tensor, del bool) (retVal *Tensor) {
// retVal, err := ts.Lstsq(a, del)
// if err != nil {
// log.Fatal(err)
// }
//
// return retVal
// }