feat(tensor): added more tensor API, WIP(example/nn, mnist/nn)
This commit is contained in:
parent
a636372144
commit
3f569fdcee
|
@ -84,8 +84,8 @@ func runLinear() {
|
|||
loss.MustBackward()
|
||||
|
||||
ts.NoGrad(func() {
|
||||
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
ws.Add_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
bs.Add_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
})
|
||||
|
||||
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
|
||||
|
|
61
example/nn/main.go
Normal file
61
example/nn/main.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func testOptimizer() {
|
||||
|
||||
var data []float64
|
||||
for i := 0; i < 15; i++ {
|
||||
data = append(data, float64(i))
|
||||
}
|
||||
xs, err := ts.NewTensorFromData(data, []int64{int64(len(data)), 1})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ys := xs.MustMul1(ts.FloatScalar(0.42)).MustAdd1(ts.FloatScalar(1.337))
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
|
||||
opt, err := nn.DefaultSGDConfig().Build(vs, 1e-2)
|
||||
if err != nil {
|
||||
log.Fatal("Failed building SGD optimizer")
|
||||
}
|
||||
|
||||
cfg := nn.LinearConfig{
|
||||
WsInit: nn.NewConstInit(0.001),
|
||||
BsInit: nn.NewConstInit(0.001),
|
||||
Bias: true,
|
||||
}
|
||||
|
||||
linear := nn.NewLinear(vs.Root(), 1, 1, cfg)
|
||||
loss := xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
|
||||
initialLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
fmt.Printf("Initial Loss: %.3f\n", initialLoss)
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
loss = xs.Apply(linear)
|
||||
// loss = linear.Forward(xs)
|
||||
loss = loss.MustMseLoss(ys, ts.ReductionMean.ToInt())
|
||||
|
||||
fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}))
|
||||
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
fmt.Printf("Bs: %.3f - Bs Grad: %.3f\n", linear.Bs.MustView([]int64{-1}).MustFloat64Value([]int64{0}), linear.Bs.MustGrad().MustFloat64Value([]int64{0}))
|
||||
fmt.Printf("Ws: %.3f - Ws Grad: %.3f\n", linear.Ws.MustView([]int64{-1}).MustFloat64Value([]int64{0}), linear.Ws.MustGrad().MustFloat64Value([]int64{0}))
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func main() {
|
||||
testOptimizer()
|
||||
}
|
45
example/sgd/main.go
Normal file
45
example/sgd/main.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func myModule(p nn.Path, dim int64) ts.Module {
|
||||
x1 := p.Zeros("x1", []int64{dim})
|
||||
x2 := p.Zeros("x1", []int64{dim})
|
||||
|
||||
return nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||
return xs.MustMul(x1).MustAdd(xs.MustExp().MustMul(x2))
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
|
||||
m := myModule(vs.Root(), 7)
|
||||
|
||||
opt, err := nn.DefaultSGDConfig().Build(vs, 1e-2)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
xs := ts.MustZeros([]int64{7}, gotch.Float.CInt(), gotch.CPU.CInt())
|
||||
ys := ts.MustZeros([]int64{7}, gotch.Float.CInt(), gotch.CPU.CInt())
|
||||
|
||||
loss := m.Forward(xs).MustSub(ys).MustPow(ts.IntScalar(2)).MustSum(gotch.Float.CInt())
|
||||
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
fmt.Printf("Loss: %v\n", loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}))
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -77,6 +77,11 @@ func AtgAdd_(ptr *Ctensor, self Ctensor, other Ctensor) {
|
|||
C.atg_add_(ptr, self, other)
|
||||
}
|
||||
|
||||
// id atg_add1(tensor *, tensor self, scalar other);
|
||||
func AtgAdd1(ptr *Ctensor, self Ctensor, other Cscalar) {
|
||||
C.atg_add1(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_totype(tensor *, tensor self, int scalar_type);
|
||||
func AtgTotype(ptr *Ctensor, self Ctensor, scalar_type int32) {
|
||||
cscalar_type := *(*C.int)(unsafe.Pointer(&scalar_type))
|
||||
|
@ -267,3 +272,47 @@ func AtgT(ptr *Ctensor, self Ctensor) {
|
|||
func AtgT_(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_t_(ptr, self)
|
||||
}
|
||||
|
||||
// void atg_mse_loss(tensor *, tensor self, tensor target, int64_t reduction);
|
||||
func AtgMseLoss(ptr *Ctensor, self Ctensor, target Ctensor, reduction int) {
|
||||
creduction := *(*C.int64_t)(unsafe.Pointer(&reduction))
|
||||
|
||||
C.atg_mse_loss(ptr, self, target, creduction)
|
||||
}
|
||||
|
||||
// void atg_exp(tensor *, tensor self);
|
||||
func AtgExp(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_exp(ptr, self)
|
||||
}
|
||||
|
||||
// void atg_exp_(tensor *, tensor self);
|
||||
func AtgExp_(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_exp_(ptr, self)
|
||||
}
|
||||
|
||||
// void atg_pow(tensor *, tensor self, scalar exponent);
|
||||
func AtgPow(ptr *Ctensor, self Ctensor, exponent Cscalar) {
|
||||
C.atg_pow(ptr, self, exponent)
|
||||
}
|
||||
|
||||
// void atg_sum(tensor *, tensor self, int dtype);
|
||||
func AtgSum(ptr *Ctensor, self Ctensor, dtype int32) {
|
||||
cdtype := *(*C.int)(unsafe.Pointer(&dtype))
|
||||
|
||||
C.atg_sum(ptr, self, cdtype)
|
||||
}
|
||||
|
||||
// void atg_sub(tensor *, tensor self, tensor other);
|
||||
func AtgSub(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_sub(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_sub1(tensor *, tensor self, scalar other);
|
||||
func AtgSub1(ptr *Ctensor, self Ctensor, other Cscalar) {
|
||||
C.atg_sub1(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_sub_(tensor *, tensor self, tensor other);
|
||||
func AtgSub_(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_sub_(ptr, self, other)
|
||||
}
|
||||
|
|
35
nn/func.go
Normal file
35
nn/func.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package nn
|
||||
|
||||
// Layers defined by closure
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
type Func struct {
|
||||
f func(ts.Tensor) ts.Tensor
|
||||
}
|
||||
|
||||
func NewFunc(fn func(ts.Tensor) ts.Tensor) (retVal Func) {
|
||||
return Func{f: fn}
|
||||
}
|
||||
|
||||
// Implement Module interface for Func:
|
||||
// ====================================
|
||||
func (fn Func) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||
return fn.f(xs)
|
||||
}
|
||||
|
||||
type FuncT struct {
|
||||
f func(ts.Tensor, bool) ts.Tensor
|
||||
}
|
||||
|
||||
func NewFuncT(fn func(ts.Tensor, bool) ts.Tensor) (retVal FuncT) {
|
||||
return FuncT{f: fn}
|
||||
}
|
||||
|
||||
// Implement Module interface for Func:
|
||||
// ====================================
|
||||
func (fn FuncT) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||
return fn.f(xs, train)
|
||||
}
|
66
nn/optimizer_test.go
Normal file
66
nn/optimizer_test.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package nn_test
|
||||
|
||||
import (
|
||||
// "reflect"
|
||||
"fmt"
|
||||
"log"
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func TestOptimizer(t *testing.T) {
|
||||
|
||||
var data []float32
|
||||
for i := 0; i < 15; i++ {
|
||||
data = append(data, float32(i))
|
||||
}
|
||||
xs, err := ts.NewTensorFromData(data, []int64{int64(len(data)), 1})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ys := xs.MustMul1(ts.FloatScalar(0.42)).MustAdd1(ts.FloatScalar(1.337))
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
|
||||
opt, err := nn.DefaultSGDConfig().Build(vs, 1e-2)
|
||||
if err != nil {
|
||||
t.Errorf("Failed building SGD optimizer")
|
||||
}
|
||||
|
||||
cfg := nn.LinearConfig{
|
||||
WsInit: nn.NewConstInit(float64(0.0)),
|
||||
BsInit: nn.NewConstInit(float64(0.0)),
|
||||
Bias: true,
|
||||
}
|
||||
|
||||
linear := nn.NewLinear(vs.Root(), 1, 1, cfg)
|
||||
|
||||
loss := xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
|
||||
|
||||
initialLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
|
||||
wantLoss := float64(1.0)
|
||||
|
||||
if initialLoss < wantLoss {
|
||||
t.Errorf("Expect initial loss > %v, got %v", wantLoss, initialLoss)
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
|
||||
|
||||
opt.BackwardStep(loss)
|
||||
fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}))
|
||||
}
|
||||
|
||||
loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
|
||||
finalLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
fmt.Printf("Final loss: %v\n", finalLoss)
|
||||
|
||||
if finalLoss > 0.25 {
|
||||
t.Errorf("Expect initial loss < 0.25, got %v", finalLoss)
|
||||
}
|
||||
}
|
68
tensor/macro.go
Normal file
68
tensor/macro.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package tensor
|
||||
|
||||
// TODO: implement tensor.From macro
|
||||
/*
|
||||
* macro_rules! from_tensor {
|
||||
* ($typ:ident, $zero:expr, $kind:ident) => {
|
||||
* impl From<&Tensor> for Vec<$typ> {
|
||||
* fn from(tensor: &Tensor) -> Vec<$typ> {
|
||||
* let numel = tensor.numel();
|
||||
* let mut vec = vec![$zero; numel as usize];
|
||||
* tensor.to_kind(Kind::$kind).copy_data(&mut vec, numel);
|
||||
* vec
|
||||
* }
|
||||
* }
|
||||
*
|
||||
* impl From<&Tensor> for Vec<Vec<$typ>> {
|
||||
* fn from(tensor: &Tensor) -> Vec<Vec<$typ>> {
|
||||
* let first_dim = tensor.size()[0];
|
||||
* (0..first_dim)
|
||||
* .map(|i| Vec::<$typ>::from(tensor.get(i)))
|
||||
* .collect()
|
||||
* }
|
||||
* }
|
||||
*
|
||||
* impl From<&Tensor> for Vec<Vec<Vec<$typ>>> {
|
||||
* fn from(tensor: &Tensor) -> Vec<Vec<Vec<$typ>>> {
|
||||
* let first_dim = tensor.size()[0];
|
||||
* (0..first_dim)
|
||||
* .map(|i| Vec::<Vec<$typ>>::from(tensor.get(i)))
|
||||
* .collect()
|
||||
* }
|
||||
* }
|
||||
*
|
||||
* impl From<&Tensor> for $typ {
|
||||
* fn from(tensor: &Tensor) -> $typ {
|
||||
* let numel = tensor.numel();
|
||||
* if numel != 1 {
|
||||
* panic!("expected exactly one element, got {}", numel)
|
||||
* }
|
||||
* Vec::from(tensor)[0]
|
||||
* }
|
||||
* }
|
||||
*
|
||||
* impl From<Tensor> for Vec<$typ> {
|
||||
* fn from(tensor: Tensor) -> Vec<$typ> {
|
||||
* Vec::<$typ>::from(&tensor)
|
||||
* }
|
||||
* }
|
||||
*
|
||||
* impl From<Tensor> for Vec<Vec<$typ>> {
|
||||
* fn from(tensor: Tensor) -> Vec<Vec<$typ>> {
|
||||
* Vec::<Vec<$typ>>::from(&tensor)
|
||||
* }
|
||||
* }
|
||||
*
|
||||
* impl From<Tensor> for Vec<Vec<Vec<$typ>>> {
|
||||
* fn from(tensor: Tensor) -> Vec<Vec<Vec<$typ>>> {
|
||||
* Vec::<Vec<Vec<$typ>>>::from(&tensor)
|
||||
* }
|
||||
* }
|
||||
*
|
||||
* impl From<Tensor> for $typ {
|
||||
* fn from(tensor: Tensor) -> $typ {
|
||||
* $typ::from(&tensor)
|
||||
* }
|
||||
* }
|
||||
* };
|
||||
* } */
|
|
@ -229,24 +229,36 @@ func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Add_(other Tensor) (err error) {
|
||||
func (ts Tensor) Add_(other Tensor) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgAdd_(ptr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return err
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (ts Tensor) MustAdd_(other Tensor) {
|
||||
err := ts.Add_(other)
|
||||
func (ts Tensor) Add1(other Scalar) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgAdd1(ptr, ts.ctensor, other.cscalar)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustAdd1(other Scalar) (retVal Tensor) {
|
||||
retVal, err := ts.Add1(other)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) AddG(other Tensor) (err error) {
|
||||
|
@ -809,3 +821,176 @@ func (ts Tensor) T_() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) MseLoss(target Tensor, reduction int) (retVal Tensor, err error) {
|
||||
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgMseLoss(ptr, ts.ctensor, target.ctensor, reduction)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustMseLoss(target Tensor, reduction int) (retVal Tensor) {
|
||||
retVal, err := ts.MseLoss(target, reduction)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Exp() (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgExp(ptr, ts.ctensor)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustExp() (retVal Tensor) {
|
||||
retVal, err := ts.Exp()
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Exp_() {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgExp(ptr, ts.ctensor)
|
||||
err := TorchErr()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) Pow(exponent Scalar) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgPow(ptr, ts.ctensor, exponent.cscalar)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustPow(exponent Scalar) (retVal Tensor) {
|
||||
retVal, err := ts.Pow(exponent)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Sum(dtype int32) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgSum(ptr, ts.ctensor, dtype)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustSum(dtype int32) (retVal Tensor) {
|
||||
retVal, err := ts.Sum(dtype)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Sub(other Tensor) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgSub(ptr, ts.ctensor, other.ctensor)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustSub(other Tensor) (retVal Tensor) {
|
||||
retVal, err := ts.Sub(other)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Sub1(other Scalar) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgSub1(ptr, ts.ctensor, other.cscalar)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustSub1(other Scalar) (retVal Tensor) {
|
||||
retVal, err := ts.Sub1(other)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Sub_(other Tensor) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgSub_(ptr, ts.ctensor, other.ctensor)
|
||||
err := TorchErr()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -967,24 +967,24 @@ type Reduction int
|
|||
|
||||
const (
|
||||
// Do not reduce
|
||||
ReduceNone Reduction = iota
|
||||
ReductionNone Reduction = iota
|
||||
// Mean of losses
|
||||
ReduceMean
|
||||
ReductionMean
|
||||
// Sum of losses
|
||||
ReduceSum
|
||||
ReductionSum
|
||||
// Escape hatch in case new options become available
|
||||
Other
|
||||
ReductionOther
|
||||
)
|
||||
|
||||
func (r Reduction) ToInt() (retVal int) {
|
||||
switch r {
|
||||
case ReduceNone:
|
||||
case ReductionNone:
|
||||
return 0
|
||||
case ReduceMean:
|
||||
case ReductionMean:
|
||||
return 1
|
||||
case ReduceSum:
|
||||
case ReductionSum:
|
||||
return 2
|
||||
case Other:
|
||||
case ReductionOther:
|
||||
return 3
|
||||
}
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue
Block a user