feat(wrapper/tensor): complete draft of wrapper/tensor
This commit is contained in:
parent
d6346994e7
commit
3de38ffa27
|
@ -898,3 +898,47 @@ func NoGrad(fn interface{}) (retVal interface{}, err error) {
|
|||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// NoGradGuard is a RAII guard that prevents gradient tracking until deallocated.
|
||||
type NoGradGuard struct {
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// Disables gradient tracking, this will be enabled back when the
|
||||
// returned value gets deallocated.
|
||||
func (ngg NoGradGuard) NoGradGuard() NoGradGuard {
|
||||
return NoGradGuard{enabled: MustGradSetEnabled(false)}
|
||||
}
|
||||
|
||||
// Drop drops the NoGradGuard state.
|
||||
func (ngg NoGradGuard) Drop() {
|
||||
MustGradSetEnabled(ngg.enabled)
|
||||
}
|
||||
|
||||
// Reduction type is an enum-like type
|
||||
type Reduction int
|
||||
|
||||
const (
|
||||
// Do not reduce
|
||||
ReduceNone Reduction = iota
|
||||
// Mean of losses
|
||||
ReduceMean
|
||||
// Sum of losses
|
||||
ReduceSum
|
||||
// Escape hatch in case new options become available
|
||||
Other
|
||||
)
|
||||
|
||||
func (r Reduction) ToInt() (retVal int) {
|
||||
switch r {
|
||||
case ReduceNone:
|
||||
return 0
|
||||
case ReduceMean:
|
||||
return 1
|
||||
case ReduceSum:
|
||||
return 2
|
||||
case Other:
|
||||
return 3
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user