feat(wrapper/util): Func struct for function analysis, feat(wrapper/tensor): added more method

This commit is contained in:
sugarme 2020-06-11 11:57:56 +10:00
parent 23150953d9
commit d6346994e7
4 changed files with 177 additions and 1 deletions

View File

@ -1,7 +1,7 @@
package main
import (
// "fmt"
"fmt"
// "log"
wrapper "github.com/sugarme/gotch/wrapper"
@ -25,6 +25,11 @@ func main() {
xgrad = x.MustGrad()
xgrad.Print() // [5.0] due to accumulated 2.0 + 3.0
isGradEnabled := wrapper.MustGradSetEnabled(false)
fmt.Printf("Previous GradMode enabled state: %v\n", isGradEnabled)
isGradEnabled = wrapper.MustGradSetEnabled(true)
fmt.Printf("Previous GradMode enabled state: %v\n", isGradEnabled)
}
/* // Compute a second order derivative using run_backward.

View File

@ -311,3 +311,10 @@ func AtFree(ts Ctensor) {
ctensor := (C.tensor)(ts)
C.at_free(ctensor)
}
//int at_grad_set_enabled(int b);
func AtGradSetEnabled(b int) int {
cbool := *(*C.int)(unsafe.Pointer(&b))
cretVal := C.at_grad_set_enabled(cbool)
return *(*int)(unsafe.Pointer(&cretVal))
}

View File

@ -831,3 +831,70 @@ func (ts Tensor) MustDrop() {
log.Fatal(err)
}
}
// GradSetEnabled sets globally whether GradMode gradient accumulation is enable or not.
// It returns PREVIOUS state of Grad before setting.
func GradSetEnabled(b bool) (retVal bool, err error) {
var cbool, cretVal int
switch b {
case true:
cbool = 1
case false:
cbool = 0
}
cretVal = lib.AtGradSetEnabled(cbool)
if err = TorchErr(); err != nil {
return retVal, err
}
switch cretVal {
case 0:
retVal = false
break
case 1:
retVal = true
break
// case -1: // should be unreachable as error is captured above with TorchrErr()
// err = fmt.Errorf("Cannot set grad enable. \n")
// return retVal, err
// default: // should be unreachable as error is captured above with TorchrErr()
// err = fmt.Errorf("Cannot set grad enable. \n")
// return retVal, err
}
return retVal, nil
}
// MustGradSetEnabled sets globally whether GradMode gradient accumuation is enable or not.
// It returns PREVIOUS state of Grad before setting. It will be panic if error
func MustGradSetEnabled(b bool) (retVal bool) {
retVal, err := GradSetEnabled(b)
if err != nil {
log.Fatal(err)
}
return retVal
}
// NoGrad runs a closure without keeping track of gradients.
func NoGrad(fn interface{}) (retVal interface{}, err error) {
// Switch off Grad
prev := MustGradSetEnabled(false)
// Analyze input as function. If not, throw error
f, err := NewFunc(fn)
if err != nil {
return retVal, nil
}
// invokes the function
retVal = f.Invoke()
// Switch on Grad
_ = MustGradSetEnabled(prev)
return retVal, nil
}

View File

@ -7,6 +7,7 @@ import (
"bytes"
"encoding/binary"
"fmt"
// "log"
"reflect"
"unsafe"
@ -364,3 +365,99 @@ func flattenData(data interface{}, round int, flat []interface{}) (f []interface
return flatData, nil
}
// InvokeFn reflects and invokes a function of interface type.
func InvokeFnWithArgs(fn interface{}, args ...string) {
v := reflect.ValueOf(fn)
rargs := make([]reflect.Value, len(args))
for i, a := range args {
rargs[i] = reflect.ValueOf(a)
}
v.Call(rargs)
}
// Func struct contains information of a function
type FuncInfo struct {
Signature string
InArgs []reflect.Value
OutArgs []reflect.Value
IsVariadic bool
}
type Func struct {
typ reflect.Type
val reflect.Value
meta FuncInfo
}
func NewFunc(fn interface{}) (retVal Func, err error) {
meta, err := getFuncInfo(fn)
if err != nil {
return retVal, err
}
retVal = Func{
typ: reflect.TypeOf(fn),
val: reflect.ValueOf(fn),
meta: meta,
}
return retVal, nil
}
// getFuncInfo analyzes input of interface type and returns function information
// in FuncInfo struct. It returns error if input is not a function type under
// the hood.
func getFuncInfo(fn interface{}) (retVal FuncInfo, err error) {
fnVal := reflect.ValueOf(fn)
fnTyp := reflect.TypeOf(fn)
// First, check whether input is a function type
if fnVal.Kind() != reflect.Func {
err = fmt.Errorf("Input is not a function.")
return retVal, err
}
// get number of input and output arguments of function
numIn := fnTyp.NumIn() // inbound parameters
numOut := fnTyp.NumOut() // outbound parameters
isVariadic := fnTyp.IsVariadic() // whether function is a variadic func
fnSig := fnTyp.String() // function signature
// get input and ouput arguments values (reflect.Value type)
var inArgs []reflect.Value
var outArgs []reflect.Value
for i := 0; i < numIn; i++ {
t := fnTyp.In(i) // reflect.Type
inArgs = append(inArgs, reflect.ValueOf(t))
}
for i := 0; i < numOut; i++ {
t := fnTyp.Out(i) // reflect.Type
outArgs = append(outArgs, reflect.ValueOf(t))
}
retVal = FuncInfo{
Signature: fnSig,
InArgs: inArgs,
OutArgs: outArgs,
IsVariadic: isVariadic,
}
return retVal, nil
}
// Info analyzes input of interface type and returns function information
// in FuncInfo struct. It returns error if input is not a function type under
// the hood. It will be panic if input is not a function
func (f *Func) Info() (retVal FuncInfo) {
return f.meta
}
func (f *Func) Invoke() interface{} {
// call function with input parameters
// TODO: return vals are []reflect.Value
// How do we match them to output order of signature function
return f.val.Call(f.meta.InArgs)
}