feat(wrapper/util): Func struct for function analysis, feat(wrapper/tensor): added more method
This commit is contained in:
parent
23150953d9
commit
d6346994e7
|
@ -1,7 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
// "fmt"
|
"fmt"
|
||||||
// "log"
|
// "log"
|
||||||
|
|
||||||
wrapper "github.com/sugarme/gotch/wrapper"
|
wrapper "github.com/sugarme/gotch/wrapper"
|
||||||
|
@ -25,6 +25,11 @@ func main() {
|
||||||
xgrad = x.MustGrad()
|
xgrad = x.MustGrad()
|
||||||
xgrad.Print() // [5.0] due to accumulated 2.0 + 3.0
|
xgrad.Print() // [5.0] due to accumulated 2.0 + 3.0
|
||||||
|
|
||||||
|
isGradEnabled := wrapper.MustGradSetEnabled(false)
|
||||||
|
fmt.Printf("Previous GradMode enabled state: %v\n", isGradEnabled)
|
||||||
|
isGradEnabled = wrapper.MustGradSetEnabled(true)
|
||||||
|
fmt.Printf("Previous GradMode enabled state: %v\n", isGradEnabled)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* // Compute a second order derivative using run_backward.
|
/* // Compute a second order derivative using run_backward.
|
||||||
|
|
|
@ -311,3 +311,10 @@ func AtFree(ts Ctensor) {
|
||||||
ctensor := (C.tensor)(ts)
|
ctensor := (C.tensor)(ts)
|
||||||
C.at_free(ctensor)
|
C.at_free(ctensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//int at_grad_set_enabled(int b);
|
||||||
|
func AtGradSetEnabled(b int) int {
|
||||||
|
cbool := *(*C.int)(unsafe.Pointer(&b))
|
||||||
|
cretVal := C.at_grad_set_enabled(cbool)
|
||||||
|
return *(*int)(unsafe.Pointer(&cretVal))
|
||||||
|
}
|
||||||
|
|
|
@ -831,3 +831,70 @@ func (ts Tensor) MustDrop() {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GradSetEnabled sets globally whether GradMode gradient accumulation is enable or not.
|
||||||
|
// It returns PREVIOUS state of Grad before setting.
|
||||||
|
func GradSetEnabled(b bool) (retVal bool, err error) {
|
||||||
|
|
||||||
|
var cbool, cretVal int
|
||||||
|
switch b {
|
||||||
|
case true:
|
||||||
|
cbool = 1
|
||||||
|
case false:
|
||||||
|
cbool = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
cretVal = lib.AtGradSetEnabled(cbool)
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch cretVal {
|
||||||
|
case 0:
|
||||||
|
retVal = false
|
||||||
|
break
|
||||||
|
case 1:
|
||||||
|
retVal = true
|
||||||
|
break
|
||||||
|
// case -1: // should be unreachable as error is captured above with TorchrErr()
|
||||||
|
// err = fmt.Errorf("Cannot set grad enable. \n")
|
||||||
|
// return retVal, err
|
||||||
|
// default: // should be unreachable as error is captured above with TorchrErr()
|
||||||
|
// err = fmt.Errorf("Cannot set grad enable. \n")
|
||||||
|
// return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustGradSetEnabled sets globally whether GradMode gradient accumuation is enable or not.
|
||||||
|
// It returns PREVIOUS state of Grad before setting. It will be panic if error
|
||||||
|
func MustGradSetEnabled(b bool) (retVal bool) {
|
||||||
|
retVal, err := GradSetEnabled(b)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
// NoGrad runs a closure without keeping track of gradients.
|
||||||
|
func NoGrad(fn interface{}) (retVal interface{}, err error) {
|
||||||
|
|
||||||
|
// Switch off Grad
|
||||||
|
prev := MustGradSetEnabled(false)
|
||||||
|
|
||||||
|
// Analyze input as function. If not, throw error
|
||||||
|
f, err := NewFunc(fn)
|
||||||
|
if err != nil {
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// invokes the function
|
||||||
|
retVal = f.Invoke()
|
||||||
|
|
||||||
|
// Switch on Grad
|
||||||
|
_ = MustGradSetEnabled(prev)
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
// "log"
|
||||||
"reflect"
|
"reflect"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
@ -364,3 +365,99 @@ func flattenData(data interface{}, round int, flat []interface{}) (f []interface
|
||||||
|
|
||||||
return flatData, nil
|
return flatData, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InvokeFn reflects and invokes a function of interface type.
|
||||||
|
func InvokeFnWithArgs(fn interface{}, args ...string) {
|
||||||
|
v := reflect.ValueOf(fn)
|
||||||
|
rargs := make([]reflect.Value, len(args))
|
||||||
|
for i, a := range args {
|
||||||
|
rargs[i] = reflect.ValueOf(a)
|
||||||
|
}
|
||||||
|
v.Call(rargs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Func struct contains information of a function
|
||||||
|
type FuncInfo struct {
|
||||||
|
Signature string
|
||||||
|
InArgs []reflect.Value
|
||||||
|
OutArgs []reflect.Value
|
||||||
|
IsVariadic bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type Func struct {
|
||||||
|
typ reflect.Type
|
||||||
|
val reflect.Value
|
||||||
|
meta FuncInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFunc(fn interface{}) (retVal Func, err error) {
|
||||||
|
meta, err := getFuncInfo(fn)
|
||||||
|
if err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Func{
|
||||||
|
typ: reflect.TypeOf(fn),
|
||||||
|
val: reflect.ValueOf(fn),
|
||||||
|
meta: meta,
|
||||||
|
}
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFuncInfo analyzes input of interface type and returns function information
|
||||||
|
// in FuncInfo struct. It returns error if input is not a function type under
|
||||||
|
// the hood.
|
||||||
|
func getFuncInfo(fn interface{}) (retVal FuncInfo, err error) {
|
||||||
|
fnVal := reflect.ValueOf(fn)
|
||||||
|
fnTyp := reflect.TypeOf(fn)
|
||||||
|
|
||||||
|
// First, check whether input is a function type
|
||||||
|
if fnVal.Kind() != reflect.Func {
|
||||||
|
err = fmt.Errorf("Input is not a function.")
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// get number of input and output arguments of function
|
||||||
|
numIn := fnTyp.NumIn() // inbound parameters
|
||||||
|
numOut := fnTyp.NumOut() // outbound parameters
|
||||||
|
isVariadic := fnTyp.IsVariadic() // whether function is a variadic func
|
||||||
|
fnSig := fnTyp.String() // function signature
|
||||||
|
|
||||||
|
// get input and ouput arguments values (reflect.Value type)
|
||||||
|
var inArgs []reflect.Value
|
||||||
|
var outArgs []reflect.Value
|
||||||
|
|
||||||
|
for i := 0; i < numIn; i++ {
|
||||||
|
t := fnTyp.In(i) // reflect.Type
|
||||||
|
|
||||||
|
inArgs = append(inArgs, reflect.ValueOf(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < numOut; i++ {
|
||||||
|
t := fnTyp.Out(i) // reflect.Type
|
||||||
|
outArgs = append(outArgs, reflect.ValueOf(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = FuncInfo{
|
||||||
|
Signature: fnSig,
|
||||||
|
InArgs: inArgs,
|
||||||
|
OutArgs: outArgs,
|
||||||
|
IsVariadic: isVariadic,
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info analyzes input of interface type and returns function information
|
||||||
|
// in FuncInfo struct. It returns error if input is not a function type under
|
||||||
|
// the hood. It will be panic if input is not a function
|
||||||
|
func (f *Func) Info() (retVal FuncInfo) {
|
||||||
|
return f.meta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Func) Invoke() interface{} {
|
||||||
|
// call function with input parameters
|
||||||
|
// TODO: return vals are []reflect.Value
|
||||||
|
// How do we match them to output order of signature function
|
||||||
|
return f.val.Call(f.meta.InArgs)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user