feat(wrapper/util): Func struct for function analysis, feat(wrapper/tensor): added more method
This commit is contained in:
parent
23150953d9
commit
d6346994e7
|
@ -1,7 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
// "fmt"
|
||||
"fmt"
|
||||
// "log"
|
||||
|
||||
wrapper "github.com/sugarme/gotch/wrapper"
|
||||
|
@ -25,6 +25,11 @@ func main() {
|
|||
xgrad = x.MustGrad()
|
||||
xgrad.Print() // [5.0] due to accumulated 2.0 + 3.0
|
||||
|
||||
isGradEnabled := wrapper.MustGradSetEnabled(false)
|
||||
fmt.Printf("Previous GradMode enabled state: %v\n", isGradEnabled)
|
||||
isGradEnabled = wrapper.MustGradSetEnabled(true)
|
||||
fmt.Printf("Previous GradMode enabled state: %v\n", isGradEnabled)
|
||||
|
||||
}
|
||||
|
||||
/* // Compute a second order derivative using run_backward.
|
||||
|
|
|
@ -311,3 +311,10 @@ func AtFree(ts Ctensor) {
|
|||
ctensor := (C.tensor)(ts)
|
||||
C.at_free(ctensor)
|
||||
}
|
||||
|
||||
//int at_grad_set_enabled(int b);
|
||||
func AtGradSetEnabled(b int) int {
|
||||
cbool := *(*C.int)(unsafe.Pointer(&b))
|
||||
cretVal := C.at_grad_set_enabled(cbool)
|
||||
return *(*int)(unsafe.Pointer(&cretVal))
|
||||
}
|
||||
|
|
|
@ -831,3 +831,70 @@ func (ts Tensor) MustDrop() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// GradSetEnabled sets globally whether GradMode gradient accumulation is enable or not.
|
||||
// It returns PREVIOUS state of Grad before setting.
|
||||
func GradSetEnabled(b bool) (retVal bool, err error) {
|
||||
|
||||
var cbool, cretVal int
|
||||
switch b {
|
||||
case true:
|
||||
cbool = 1
|
||||
case false:
|
||||
cbool = 0
|
||||
}
|
||||
|
||||
cretVal = lib.AtGradSetEnabled(cbool)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
switch cretVal {
|
||||
case 0:
|
||||
retVal = false
|
||||
break
|
||||
case 1:
|
||||
retVal = true
|
||||
break
|
||||
// case -1: // should be unreachable as error is captured above with TorchrErr()
|
||||
// err = fmt.Errorf("Cannot set grad enable. \n")
|
||||
// return retVal, err
|
||||
// default: // should be unreachable as error is captured above with TorchrErr()
|
||||
// err = fmt.Errorf("Cannot set grad enable. \n")
|
||||
// return retVal, err
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// MustGradSetEnabled sets globally whether GradMode gradient accumuation is enable or not.
|
||||
// It returns PREVIOUS state of Grad before setting. It will be panic if error
|
||||
func MustGradSetEnabled(b bool) (retVal bool) {
|
||||
retVal, err := GradSetEnabled(b)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// NoGrad runs a closure without keeping track of gradients.
|
||||
func NoGrad(fn interface{}) (retVal interface{}, err error) {
|
||||
|
||||
// Switch off Grad
|
||||
prev := MustGradSetEnabled(false)
|
||||
|
||||
// Analyze input as function. If not, throw error
|
||||
f, err := NewFunc(fn)
|
||||
if err != nil {
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// invokes the function
|
||||
retVal = f.Invoke()
|
||||
|
||||
// Switch on Grad
|
||||
_ = MustGradSetEnabled(prev)
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
// "log"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
||||
|
@ -364,3 +365,99 @@ func flattenData(data interface{}, round int, flat []interface{}) (f []interface
|
|||
|
||||
return flatData, nil
|
||||
}
|
||||
|
||||
// InvokeFn reflects and invokes a function of interface type.
|
||||
func InvokeFnWithArgs(fn interface{}, args ...string) {
|
||||
v := reflect.ValueOf(fn)
|
||||
rargs := make([]reflect.Value, len(args))
|
||||
for i, a := range args {
|
||||
rargs[i] = reflect.ValueOf(a)
|
||||
}
|
||||
v.Call(rargs)
|
||||
}
|
||||
|
||||
// Func struct contains information of a function
|
||||
type FuncInfo struct {
|
||||
Signature string
|
||||
InArgs []reflect.Value
|
||||
OutArgs []reflect.Value
|
||||
IsVariadic bool
|
||||
}
|
||||
|
||||
type Func struct {
|
||||
typ reflect.Type
|
||||
val reflect.Value
|
||||
meta FuncInfo
|
||||
}
|
||||
|
||||
func NewFunc(fn interface{}) (retVal Func, err error) {
|
||||
meta, err := getFuncInfo(fn)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Func{
|
||||
typ: reflect.TypeOf(fn),
|
||||
val: reflect.ValueOf(fn),
|
||||
meta: meta,
|
||||
}
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// getFuncInfo analyzes input of interface type and returns function information
|
||||
// in FuncInfo struct. It returns error if input is not a function type under
|
||||
// the hood.
|
||||
func getFuncInfo(fn interface{}) (retVal FuncInfo, err error) {
|
||||
fnVal := reflect.ValueOf(fn)
|
||||
fnTyp := reflect.TypeOf(fn)
|
||||
|
||||
// First, check whether input is a function type
|
||||
if fnVal.Kind() != reflect.Func {
|
||||
err = fmt.Errorf("Input is not a function.")
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
// get number of input and output arguments of function
|
||||
numIn := fnTyp.NumIn() // inbound parameters
|
||||
numOut := fnTyp.NumOut() // outbound parameters
|
||||
isVariadic := fnTyp.IsVariadic() // whether function is a variadic func
|
||||
fnSig := fnTyp.String() // function signature
|
||||
|
||||
// get input and ouput arguments values (reflect.Value type)
|
||||
var inArgs []reflect.Value
|
||||
var outArgs []reflect.Value
|
||||
|
||||
for i := 0; i < numIn; i++ {
|
||||
t := fnTyp.In(i) // reflect.Type
|
||||
|
||||
inArgs = append(inArgs, reflect.ValueOf(t))
|
||||
}
|
||||
|
||||
for i := 0; i < numOut; i++ {
|
||||
t := fnTyp.Out(i) // reflect.Type
|
||||
outArgs = append(outArgs, reflect.ValueOf(t))
|
||||
}
|
||||
|
||||
retVal = FuncInfo{
|
||||
Signature: fnSig,
|
||||
InArgs: inArgs,
|
||||
OutArgs: outArgs,
|
||||
IsVariadic: isVariadic,
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// Info analyzes input of interface type and returns function information
|
||||
// in FuncInfo struct. It returns error if input is not a function type under
|
||||
// the hood. It will be panic if input is not a function
|
||||
func (f *Func) Info() (retVal FuncInfo) {
|
||||
return f.meta
|
||||
}
|
||||
|
||||
func (f *Func) Invoke() interface{} {
|
||||
// call function with input parameters
|
||||
// TODO: return vals are []reflect.Value
|
||||
// How do we match them to output order of signature function
|
||||
return f.val.Call(f.meta.InArgs)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user