feat(wrapper/tensor): complete draft of wrapper/tensor

This commit is contained in:
sugarme 2020-06-11 14:11:58 +10:00
parent d6346994e7
commit 3de38ffa27

View File

@ -898,3 +898,47 @@ func NoGrad(fn interface{}) (retVal interface{}, err error) {
return retVal, nil
}
// NoGradGuard is a RAII guard that prevents gradient tracking until deallocated.
type NoGradGuard struct {
enabled bool
}
// Disables gradient tracking, this will be enabled back when the
// returned value gets deallocated.
func (ngg NoGradGuard) NoGradGuard() NoGradGuard {
return NoGradGuard{enabled: MustGradSetEnabled(false)}
}
// Drop drops the NoGradGuard state.
func (ngg NoGradGuard) Drop() {
MustGradSetEnabled(ngg.enabled)
}
// Reduction type is an enum-like type
type Reduction int
const (
// Do not reduce
ReduceNone Reduction = iota
// Mean of losses
ReduceMean
// Sum of losses
ReduceSum
// Escape hatch in case new options become available
Other
)
func (r Reduction) ToInt() (retVal int) {
switch r {
case ReduceNone:
return 0
case ReduceMean:
return 1
case ReduceSum:
return 2
case Other:
return 3
}
return
}