From d6346994e74a8d43c1363f651a605f2912d16875 Mon Sep 17 00:00:00 2001 From: sugarme Date: Thu, 11 Jun 2020 11:57:56 +1000 Subject: [PATCH] feat(wrapper/util): Func struct for function analysis, feat(wrapper/tensor): added more method --- example/tensor-grad/main.go | 7 ++- libtch/tensor.go | 7 +++ wrapper/tensor.go | 67 +++++++++++++++++++++++++ wrapper/util.go | 97 +++++++++++++++++++++++++++++++++++++ 4 files changed, 177 insertions(+), 1 deletion(-) diff --git a/example/tensor-grad/main.go b/example/tensor-grad/main.go index a71dd8a..17af6bc 100644 --- a/example/tensor-grad/main.go +++ b/example/tensor-grad/main.go @@ -1,7 +1,7 @@ package main import ( - // "fmt" + "fmt" // "log" wrapper "github.com/sugarme/gotch/wrapper" @@ -25,6 +25,11 @@ func main() { xgrad = x.MustGrad() xgrad.Print() // [5.0] due to accumulated 2.0 + 3.0 + isGradEnabled := wrapper.MustGradSetEnabled(false) + fmt.Printf("Previous GradMode enabled state: %v\n", isGradEnabled) + isGradEnabled = wrapper.MustGradSetEnabled(true) + fmt.Printf("Previous GradMode enabled state: %v\n", isGradEnabled) + } /* // Compute a second order derivative using run_backward. diff --git a/libtch/tensor.go b/libtch/tensor.go index c1cf34a..cc3a643 100644 --- a/libtch/tensor.go +++ b/libtch/tensor.go @@ -311,3 +311,10 @@ func AtFree(ts Ctensor) { ctensor := (C.tensor)(ts) C.at_free(ctensor) } + +//int at_grad_set_enabled(int b); +func AtGradSetEnabled(b int) int { + cbool := *(*C.int)(unsafe.Pointer(&b)) + cretVal := C.at_grad_set_enabled(cbool) + return *(*int)(unsafe.Pointer(&cretVal)) +} diff --git a/wrapper/tensor.go b/wrapper/tensor.go index f416201..c113bf0 100644 --- a/wrapper/tensor.go +++ b/wrapper/tensor.go @@ -831,3 +831,70 @@ func (ts Tensor) MustDrop() { log.Fatal(err) } } + +// GradSetEnabled sets globally whether GradMode gradient accumulation is enable or not. +// It returns PREVIOUS state of Grad before setting. +func GradSetEnabled(b bool) (retVal bool, err error) { + + var cbool, cretVal int + switch b { + case true: + cbool = 1 + case false: + cbool = 0 + } + + cretVal = lib.AtGradSetEnabled(cbool) + if err = TorchErr(); err != nil { + return retVal, err + } + + switch cretVal { + case 0: + retVal = false + break + case 1: + retVal = true + break + // case -1: // should be unreachable as error is captured above with TorchrErr() + // err = fmt.Errorf("Cannot set grad enable. \n") + // return retVal, err + // default: // should be unreachable as error is captured above with TorchrErr() + // err = fmt.Errorf("Cannot set grad enable. \n") + // return retVal, err + } + + return retVal, nil +} + +// MustGradSetEnabled sets globally whether GradMode gradient accumuation is enable or not. +// It returns PREVIOUS state of Grad before setting. It will be panic if error +func MustGradSetEnabled(b bool) (retVal bool) { + retVal, err := GradSetEnabled(b) + if err != nil { + log.Fatal(err) + } + + return retVal +} + +// NoGrad runs a closure without keeping track of gradients. +func NoGrad(fn interface{}) (retVal interface{}, err error) { + + // Switch off Grad + prev := MustGradSetEnabled(false) + + // Analyze input as function. If not, throw error + f, err := NewFunc(fn) + if err != nil { + return retVal, nil + } + + // invokes the function + retVal = f.Invoke() + + // Switch on Grad + _ = MustGradSetEnabled(prev) + + return retVal, nil +} diff --git a/wrapper/util.go b/wrapper/util.go index 44b963a..ba525d6 100644 --- a/wrapper/util.go +++ b/wrapper/util.go @@ -7,6 +7,7 @@ import ( "bytes" "encoding/binary" "fmt" + // "log" "reflect" "unsafe" @@ -364,3 +365,99 @@ func flattenData(data interface{}, round int, flat []interface{}) (f []interface return flatData, nil } + +// InvokeFn reflects and invokes a function of interface type. +func InvokeFnWithArgs(fn interface{}, args ...string) { + v := reflect.ValueOf(fn) + rargs := make([]reflect.Value, len(args)) + for i, a := range args { + rargs[i] = reflect.ValueOf(a) + } + v.Call(rargs) +} + +// Func struct contains information of a function +type FuncInfo struct { + Signature string + InArgs []reflect.Value + OutArgs []reflect.Value + IsVariadic bool +} + +type Func struct { + typ reflect.Type + val reflect.Value + meta FuncInfo +} + +func NewFunc(fn interface{}) (retVal Func, err error) { + meta, err := getFuncInfo(fn) + if err != nil { + return retVal, err + } + + retVal = Func{ + typ: reflect.TypeOf(fn), + val: reflect.ValueOf(fn), + meta: meta, + } + return retVal, nil +} + +// getFuncInfo analyzes input of interface type and returns function information +// in FuncInfo struct. It returns error if input is not a function type under +// the hood. +func getFuncInfo(fn interface{}) (retVal FuncInfo, err error) { + fnVal := reflect.ValueOf(fn) + fnTyp := reflect.TypeOf(fn) + + // First, check whether input is a function type + if fnVal.Kind() != reflect.Func { + err = fmt.Errorf("Input is not a function.") + return retVal, err + } + + // get number of input and output arguments of function + numIn := fnTyp.NumIn() // inbound parameters + numOut := fnTyp.NumOut() // outbound parameters + isVariadic := fnTyp.IsVariadic() // whether function is a variadic func + fnSig := fnTyp.String() // function signature + + // get input and ouput arguments values (reflect.Value type) + var inArgs []reflect.Value + var outArgs []reflect.Value + + for i := 0; i < numIn; i++ { + t := fnTyp.In(i) // reflect.Type + + inArgs = append(inArgs, reflect.ValueOf(t)) + } + + for i := 0; i < numOut; i++ { + t := fnTyp.Out(i) // reflect.Type + outArgs = append(outArgs, reflect.ValueOf(t)) + } + + retVal = FuncInfo{ + Signature: fnSig, + InArgs: inArgs, + OutArgs: outArgs, + IsVariadic: isVariadic, + } + + return retVal, nil +} + +// Info analyzes input of interface type and returns function information +// in FuncInfo struct. It returns error if input is not a function type under +// the hood. It will be panic if input is not a function +func (f *Func) Info() (retVal FuncInfo) { + return f.meta +} + +func (f *Func) Invoke() interface{} { + // call function with input parameters + // TODO: return vals are []reflect.Value + // How do we match them to output order of signature function + return f.val.Call(f.meta.InArgs) +}