feat(wrapper/tensor-generated-sample.go): backward and grad functions

This commit is contained in:
sugarme 2020-06-08 07:31:07 +10:00
parent 1c1122c4ea
commit c85aa7d6c4
5 changed files with 382 additions and 0 deletions

View File

@ -0,0 +1,39 @@
package main
import (
// "fmt"
// "log"
wrapper "github.com/sugarme/gotch/wrapper"
)
func main() {
x := wrapper.TensorFrom([]float64{2.0})
x = x.MustSetRequiresGrad(true)
x.ZeroGrad()
xy := wrapper.TensorFrom([]float64{2.0})
xz := wrapper.TensorFrom([]float64{3.0})
y := x.MustMul(xy)
z := x.MustMul(xz)
y.Backward()
xgrad := x.MustGrad()
xgrad.Print()
z.Backward()
xgrad = x.MustGrad()
xgrad.Print()
}
/* // Compute a second order derivative using run_backward.
* let mut x = Tensor::from(42.0).set_requires_grad(true);
* let y = &x * &x * &x + &x + &x * &x;
* x.zero_grad();
* let dy_over_dx = Tensor::run_backward(&[y], &[&x], true, true);
* assert_eq!(dy_over_dx.len(), 1);
* let dy_over_dx = &dy_over_dx[0];
* dy_over_dx.backward();
* let dy_over_dx2 = x.grad();
* assert_eq!(f64::from(&dy_over_dx2), 254.0); */

View File

@ -30,3 +30,34 @@ func AtDevice(ts Ctensor) int {
cint := C.at_device(ts)
return *(*int)(unsafe.Pointer(&cint))
}
// void atg_grad(tensor *, tensor self);
func AtgGrad(ptr *Ctensor, self Ctensor) {
C.atg_grad(ptr, self)
}
// void atg_detach_(tensor *, tensor self);
func AtgDetach_(ptr *Ctensor, self Ctensor) {
C.atg_detach_(ptr, self)
}
// void atg_zero_(tensor *, tensor self);
func AtgZero_(ptr *Ctensor, self Ctensor) {
C.atg_zero_(ptr, self)
}
// void atg_set_requires_grad(tensor *, tensor self, int r);
func AtgSetRequiresGrad(ptr *Ctensor, self Ctensor, r int) {
cr := *(*C.int)(unsafe.Pointer(&r))
C.atg_set_requires_grad(ptr, self, cr)
}
// void atg_mul(tensor *, tensor self, tensor other);
func AtgMul(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_mul(ptr, self, other)
}
// void atg_add(tensor *, tensor self, tensor other);
func AtgAdd(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_add(ptr, self, other)
}

View File

@ -113,3 +113,41 @@ func AtRequiresGrad(ts Ctensor) bool {
retVal := C.at_requires_grad((C.tensor)(ts))
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_defined(tensor);
func AtDefined(ts Ctensor) bool {
retVal := C.at_defined((C.tensor)(ts))
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_is_sparse(tensor);
func AtIsSparse(ts Ctensor) bool {
retVal := C.at_is_sparse((C.tensor)(ts))
return *(*bool)(unsafe.Pointer(&retVal))
}
// void at_backward(tensor, int, int);
func AtBackward(ts Ctensor, keepGraph int, createGraph int) {
ctensor := (C.tensor)(ts)
ckeepGraph := *(*C.int)(unsafe.Pointer(&keepGraph))
ccreateGraph := *(*C.int)(unsafe.Pointer(&createGraph))
C.at_backward(ctensor, ckeepGraph, ccreateGraph)
}
/*
* void at_run_backward(tensor *tensors,
* int ntensors,
* tensor *inputs,
* int ninputs,
* tensor *outputs,
* int keep_graph,
* int create_graph);
* */
func AtRunBackward(tensorsPtr *Ctensor, ntensors int, inputsPtr *Ctensor, ninputs int, outputsPtr *Ctensor, keepGraph int, createGraph int) {
cntensors := *(*C.int)(unsafe.Pointer(&ntensors))
cninputs := *(*C.int)(unsafe.Pointer(&ninputs))
ckeepGraph := *(*C.int)(unsafe.Pointer(&keepGraph))
ccreateGraph := *(*C.int)(unsafe.Pointer(&createGraph))
C.at_run_backward(tensorsPtr, cntensors, inputsPtr, cninputs, outputsPtr, ckeepGraph, ccreateGraph)
}

View File

@ -5,6 +5,7 @@ package wrapper
import "C"
import (
"log"
"unsafe"
gt "github.com/sugarme/gotch"
@ -38,3 +39,165 @@ func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) {
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustMatMul(other Tensor) (retVal Tensor) {
retVal, err := ts.Matmul(other)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Grad() (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgGrad(ptr, ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustGrad() (retVal Tensor) {
retVal, err := ts.Grad()
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Detach_() (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgDetach_(ptr, ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustDetach_() (retVal Tensor) {
retVal, err := ts.Detach_()
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Zero_() (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgZero_(ptr, ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustZero_() (retVal Tensor) {
retVal, err := ts.Zero_()
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) SetRequiresGrad(rb bool) (retVal Tensor, err error) {
var r int = 0
if rb {
r = 1
}
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgSetRequiresGrad(ptr, ts.ctensor, r)
if err = TorchErr(); err != nil {
return retVal, err
}
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustSetRequiresGrad(rb bool) (retVal Tensor) {
retVal, err := ts.SetRequiresGrad(rb)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Mul(other Tensor) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgMul(ptr, ts.ctensor, other.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustMul(other Tensor) (retVal Tensor) {
retVal, err := ts.Mul(other)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Add(other Tensor) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgAdd(ptr, ts.ctensor, other.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) {
retVal, err := ts.Add(other)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts *Tensor) AddG(other Tensor) (err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgAdd(ptr, ts.ctensor, other.ctensor)
if err = TorchErr(); err != nil {
return err
}
ts = &Tensor{ctensor: *ptr}
return nil
}
func (ts *Tensor) MustAddG(other Tensor) {
err := ts.AddG(other)
if err != nil {
log.Fatal(err)
}
}

View File

@ -181,6 +181,14 @@ func OfSlice(data interface{}) (retVal Tensor, err error) {
return retVal, nil
}
func TensorFrom(data interface{}) (retVal Tensor) {
retVal, err := OfSlice(data)
if err != nil {
log.Fatal(err)
}
return retVal
}
// Print prints tensor values to console.
//
// NOTE: it is printed from C and will print ALL elements of tensor
@ -340,3 +348,106 @@ func (ts Tensor) DataPtr() (retVal unsafe.Pointer, err error) {
return retVal, nil
}
// Defined returns true is the tensor is defined.
func (ts Tensor) Defined() (retVal bool, err error) {
retVal = lib.AtDefined(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return retVal, nil
}
func (ts Tensor) MustDefined() (retVal bool) {
retVal, err := ts.Defined()
if err != nil {
log.Fatal(err)
}
return retVal
}
// IsSparse returns true is the tensor is spare.
func (ts Tensor) IsSparse() (retVal bool, err error) {
retVal = lib.AtIsSparse(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return retVal, nil
}
// ZeroGrad zeroes the gradient tensor attached to this tensor if defined.
func (ts Tensor) ZeroGrad() {
grad := ts.MustGrad()
if grad.MustDefined() {
// TODO: can we chain them?
// grad.MustDetach_().MustZero_()
// https://www.calhoun.io/using-functional-options-instead-of-method-chaining-in-go/
detach := grad.MustDetach_()
_ = detach.MustZero_()
}
}
// Backward runs the backward pass, populating the gradient tensors for tensors
// which gradients are tracked.
//
// Gradients tracking can be turned on via `SetRequiresGrad`.
func (ts Tensor) Backward() (err error) {
lib.AtBackward(ts.ctensor, 0, 0)
if err = TorchErr(); err != nil {
return err
}
return nil
}
func (ts Tensor) MustBackward() {
if err := ts.Backward(); err != nil {
log.Fatal(err)
}
}
// RunBackward runs the backward ...
func RunBackward(tensors []Tensor, inputs []Tensor, keepGraphB bool, createGraphB bool) (retVal []Tensor, err error) {
// NOTE: outputs is a slice of tensors with length = len(inputs)
// outputsPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
// defer C.free(unsafe.Pointer(outputsPtr))
var outputsPtr []*lib.Ctensor
// TODO: Are they allocated continouslly???
for i := 0; i < len(inputs); i++ {
outputPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
// defer C.free(unsafe.Pointer(outputPtr))
outputsPtr = append(outputsPtr, outputPtr)
// retVal = append(retVal, Tensor{ctensor: *outputPtr})
}
// Get first element pointer
ctensor := tensors[0].ctensor
cinput := inputs[0].ctensor
tensorsPtr := (*lib.Ctensor)(unsafe.Pointer(&ctensor))
inputsPtr := (*lib.Ctensor)(unsafe.Pointer(&cinput))
var keepGraph int = 0
if keepGraphB {
keepGraph = 1
}
var createGraph int = 0
if createGraphB {
createGraph = 1
}
lib.AtRunBackward(tensorsPtr, len(tensors), inputsPtr, len(inputs), outputsPtr[0], keepGraph, createGraph)
if err = TorchErr(); err != nil {
return retVal, err
}
for i := 0; i < len(inputs); i++ {
outputPtr := outputsPtr[i]
retVal = append(retVal, Tensor{ctensor: *outputPtr})
}
return retVal, nil
}