feat(tensor): added more tensor API, WIP(example/nn, mnist/nn)

This commit is contained in:
sugarme 2020-06-19 15:26:42 +10:00
parent a636372144
commit 3f569fdcee
9 changed files with 527 additions and 18 deletions

View File

@ -84,8 +84,8 @@ func runLinear() {
loss.MustBackward()
ts.NoGrad(func() {
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
ws.Add_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
bs.Add_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
})
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)

61
example/nn/main.go Normal file
View File

@ -0,0 +1,61 @@
package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
)
func testOptimizer() {
var data []float64
for i := 0; i < 15; i++ {
data = append(data, float64(i))
}
xs, err := ts.NewTensorFromData(data, []int64{int64(len(data)), 1})
if err != nil {
log.Fatal(err)
}
ys := xs.MustMul1(ts.FloatScalar(0.42)).MustAdd1(ts.FloatScalar(1.337))
vs := nn.NewVarStore(gotch.CPU)
opt, err := nn.DefaultSGDConfig().Build(vs, 1e-2)
if err != nil {
log.Fatal("Failed building SGD optimizer")
}
cfg := nn.LinearConfig{
WsInit: nn.NewConstInit(0.001),
BsInit: nn.NewConstInit(0.001),
Bias: true,
}
linear := nn.NewLinear(vs.Root(), 1, 1, cfg)
loss := xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
initialLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})
fmt.Printf("Initial Loss: %.3f\n", initialLoss)
for i := 0; i < 50; i++ {
loss = xs.Apply(linear)
// loss = linear.Forward(xs)
loss = loss.MustMseLoss(ys, ts.ReductionMean.ToInt())
fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}))
opt.BackwardStep(loss)
fmt.Printf("Bs: %.3f - Bs Grad: %.3f\n", linear.Bs.MustView([]int64{-1}).MustFloat64Value([]int64{0}), linear.Bs.MustGrad().MustFloat64Value([]int64{0}))
fmt.Printf("Ws: %.3f - Ws Grad: %.3f\n", linear.Ws.MustView([]int64{-1}).MustFloat64Value([]int64{0}), linear.Ws.MustGrad().MustFloat64Value([]int64{0}))
}
}
func main() {
testOptimizer()
}

45
example/sgd/main.go Normal file
View File

@ -0,0 +1,45 @@
package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
)
func myModule(p nn.Path, dim int64) ts.Module {
x1 := p.Zeros("x1", []int64{dim})
x2 := p.Zeros("x1", []int64{dim})
return nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
return xs.MustMul(x1).MustAdd(xs.MustExp().MustMul(x2))
})
}
func main() {
vs := nn.NewVarStore(gotch.CPU)
m := myModule(vs.Root(), 7)
opt, err := nn.DefaultSGDConfig().Build(vs, 1e-2)
if err != nil {
log.Fatal(err)
}
for i := 0; i < 50; i++ {
xs := ts.MustZeros([]int64{7}, gotch.Float.CInt(), gotch.CPU.CInt())
ys := ts.MustZeros([]int64{7}, gotch.Float.CInt(), gotch.CPU.CInt())
loss := m.Forward(xs).MustSub(ys).MustPow(ts.IntScalar(2)).MustSum(gotch.Float.CInt())
opt.BackwardStep(loss)
fmt.Printf("Loss: %v\n", loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}))
}
}

View File

@ -77,6 +77,11 @@ func AtgAdd_(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_add_(ptr, self, other)
}
// id atg_add1(tensor *, tensor self, scalar other);
func AtgAdd1(ptr *Ctensor, self Ctensor, other Cscalar) {
C.atg_add1(ptr, self, other)
}
// void atg_totype(tensor *, tensor self, int scalar_type);
func AtgTotype(ptr *Ctensor, self Ctensor, scalar_type int32) {
cscalar_type := *(*C.int)(unsafe.Pointer(&scalar_type))
@ -267,3 +272,47 @@ func AtgT(ptr *Ctensor, self Ctensor) {
func AtgT_(ptr *Ctensor, self Ctensor) {
C.atg_t_(ptr, self)
}
// void atg_mse_loss(tensor *, tensor self, tensor target, int64_t reduction);
func AtgMseLoss(ptr *Ctensor, self Ctensor, target Ctensor, reduction int) {
creduction := *(*C.int64_t)(unsafe.Pointer(&reduction))
C.atg_mse_loss(ptr, self, target, creduction)
}
// void atg_exp(tensor *, tensor self);
func AtgExp(ptr *Ctensor, self Ctensor) {
C.atg_exp(ptr, self)
}
// void atg_exp_(tensor *, tensor self);
func AtgExp_(ptr *Ctensor, self Ctensor) {
C.atg_exp_(ptr, self)
}
// void atg_pow(tensor *, tensor self, scalar exponent);
func AtgPow(ptr *Ctensor, self Ctensor, exponent Cscalar) {
C.atg_pow(ptr, self, exponent)
}
// void atg_sum(tensor *, tensor self, int dtype);
func AtgSum(ptr *Ctensor, self Ctensor, dtype int32) {
cdtype := *(*C.int)(unsafe.Pointer(&dtype))
C.atg_sum(ptr, self, cdtype)
}
// void atg_sub(tensor *, tensor self, tensor other);
func AtgSub(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_sub(ptr, self, other)
}
// void atg_sub1(tensor *, tensor self, scalar other);
func AtgSub1(ptr *Ctensor, self Ctensor, other Cscalar) {
C.atg_sub1(ptr, self, other)
}
// void atg_sub_(tensor *, tensor self, tensor other);
func AtgSub_(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_sub_(ptr, self, other)
}

35
nn/func.go Normal file
View File

@ -0,0 +1,35 @@
package nn
// Layers defined by closure
import (
ts "github.com/sugarme/gotch/tensor"
)
type Func struct {
f func(ts.Tensor) ts.Tensor
}
func NewFunc(fn func(ts.Tensor) ts.Tensor) (retVal Func) {
return Func{f: fn}
}
// Implement Module interface for Func:
// ====================================
func (fn Func) Forward(xs ts.Tensor) (retVal ts.Tensor) {
return fn.f(xs)
}
type FuncT struct {
f func(ts.Tensor, bool) ts.Tensor
}
func NewFuncT(fn func(ts.Tensor, bool) ts.Tensor) (retVal FuncT) {
return FuncT{f: fn}
}
// Implement Module interface for Func:
// ====================================
func (fn FuncT) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
return fn.f(xs, train)
}

66
nn/optimizer_test.go Normal file
View File

@ -0,0 +1,66 @@
package nn_test
import (
// "reflect"
"fmt"
"log"
"testing"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
)
func TestOptimizer(t *testing.T) {
var data []float32
for i := 0; i < 15; i++ {
data = append(data, float32(i))
}
xs, err := ts.NewTensorFromData(data, []int64{int64(len(data)), 1})
if err != nil {
log.Fatal(err)
}
ys := xs.MustMul1(ts.FloatScalar(0.42)).MustAdd1(ts.FloatScalar(1.337))
vs := nn.NewVarStore(gotch.CPU)
opt, err := nn.DefaultSGDConfig().Build(vs, 1e-2)
if err != nil {
t.Errorf("Failed building SGD optimizer")
}
cfg := nn.LinearConfig{
WsInit: nn.NewConstInit(float64(0.0)),
BsInit: nn.NewConstInit(float64(0.0)),
Bias: true,
}
linear := nn.NewLinear(vs.Root(), 1, 1, cfg)
loss := xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
initialLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})
wantLoss := float64(1.0)
if initialLoss < wantLoss {
t.Errorf("Expect initial loss > %v, got %v", wantLoss, initialLoss)
}
for i := 0; i < 50; i++ {
loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
opt.BackwardStep(loss)
fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}))
}
loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
finalLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})
fmt.Printf("Final loss: %v\n", finalLoss)
if finalLoss > 0.25 {
t.Errorf("Expect initial loss < 0.25, got %v", finalLoss)
}
}

68
tensor/macro.go Normal file
View File

@ -0,0 +1,68 @@
package tensor
// TODO: implement tensor.From macro
/*
* macro_rules! from_tensor {
* ($typ:ident, $zero:expr, $kind:ident) => {
* impl From<&Tensor> for Vec<$typ> {
* fn from(tensor: &Tensor) -> Vec<$typ> {
* let numel = tensor.numel();
* let mut vec = vec![$zero; numel as usize];
* tensor.to_kind(Kind::$kind).copy_data(&mut vec, numel);
* vec
* }
* }
*
* impl From<&Tensor> for Vec<Vec<$typ>> {
* fn from(tensor: &Tensor) -> Vec<Vec<$typ>> {
* let first_dim = tensor.size()[0];
* (0..first_dim)
* .map(|i| Vec::<$typ>::from(tensor.get(i)))
* .collect()
* }
* }
*
* impl From<&Tensor> for Vec<Vec<Vec<$typ>>> {
* fn from(tensor: &Tensor) -> Vec<Vec<Vec<$typ>>> {
* let first_dim = tensor.size()[0];
* (0..first_dim)
* .map(|i| Vec::<Vec<$typ>>::from(tensor.get(i)))
* .collect()
* }
* }
*
* impl From<&Tensor> for $typ {
* fn from(tensor: &Tensor) -> $typ {
* let numel = tensor.numel();
* if numel != 1 {
* panic!("expected exactly one element, got {}", numel)
* }
* Vec::from(tensor)[0]
* }
* }
*
* impl From<Tensor> for Vec<$typ> {
* fn from(tensor: Tensor) -> Vec<$typ> {
* Vec::<$typ>::from(&tensor)
* }
* }
*
* impl From<Tensor> for Vec<Vec<$typ>> {
* fn from(tensor: Tensor) -> Vec<Vec<$typ>> {
* Vec::<Vec<$typ>>::from(&tensor)
* }
* }
*
* impl From<Tensor> for Vec<Vec<Vec<$typ>>> {
* fn from(tensor: Tensor) -> Vec<Vec<Vec<$typ>>> {
* Vec::<Vec<Vec<$typ>>>::from(&tensor)
* }
* }
*
* impl From<Tensor> for $typ {
* fn from(tensor: Tensor) -> $typ {
* $typ::from(&tensor)
* }
* }
* };
* } */

View File

@ -229,24 +229,36 @@ func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) {
return retVal
}
func (ts Tensor) Add_(other Tensor) (err error) {
func (ts Tensor) Add_(other Tensor) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgAdd_(ptr, ts.ctensor, other.ctensor)
if err = TorchErr(); err != nil {
return err
if err := TorchErr(); err != nil {
log.Fatal(err)
}
return nil
}
func (ts Tensor) MustAdd_(other Tensor) {
err := ts.Add_(other)
func (ts Tensor) Add1(other Scalar) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgAdd1(ptr, ts.ctensor, other.cscalar)
if err = TorchErr(); err != nil {
return retVal, err
}
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustAdd1(other Scalar) (retVal Tensor) {
retVal, err := ts.Add1(other)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) AddG(other Tensor) (err error) {
@ -809,3 +821,176 @@ func (ts Tensor) T_() {
log.Fatal(err)
}
}
func (ts Tensor) MseLoss(target Tensor, reduction int) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgMseLoss(ptr, ts.ctensor, target.ctensor, reduction)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustMseLoss(target Tensor, reduction int) (retVal Tensor) {
retVal, err := ts.MseLoss(target, reduction)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Exp() (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgExp(ptr, ts.ctensor)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustExp() (retVal Tensor) {
retVal, err := ts.Exp()
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Exp_() {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgExp(ptr, ts.ctensor)
err := TorchErr()
if err != nil {
log.Fatal(err)
}
}
func (ts Tensor) Pow(exponent Scalar) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgPow(ptr, ts.ctensor, exponent.cscalar)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustPow(exponent Scalar) (retVal Tensor) {
retVal, err := ts.Pow(exponent)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Sum(dtype int32) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgSum(ptr, ts.ctensor, dtype)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustSum(dtype int32) (retVal Tensor) {
retVal, err := ts.Sum(dtype)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Sub(other Tensor) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgSub(ptr, ts.ctensor, other.ctensor)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustSub(other Tensor) (retVal Tensor) {
retVal, err := ts.Sub(other)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Sub1(other Scalar) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgSub1(ptr, ts.ctensor, other.cscalar)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustSub1(other Scalar) (retVal Tensor) {
retVal, err := ts.Sub1(other)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Sub_(other Tensor) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgSub_(ptr, ts.ctensor, other.ctensor)
err := TorchErr()
if err != nil {
log.Fatal(err)
}
}

View File

@ -967,24 +967,24 @@ type Reduction int
const (
// Do not reduce
ReduceNone Reduction = iota
ReductionNone Reduction = iota
// Mean of losses
ReduceMean
ReductionMean
// Sum of losses
ReduceSum
ReductionSum
// Escape hatch in case new options become available
Other
ReductionOther
)
func (r Reduction) ToInt() (retVal int) {
switch r {
case ReduceNone:
case ReductionNone:
return 0
case ReduceMean:
case ReductionMean:
return 1
case ReduceSum:
case ReductionSum:
return 2
case Other:
case ReductionOther:
return 3
}
return