feat(wrapper/tensor-generated-sample.go): backward and grad functions
This commit is contained in:
parent
1c1122c4ea
commit
c85aa7d6c4
39
example/tensor-grad/main.go
Normal file
39
example/tensor-grad/main.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
// "fmt"
|
||||
// "log"
|
||||
|
||||
wrapper "github.com/sugarme/gotch/wrapper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
x := wrapper.TensorFrom([]float64{2.0})
|
||||
x = x.MustSetRequiresGrad(true)
|
||||
x.ZeroGrad()
|
||||
|
||||
xy := wrapper.TensorFrom([]float64{2.0})
|
||||
xz := wrapper.TensorFrom([]float64{3.0})
|
||||
|
||||
y := x.MustMul(xy)
|
||||
z := x.MustMul(xz)
|
||||
|
||||
y.Backward()
|
||||
xgrad := x.MustGrad()
|
||||
xgrad.Print()
|
||||
z.Backward()
|
||||
xgrad = x.MustGrad()
|
||||
xgrad.Print()
|
||||
|
||||
}
|
||||
|
||||
/* // Compute a second order derivative using run_backward.
|
||||
* let mut x = Tensor::from(42.0).set_requires_grad(true);
|
||||
* let y = &x * &x * &x + &x + &x * &x;
|
||||
* x.zero_grad();
|
||||
* let dy_over_dx = Tensor::run_backward(&[y], &[&x], true, true);
|
||||
* assert_eq!(dy_over_dx.len(), 1);
|
||||
* let dy_over_dx = &dy_over_dx[0];
|
||||
* dy_over_dx.backward();
|
||||
* let dy_over_dx2 = x.grad();
|
||||
* assert_eq!(f64::from(&dy_over_dx2), 254.0); */
|
|
@ -30,3 +30,34 @@ func AtDevice(ts Ctensor) int {
|
|||
cint := C.at_device(ts)
|
||||
return *(*int)(unsafe.Pointer(&cint))
|
||||
}
|
||||
|
||||
// void atg_grad(tensor *, tensor self);
|
||||
func AtgGrad(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_grad(ptr, self)
|
||||
}
|
||||
|
||||
// void atg_detach_(tensor *, tensor self);
|
||||
func AtgDetach_(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_detach_(ptr, self)
|
||||
}
|
||||
|
||||
// void atg_zero_(tensor *, tensor self);
|
||||
func AtgZero_(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_zero_(ptr, self)
|
||||
}
|
||||
|
||||
// void atg_set_requires_grad(tensor *, tensor self, int r);
|
||||
func AtgSetRequiresGrad(ptr *Ctensor, self Ctensor, r int) {
|
||||
cr := *(*C.int)(unsafe.Pointer(&r))
|
||||
C.atg_set_requires_grad(ptr, self, cr)
|
||||
}
|
||||
|
||||
// void atg_mul(tensor *, tensor self, tensor other);
|
||||
func AtgMul(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_mul(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_add(tensor *, tensor self, tensor other);
|
||||
func AtgAdd(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_add(ptr, self, other)
|
||||
}
|
||||
|
|
|
@ -113,3 +113,41 @@ func AtRequiresGrad(ts Ctensor) bool {
|
|||
retVal := C.at_requires_grad((C.tensor)(ts))
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// int at_defined(tensor);
|
||||
func AtDefined(ts Ctensor) bool {
|
||||
retVal := C.at_defined((C.tensor)(ts))
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// int at_is_sparse(tensor);
|
||||
func AtIsSparse(ts Ctensor) bool {
|
||||
retVal := C.at_is_sparse((C.tensor)(ts))
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// void at_backward(tensor, int, int);
|
||||
func AtBackward(ts Ctensor, keepGraph int, createGraph int) {
|
||||
ctensor := (C.tensor)(ts)
|
||||
ckeepGraph := *(*C.int)(unsafe.Pointer(&keepGraph))
|
||||
ccreateGraph := *(*C.int)(unsafe.Pointer(&createGraph))
|
||||
|
||||
C.at_backward(ctensor, ckeepGraph, ccreateGraph)
|
||||
}
|
||||
|
||||
/*
|
||||
* void at_run_backward(tensor *tensors,
|
||||
* int ntensors,
|
||||
* tensor *inputs,
|
||||
* int ninputs,
|
||||
* tensor *outputs,
|
||||
* int keep_graph,
|
||||
* int create_graph);
|
||||
* */
|
||||
func AtRunBackward(tensorsPtr *Ctensor, ntensors int, inputsPtr *Ctensor, ninputs int, outputsPtr *Ctensor, keepGraph int, createGraph int) {
|
||||
cntensors := *(*C.int)(unsafe.Pointer(&ntensors))
|
||||
cninputs := *(*C.int)(unsafe.Pointer(&ninputs))
|
||||
ckeepGraph := *(*C.int)(unsafe.Pointer(&keepGraph))
|
||||
ccreateGraph := *(*C.int)(unsafe.Pointer(&createGraph))
|
||||
C.at_run_backward(tensorsPtr, cntensors, inputsPtr, cninputs, outputsPtr, ckeepGraph, ccreateGraph)
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ package wrapper
|
|||
import "C"
|
||||
|
||||
import (
|
||||
"log"
|
||||
"unsafe"
|
||||
|
||||
gt "github.com/sugarme/gotch"
|
||||
|
@ -38,3 +39,165 @@ func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) {
|
|||
|
||||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustMatMul(other Tensor) (retVal Tensor) {
|
||||
retVal, err := ts.Matmul(other)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Grad() (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgGrad(ptr, ts.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustGrad() (retVal Tensor) {
|
||||
retVal, err := ts.Grad()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Detach_() (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgDetach_(ptr, ts.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustDetach_() (retVal Tensor) {
|
||||
retVal, err := ts.Detach_()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Zero_() (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgZero_(ptr, ts.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustZero_() (retVal Tensor) {
|
||||
retVal, err := ts.Zero_()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) SetRequiresGrad(rb bool) (retVal Tensor, err error) {
|
||||
var r int = 0
|
||||
if rb {
|
||||
r = 1
|
||||
}
|
||||
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgSetRequiresGrad(ptr, ts.ctensor, r)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustSetRequiresGrad(rb bool) (retVal Tensor) {
|
||||
retVal, err := ts.SetRequiresGrad(rb)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Mul(other Tensor) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgMul(ptr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustMul(other Tensor) (retVal Tensor) {
|
||||
retVal, err := ts.Mul(other)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Add(other Tensor) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgAdd(ptr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) {
|
||||
retVal, err := ts.Add(other)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts *Tensor) AddG(other Tensor) (err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgAdd(ptr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ts = &Tensor{ctensor: *ptr}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustAddG(other Tensor) {
|
||||
err := ts.AddG(other)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -181,6 +181,14 @@ func OfSlice(data interface{}) (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func TensorFrom(data interface{}) (retVal Tensor) {
|
||||
retVal, err := OfSlice(data)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return retVal
|
||||
}
|
||||
|
||||
// Print prints tensor values to console.
|
||||
//
|
||||
// NOTE: it is printed from C and will print ALL elements of tensor
|
||||
|
@ -340,3 +348,106 @@ func (ts Tensor) DataPtr() (retVal unsafe.Pointer, err error) {
|
|||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// Defined returns true is the tensor is defined.
|
||||
func (ts Tensor) Defined() (retVal bool, err error) {
|
||||
retVal = lib.AtDefined(ts.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustDefined() (retVal bool) {
|
||||
retVal, err := ts.Defined()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// IsSparse returns true is the tensor is spare.
|
||||
func (ts Tensor) IsSparse() (retVal bool, err error) {
|
||||
retVal = lib.AtIsSparse(ts.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// ZeroGrad zeroes the gradient tensor attached to this tensor if defined.
|
||||
func (ts Tensor) ZeroGrad() {
|
||||
grad := ts.MustGrad()
|
||||
if grad.MustDefined() {
|
||||
// TODO: can we chain them?
|
||||
// grad.MustDetach_().MustZero_()
|
||||
// https://www.calhoun.io/using-functional-options-instead-of-method-chaining-in-go/
|
||||
detach := grad.MustDetach_()
|
||||
_ = detach.MustZero_()
|
||||
}
|
||||
}
|
||||
|
||||
// Backward runs the backward pass, populating the gradient tensors for tensors
|
||||
// which gradients are tracked.
|
||||
//
|
||||
// Gradients tracking can be turned on via `SetRequiresGrad`.
|
||||
func (ts Tensor) Backward() (err error) {
|
||||
lib.AtBackward(ts.ctensor, 0, 0)
|
||||
if err = TorchErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustBackward() {
|
||||
if err := ts.Backward(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// RunBackward runs the backward ...
|
||||
func RunBackward(tensors []Tensor, inputs []Tensor, keepGraphB bool, createGraphB bool) (retVal []Tensor, err error) {
|
||||
// NOTE: outputs is a slice of tensors with length = len(inputs)
|
||||
// outputsPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
// defer C.free(unsafe.Pointer(outputsPtr))
|
||||
var outputsPtr []*lib.Ctensor
|
||||
// TODO: Are they allocated continouslly???
|
||||
for i := 0; i < len(inputs); i++ {
|
||||
outputPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
// defer C.free(unsafe.Pointer(outputPtr))
|
||||
outputsPtr = append(outputsPtr, outputPtr)
|
||||
// retVal = append(retVal, Tensor{ctensor: *outputPtr})
|
||||
}
|
||||
|
||||
// Get first element pointer
|
||||
ctensor := tensors[0].ctensor
|
||||
cinput := inputs[0].ctensor
|
||||
tensorsPtr := (*lib.Ctensor)(unsafe.Pointer(&ctensor))
|
||||
inputsPtr := (*lib.Ctensor)(unsafe.Pointer(&cinput))
|
||||
var keepGraph int = 0
|
||||
if keepGraphB {
|
||||
keepGraph = 1
|
||||
}
|
||||
var createGraph int = 0
|
||||
if createGraphB {
|
||||
createGraph = 1
|
||||
}
|
||||
|
||||
lib.AtRunBackward(tensorsPtr, len(tensors), inputsPtr, len(inputs), outputsPtr[0], keepGraph, createGraph)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
for i := 0; i < len(inputs); i++ {
|
||||
outputPtr := outputsPtr[i]
|
||||
retVal = append(retVal, Tensor{ctensor: *outputPtr})
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user