From 3f569fdcee17bd2b806914a22d0a45e50585d693 Mon Sep 17 00:00:00 2001 From: sugarme Date: Fri, 19 Jun 2020 15:26:42 +1000 Subject: [PATCH] feat(tensor): added more tensor API, WIP(example/nn, mnist/nn) --- example/mnist/linear.go | 4 +- example/nn/main.go | 61 +++++++++ example/sgd/main.go | 45 +++++++ libtch/c-generated-sample.go | 49 ++++++++ nn/func.go | 35 ++++++ nn/optimizer_test.go | 66 ++++++++++ tensor/macro.go | 68 ++++++++++ tensor/tensor-generated-sample.go | 201 ++++++++++++++++++++++++++++-- tensor/tensor.go | 16 +-- 9 files changed, 527 insertions(+), 18 deletions(-) create mode 100644 example/nn/main.go create mode 100644 example/sgd/main.go create mode 100644 nn/func.go create mode 100644 nn/optimizer_test.go create mode 100644 tensor/macro.go diff --git a/example/mnist/linear.go b/example/mnist/linear.go index e690fdc..c818f2b 100644 --- a/example/mnist/linear.go +++ b/example/mnist/linear.go @@ -84,8 +84,8 @@ func runLinear() { loss.MustBackward() ts.NoGrad(func() { - ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0))) - bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0))) + ws.Add_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0))) + bs.Add_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0))) }) testLogits := ds.TestImages.MustMm(ws).MustAdd(bs) diff --git a/example/nn/main.go b/example/nn/main.go new file mode 100644 index 0000000..b62b9ed --- /dev/null +++ b/example/nn/main.go @@ -0,0 +1,61 @@ +package main + +import ( + "fmt" + "log" + + "github.com/sugarme/gotch" + "github.com/sugarme/gotch/nn" + ts "github.com/sugarme/gotch/tensor" +) + +func testOptimizer() { + + var data []float64 + for i := 0; i < 15; i++ { + data = append(data, float64(i)) + } + xs, err := ts.NewTensorFromData(data, []int64{int64(len(data)), 1}) + if err != nil { + log.Fatal(err) + } + + ys := xs.MustMul1(ts.FloatScalar(0.42)).MustAdd1(ts.FloatScalar(1.337)) + + vs := nn.NewVarStore(gotch.CPU) + + opt, err := nn.DefaultSGDConfig().Build(vs, 1e-2) + if err != nil { + log.Fatal("Failed building SGD optimizer") + } + + cfg := nn.LinearConfig{ + WsInit: nn.NewConstInit(0.001), + BsInit: nn.NewConstInit(0.001), + Bias: true, + } + + linear := nn.NewLinear(vs.Root(), 1, 1, cfg) + loss := xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt()) + initialLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}) + fmt.Printf("Initial Loss: %.3f\n", initialLoss) + + for i := 0; i < 50; i++ { + loss = xs.Apply(linear) + // loss = linear.Forward(xs) + loss = loss.MustMseLoss(ys, ts.ReductionMean.ToInt()) + + fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})) + + opt.BackwardStep(loss) + + fmt.Printf("Bs: %.3f - Bs Grad: %.3f\n", linear.Bs.MustView([]int64{-1}).MustFloat64Value([]int64{0}), linear.Bs.MustGrad().MustFloat64Value([]int64{0})) + fmt.Printf("Ws: %.3f - Ws Grad: %.3f\n", linear.Ws.MustView([]int64{-1}).MustFloat64Value([]int64{0}), linear.Ws.MustGrad().MustFloat64Value([]int64{0})) + + } + +} + +func main() { + testOptimizer() +} diff --git a/example/sgd/main.go b/example/sgd/main.go new file mode 100644 index 0000000..b7a965f --- /dev/null +++ b/example/sgd/main.go @@ -0,0 +1,45 @@ +package main + +import ( + "fmt" + "log" + + "github.com/sugarme/gotch" + "github.com/sugarme/gotch/nn" + ts "github.com/sugarme/gotch/tensor" +) + +func myModule(p nn.Path, dim int64) ts.Module { + x1 := p.Zeros("x1", []int64{dim}) + x2 := p.Zeros("x1", []int64{dim}) + + return nn.NewFunc(func(xs ts.Tensor) ts.Tensor { + return xs.MustMul(x1).MustAdd(xs.MustExp().MustMul(x2)) + }) + +} + +func main() { + + vs := nn.NewVarStore(gotch.CPU) + + m := myModule(vs.Root(), 7) + + opt, err := nn.DefaultSGDConfig().Build(vs, 1e-2) + if err != nil { + log.Fatal(err) + } + + for i := 0; i < 50; i++ { + xs := ts.MustZeros([]int64{7}, gotch.Float.CInt(), gotch.CPU.CInt()) + ys := ts.MustZeros([]int64{7}, gotch.Float.CInt(), gotch.CPU.CInt()) + + loss := m.Forward(xs).MustSub(ys).MustPow(ts.IntScalar(2)).MustSum(gotch.Float.CInt()) + + opt.BackwardStep(loss) + + fmt.Printf("Loss: %v\n", loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})) + + } + +} diff --git a/libtch/c-generated-sample.go b/libtch/c-generated-sample.go index 5ff16f1..31ad23e 100644 --- a/libtch/c-generated-sample.go +++ b/libtch/c-generated-sample.go @@ -77,6 +77,11 @@ func AtgAdd_(ptr *Ctensor, self Ctensor, other Ctensor) { C.atg_add_(ptr, self, other) } +// id atg_add1(tensor *, tensor self, scalar other); +func AtgAdd1(ptr *Ctensor, self Ctensor, other Cscalar) { + C.atg_add1(ptr, self, other) +} + // void atg_totype(tensor *, tensor self, int scalar_type); func AtgTotype(ptr *Ctensor, self Ctensor, scalar_type int32) { cscalar_type := *(*C.int)(unsafe.Pointer(&scalar_type)) @@ -267,3 +272,47 @@ func AtgT(ptr *Ctensor, self Ctensor) { func AtgT_(ptr *Ctensor, self Ctensor) { C.atg_t_(ptr, self) } + +// void atg_mse_loss(tensor *, tensor self, tensor target, int64_t reduction); +func AtgMseLoss(ptr *Ctensor, self Ctensor, target Ctensor, reduction int) { + creduction := *(*C.int64_t)(unsafe.Pointer(&reduction)) + + C.atg_mse_loss(ptr, self, target, creduction) +} + +// void atg_exp(tensor *, tensor self); +func AtgExp(ptr *Ctensor, self Ctensor) { + C.atg_exp(ptr, self) +} + +// void atg_exp_(tensor *, tensor self); +func AtgExp_(ptr *Ctensor, self Ctensor) { + C.atg_exp_(ptr, self) +} + +// void atg_pow(tensor *, tensor self, scalar exponent); +func AtgPow(ptr *Ctensor, self Ctensor, exponent Cscalar) { + C.atg_pow(ptr, self, exponent) +} + +// void atg_sum(tensor *, tensor self, int dtype); +func AtgSum(ptr *Ctensor, self Ctensor, dtype int32) { + cdtype := *(*C.int)(unsafe.Pointer(&dtype)) + + C.atg_sum(ptr, self, cdtype) +} + +// void atg_sub(tensor *, tensor self, tensor other); +func AtgSub(ptr *Ctensor, self Ctensor, other Ctensor) { + C.atg_sub(ptr, self, other) +} + +// void atg_sub1(tensor *, tensor self, scalar other); +func AtgSub1(ptr *Ctensor, self Ctensor, other Cscalar) { + C.atg_sub1(ptr, self, other) +} + +// void atg_sub_(tensor *, tensor self, tensor other); +func AtgSub_(ptr *Ctensor, self Ctensor, other Ctensor) { + C.atg_sub_(ptr, self, other) +} diff --git a/nn/func.go b/nn/func.go new file mode 100644 index 0000000..27489cf --- /dev/null +++ b/nn/func.go @@ -0,0 +1,35 @@ +package nn + +// Layers defined by closure + +import ( + ts "github.com/sugarme/gotch/tensor" +) + +type Func struct { + f func(ts.Tensor) ts.Tensor +} + +func NewFunc(fn func(ts.Tensor) ts.Tensor) (retVal Func) { + return Func{f: fn} +} + +// Implement Module interface for Func: +// ==================================== +func (fn Func) Forward(xs ts.Tensor) (retVal ts.Tensor) { + return fn.f(xs) +} + +type FuncT struct { + f func(ts.Tensor, bool) ts.Tensor +} + +func NewFuncT(fn func(ts.Tensor, bool) ts.Tensor) (retVal FuncT) { + return FuncT{f: fn} +} + +// Implement Module interface for Func: +// ==================================== +func (fn FuncT) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) { + return fn.f(xs, train) +} diff --git a/nn/optimizer_test.go b/nn/optimizer_test.go new file mode 100644 index 0000000..ee62522 --- /dev/null +++ b/nn/optimizer_test.go @@ -0,0 +1,66 @@ +package nn_test + +import ( + // "reflect" + "fmt" + "log" + "testing" + + "github.com/sugarme/gotch" + "github.com/sugarme/gotch/nn" + ts "github.com/sugarme/gotch/tensor" +) + +func TestOptimizer(t *testing.T) { + + var data []float32 + for i := 0; i < 15; i++ { + data = append(data, float32(i)) + } + xs, err := ts.NewTensorFromData(data, []int64{int64(len(data)), 1}) + if err != nil { + log.Fatal(err) + } + + ys := xs.MustMul1(ts.FloatScalar(0.42)).MustAdd1(ts.FloatScalar(1.337)) + + vs := nn.NewVarStore(gotch.CPU) + + opt, err := nn.DefaultSGDConfig().Build(vs, 1e-2) + if err != nil { + t.Errorf("Failed building SGD optimizer") + } + + cfg := nn.LinearConfig{ + WsInit: nn.NewConstInit(float64(0.0)), + BsInit: nn.NewConstInit(float64(0.0)), + Bias: true, + } + + linear := nn.NewLinear(vs.Root(), 1, 1, cfg) + + loss := xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt()) + + initialLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}) + + wantLoss := float64(1.0) + + if initialLoss < wantLoss { + t.Errorf("Expect initial loss > %v, got %v", wantLoss, initialLoss) + } + + for i := 0; i < 50; i++ { + loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt()) + + opt.BackwardStep(loss) + fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})) + } + + loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt()) + finalLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}) + fmt.Printf("Final loss: %v\n", finalLoss) + + if finalLoss > 0.25 { + t.Errorf("Expect initial loss < 0.25, got %v", finalLoss) + } +} diff --git a/tensor/macro.go b/tensor/macro.go new file mode 100644 index 0000000..97461f6 --- /dev/null +++ b/tensor/macro.go @@ -0,0 +1,68 @@ +package tensor + +// TODO: implement tensor.From macro +/* + * macro_rules! from_tensor { + * ($typ:ident, $zero:expr, $kind:ident) => { + * impl From<&Tensor> for Vec<$typ> { + * fn from(tensor: &Tensor) -> Vec<$typ> { + * let numel = tensor.numel(); + * let mut vec = vec![$zero; numel as usize]; + * tensor.to_kind(Kind::$kind).copy_data(&mut vec, numel); + * vec + * } + * } + * + * impl From<&Tensor> for Vec> { + * fn from(tensor: &Tensor) -> Vec> { + * let first_dim = tensor.size()[0]; + * (0..first_dim) + * .map(|i| Vec::<$typ>::from(tensor.get(i))) + * .collect() + * } + * } + * + * impl From<&Tensor> for Vec>> { + * fn from(tensor: &Tensor) -> Vec>> { + * let first_dim = tensor.size()[0]; + * (0..first_dim) + * .map(|i| Vec::>::from(tensor.get(i))) + * .collect() + * } + * } + * + * impl From<&Tensor> for $typ { + * fn from(tensor: &Tensor) -> $typ { + * let numel = tensor.numel(); + * if numel != 1 { + * panic!("expected exactly one element, got {}", numel) + * } + * Vec::from(tensor)[0] + * } + * } + * + * impl From for Vec<$typ> { + * fn from(tensor: Tensor) -> Vec<$typ> { + * Vec::<$typ>::from(&tensor) + * } + * } + * + * impl From for Vec> { + * fn from(tensor: Tensor) -> Vec> { + * Vec::>::from(&tensor) + * } + * } + * + * impl From for Vec>> { + * fn from(tensor: Tensor) -> Vec>> { + * Vec::>>::from(&tensor) + * } + * } + * + * impl From for $typ { + * fn from(tensor: Tensor) -> $typ { + * $typ::from(&tensor) + * } + * } + * }; + * } */ diff --git a/tensor/tensor-generated-sample.go b/tensor/tensor-generated-sample.go index a381739..d3a7daa 100644 --- a/tensor/tensor-generated-sample.go +++ b/tensor/tensor-generated-sample.go @@ -229,24 +229,36 @@ func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) { return retVal } -func (ts Tensor) Add_(other Tensor) (err error) { +func (ts Tensor) Add_(other Tensor) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) defer C.free(unsafe.Pointer(ptr)) lib.AtgAdd_(ptr, ts.ctensor, other.ctensor) - if err = TorchErr(); err != nil { - return err + if err := TorchErr(); err != nil { + log.Fatal(err) } - - return nil - } -func (ts Tensor) MustAdd_(other Tensor) { - err := ts.Add_(other) +func (ts Tensor) Add1(other Scalar) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + lib.AtgAdd1(ptr, ts.ctensor, other.cscalar) + + if err = TorchErr(); err != nil { + return retVal, err + } + + return Tensor{ctensor: *ptr}, nil +} + +func (ts Tensor) MustAdd1(other Scalar) (retVal Tensor) { + retVal, err := ts.Add1(other) + if err != nil { log.Fatal(err) } + + return retVal } func (ts Tensor) AddG(other Tensor) (err error) { @@ -809,3 +821,176 @@ func (ts Tensor) T_() { log.Fatal(err) } } + +func (ts Tensor) MseLoss(target Tensor, reduction int) (retVal Tensor, err error) { + + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgMseLoss(ptr, ts.ctensor, target.ctensor, reduction) + err = TorchErr() + if err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustMseLoss(target Tensor, reduction int) (retVal Tensor) { + retVal, err := ts.MseLoss(target, reduction) + + if err != nil { + log.Fatal(err) + } + + return retVal +} + +func (ts Tensor) Exp() (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgExp(ptr, ts.ctensor) + err = TorchErr() + if err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustExp() (retVal Tensor) { + retVal, err := ts.Exp() + + if err != nil { + log.Fatal(err) + } + + return retVal +} + +func (ts Tensor) Exp_() { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgExp(ptr, ts.ctensor) + err := TorchErr() + if err != nil { + log.Fatal(err) + } +} + +func (ts Tensor) Pow(exponent Scalar) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgPow(ptr, ts.ctensor, exponent.cscalar) + err = TorchErr() + if err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustPow(exponent Scalar) (retVal Tensor) { + retVal, err := ts.Pow(exponent) + + if err != nil { + log.Fatal(err) + } + + return retVal +} + +func (ts Tensor) Sum(dtype int32) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgSum(ptr, ts.ctensor, dtype) + err = TorchErr() + if err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustSum(dtype int32) (retVal Tensor) { + retVal, err := ts.Sum(dtype) + + if err != nil { + log.Fatal(err) + } + + return retVal +} + +func (ts Tensor) Sub(other Tensor) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgSub(ptr, ts.ctensor, other.ctensor) + err = TorchErr() + if err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustSub(other Tensor) (retVal Tensor) { + retVal, err := ts.Sub(other) + + if err != nil { + log.Fatal(err) + } + + return retVal +} + +func (ts Tensor) Sub1(other Scalar) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgSub1(ptr, ts.ctensor, other.cscalar) + err = TorchErr() + if err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustSub1(other Scalar) (retVal Tensor) { + retVal, err := ts.Sub1(other) + + if err != nil { + log.Fatal(err) + } + + return retVal +} + +func (ts Tensor) Sub_(other Tensor) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgSub_(ptr, ts.ctensor, other.ctensor) + err := TorchErr() + if err != nil { + log.Fatal(err) + } +} diff --git a/tensor/tensor.go b/tensor/tensor.go index 7e9ed4e..a64d7a9 100644 --- a/tensor/tensor.go +++ b/tensor/tensor.go @@ -967,24 +967,24 @@ type Reduction int const ( // Do not reduce - ReduceNone Reduction = iota + ReductionNone Reduction = iota // Mean of losses - ReduceMean + ReductionMean // Sum of losses - ReduceSum + ReductionSum // Escape hatch in case new options become available - Other + ReductionOther ) func (r Reduction) ToInt() (retVal int) { switch r { - case ReduceNone: + case ReductionNone: return 0 - case ReduceMean: + case ReductionMean: return 1 - case ReduceSum: + case ReductionSum: return 2 - case Other: + case ReductionOther: return 3 } return