234 lines
4.7 KiB
Go
234 lines
4.7 KiB
Go
// NOTE: this is a sample for OCaml generated code for `tensor-generated.go`
|
|
package wrapper
|
|
|
|
// #include "stdlib.h"
|
|
import "C"
|
|
|
|
import (
|
|
"log"
|
|
"unsafe"
|
|
|
|
gt "github.com/sugarme/gotch"
|
|
lib "github.com/sugarme/gotch/libtch"
|
|
)
|
|
|
|
func (ts Tensor) To(device gt.Device) (retVal Tensor, err error) {
|
|
|
|
// TODO: how to get pointer to CUDA memory???
|
|
// C.cuMemAlloc((*C.ulonglong)(cudaPtr), 1) // 0 byte is invalid
|
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
defer C.free(unsafe.Pointer(ptr))
|
|
|
|
lib.AtgTo((*lib.Ctensor)(ptr), ts.ctensor, int(device.CInt()))
|
|
|
|
if err = TorchErr(); err != nil {
|
|
return retVal, err
|
|
}
|
|
|
|
return Tensor{ctensor: *ptr}, nil
|
|
}
|
|
|
|
func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) {
|
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
defer C.free(unsafe.Pointer(ptr))
|
|
lib.AtgMatmul(ptr, ts.ctensor, other.ctensor)
|
|
|
|
if err = TorchErr(); err != nil {
|
|
return retVal, err
|
|
}
|
|
|
|
return Tensor{ctensor: *ptr}, nil
|
|
}
|
|
|
|
func (ts Tensor) MustMatMul(other Tensor) (retVal Tensor) {
|
|
retVal, err := ts.Matmul(other)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return retVal
|
|
}
|
|
|
|
func (ts Tensor) Grad() (retVal Tensor, err error) {
|
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
defer C.free(unsafe.Pointer(ptr))
|
|
lib.AtgGrad(ptr, ts.ctensor)
|
|
|
|
if err = TorchErr(); err != nil {
|
|
return retVal, err
|
|
}
|
|
|
|
return Tensor{ctensor: *ptr}, nil
|
|
}
|
|
|
|
func (ts Tensor) MustGrad() (retVal Tensor) {
|
|
retVal, err := ts.Grad()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return retVal
|
|
}
|
|
|
|
func (ts Tensor) Detach_() (retVal Tensor, err error) {
|
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
defer C.free(unsafe.Pointer(ptr))
|
|
lib.AtgDetach_(ptr, ts.ctensor)
|
|
|
|
if err = TorchErr(); err != nil {
|
|
return retVal, err
|
|
}
|
|
|
|
return Tensor{ctensor: *ptr}, nil
|
|
}
|
|
|
|
func (ts Tensor) MustDetach_() (retVal Tensor) {
|
|
retVal, err := ts.Detach_()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return retVal
|
|
}
|
|
|
|
func (ts Tensor) Zero_() (retVal Tensor, err error) {
|
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
defer C.free(unsafe.Pointer(ptr))
|
|
lib.AtgZero_(ptr, ts.ctensor)
|
|
|
|
if err = TorchErr(); err != nil {
|
|
return retVal, err
|
|
}
|
|
|
|
return Tensor{ctensor: *ptr}, nil
|
|
}
|
|
|
|
func (ts Tensor) MustZero_() (retVal Tensor) {
|
|
retVal, err := ts.Zero_()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return retVal
|
|
}
|
|
|
|
func (ts Tensor) SetRequiresGrad(rb bool) (retVal Tensor, err error) {
|
|
var r int = 0
|
|
if rb {
|
|
r = 1
|
|
}
|
|
|
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
defer C.free(unsafe.Pointer(ptr))
|
|
|
|
lib.AtgSetRequiresGrad(ptr, ts.ctensor, r)
|
|
|
|
if err = TorchErr(); err != nil {
|
|
return retVal, err
|
|
}
|
|
|
|
return Tensor{ctensor: *ptr}, nil
|
|
}
|
|
|
|
func (ts Tensor) MustSetRequiresGrad(rb bool) (retVal Tensor) {
|
|
retVal, err := ts.SetRequiresGrad(rb)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return retVal
|
|
}
|
|
|
|
func (ts Tensor) Mul(other Tensor) (retVal Tensor, err error) {
|
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
defer C.free(unsafe.Pointer(ptr))
|
|
lib.AtgMul(ptr, ts.ctensor, other.ctensor)
|
|
|
|
if err = TorchErr(); err != nil {
|
|
return retVal, err
|
|
}
|
|
|
|
return Tensor{ctensor: *ptr}, nil
|
|
}
|
|
|
|
func (ts Tensor) MustMul(other Tensor) (retVal Tensor) {
|
|
retVal, err := ts.Mul(other)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return retVal
|
|
}
|
|
|
|
func (ts Tensor) Add(other Tensor) (retVal Tensor, err error) {
|
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
defer C.free(unsafe.Pointer(ptr))
|
|
lib.AtgAdd(ptr, ts.ctensor, other.ctensor)
|
|
|
|
if err = TorchErr(); err != nil {
|
|
return retVal, err
|
|
}
|
|
|
|
return Tensor{ctensor: *ptr}, nil
|
|
}
|
|
|
|
func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) {
|
|
retVal, err := ts.Add(other)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return retVal
|
|
}
|
|
|
|
func (ts Tensor) AddG(other Tensor) (err error) {
|
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
defer C.free(unsafe.Pointer(ptr))
|
|
lib.AtgAdd(ptr, ts.ctensor, other.ctensor)
|
|
|
|
if err = TorchErr(); err != nil {
|
|
return err
|
|
}
|
|
|
|
ts = Tensor{ctensor: *ptr}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ts Tensor) MustAddG(other Tensor) {
|
|
err := ts.AddG(other)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
// Totype casts type of tensor to a new tensor with specified DType
|
|
func (ts Tensor) Totype(dtype gt.DType) (retVal Tensor, err error) {
|
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
defer C.free(unsafe.Pointer(ptr))
|
|
cint, err := gt.DType2CInt(dtype)
|
|
if err != nil {
|
|
return retVal, err
|
|
}
|
|
|
|
lib.AtgTotype(ptr, ts.ctensor, cint)
|
|
if err = TorchErr(); err != nil {
|
|
return retVal, err
|
|
}
|
|
|
|
retVal = Tensor{ctensor: *ptr}
|
|
|
|
return retVal, nil
|
|
}
|
|
|
|
// Totype casts type of tensor to a new tensor with specified DType. It will
|
|
// panic if error
|
|
func (ts Tensor) MustTotype(dtype gt.DType) (retVal Tensor) {
|
|
retVal, err := ts.Totype(dtype)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return retVal
|
|
}
|