feat(nn/rnn_test): added unit tests for rnn
This commit is contained in:
parent
42c02b0f65
commit
92f6e9da15
|
@ -1,43 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func rnnTest(rnnConfig nn.RNNConfig) {
|
||||
|
||||
var (
|
||||
batchDim int64 = 5
|
||||
// seqLen int64 = 3
|
||||
inputDim int64 = 2
|
||||
outputDim int64 = 4
|
||||
)
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
path := vs.Root()
|
||||
|
||||
gru := nn.NewGRU(&path, inputDim, outputDim, rnnConfig)
|
||||
|
||||
numDirections := int64(1)
|
||||
if rnnConfig.Bidirectional {
|
||||
numDirections = 2
|
||||
}
|
||||
layerDim := rnnConfig.NumLayers * numDirections
|
||||
|
||||
// Step test
|
||||
input := ts.MustRandn([]int64{batchDim, inputDim}, gotch.Float, gotch.CPU)
|
||||
output := gru.Step(input, gru.ZeroState(batchDim).(nn.GRUState))
|
||||
|
||||
fmt.Printf("Expected ouput shape: %v\n", []int64{layerDim, batchDim, outputDim})
|
||||
fmt.Printf("Got output shape: %v\n", output.(nn.GRUState).Tensor.MustSize())
|
||||
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
rnnTest(nn.DefaultRNNConfig())
|
||||
}
|
|
@ -46,8 +46,10 @@ func defaultBuild(config OptimizerConfig, vs VarStore, lr float64) (retVal Optim
|
|||
parameters = append(parameters, param)
|
||||
}
|
||||
|
||||
if err = opt.AddParameters(vs.Vars.TrainableVariables); err != nil {
|
||||
return retVal, err
|
||||
if len(vs.Vars.TrainableVariables) > 0 {
|
||||
if err = opt.AddParameters(vs.Vars.TrainableVariables); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: should we clone or copy?
|
||||
|
|
|
@ -1,66 +1,69 @@
|
|||
package nn_test
|
||||
|
||||
import (
|
||||
// "reflect"
|
||||
"fmt"
|
||||
"log"
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func TestOptimizer(t *testing.T) {
|
||||
|
||||
var data []float32
|
||||
for i := 0; i < 15; i++ {
|
||||
data = append(data, float32(i))
|
||||
}
|
||||
xs, err := ts.NewTensorFromData(data, []int64{int64(len(data)), 1})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ys := xs.MustMul1(ts.FloatScalar(0.42), true).MustAdd1(ts.FloatScalar(1.337), true)
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
|
||||
opt, err := nn.DefaultSGDConfig().Build(vs, 1e-2)
|
||||
if err != nil {
|
||||
t.Errorf("Failed building SGD optimizer")
|
||||
}
|
||||
|
||||
cfg := nn.LinearConfig{
|
||||
WsInit: nn.NewConstInit(float64(0.0)),
|
||||
BsInit: nn.NewConstInit(float64(0.0)),
|
||||
Bias: true,
|
||||
}
|
||||
|
||||
linear := nn.NewLinear(vs.Root(), 1, 1, cfg)
|
||||
|
||||
loss := xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
|
||||
|
||||
initialLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
|
||||
wantLoss := float64(1.0)
|
||||
|
||||
if initialLoss < wantLoss {
|
||||
t.Errorf("Expect initial loss > %v, got %v", wantLoss, initialLoss)
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
|
||||
|
||||
opt.BackwardStep(loss)
|
||||
fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}))
|
||||
}
|
||||
|
||||
loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
|
||||
finalLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
fmt.Printf("Final loss: %v\n", finalLoss)
|
||||
|
||||
if finalLoss > 0.25 {
|
||||
t.Errorf("Expect initial loss < 0.25, got %v", finalLoss)
|
||||
}
|
||||
}
|
||||
/*
|
||||
* import (
|
||||
* // "reflect"
|
||||
* "fmt"
|
||||
* "log"
|
||||
* "testing"
|
||||
*
|
||||
* "github.com/sugarme/gotch"
|
||||
* "github.com/sugarme/gotch/nn"
|
||||
* ts "github.com/sugarme/gotch/tensor"
|
||||
* )
|
||||
*
|
||||
* func TestOptimizer(t *testing.T) {
|
||||
*
|
||||
* var data []float32
|
||||
* for i := 0; i < 15; i++ {
|
||||
* data = append(data, float32(i))
|
||||
* }
|
||||
* xs, err := ts.NewTensorFromData(data, []int64{int64(len(data)), 1})
|
||||
* if err != nil {
|
||||
* log.Fatal(err)
|
||||
* }
|
||||
*
|
||||
* ys := xs.MustMul1(ts.FloatScalar(0.42), false).MustAdd1(ts.FloatScalar(1.337), false)
|
||||
*
|
||||
* vs := nn.NewVarStore(gotch.CPU)
|
||||
*
|
||||
* optCfg := nn.DefaultSGDConfig()
|
||||
* opt, err := optCfg.Build(vs, 1e-2)
|
||||
* if err != nil {
|
||||
* t.Errorf("Failed building SGD optimizer")
|
||||
* }
|
||||
*
|
||||
* cfg := nn.LinearConfig{
|
||||
* WsInit: nn.NewConstInit(0.0),
|
||||
* BsInit: nn.NewConstInit(0.0),
|
||||
* Bias: true,
|
||||
* }
|
||||
*
|
||||
* linear := nn.NewLinear(vs.Root(), 1, 1, cfg)
|
||||
*
|
||||
* logits := xs.Apply(linear)
|
||||
* loss := logits.MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
|
||||
*
|
||||
* initialLoss := loss.MustView([]int64{-1}, false).MustFloat64Value([]int64{0})
|
||||
*
|
||||
* wantLoss := float64(1.0)
|
||||
*
|
||||
* if initialLoss < wantLoss {
|
||||
* t.Errorf("Expect initial loss > %v, got %v", wantLoss, initialLoss)
|
||||
* }
|
||||
*
|
||||
* for i := 0; i < 50; i++ {
|
||||
* loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
|
||||
*
|
||||
* opt.BackwardStep(loss)
|
||||
* fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}, false).MustFloat64Value([]int64{0}))
|
||||
* }
|
||||
*
|
||||
* loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
|
||||
* finalLoss := loss.Values()[0]
|
||||
* fmt.Printf("Final loss: %v\n", finalLoss)
|
||||
*
|
||||
* if finalLoss > 0.25 {
|
||||
* t.Errorf("Expect initial loss < 0.25, got %v", finalLoss)
|
||||
* }
|
||||
* } */
|
||||
|
|
12
nn/rnn.go
12
nn/rnn.go
|
@ -139,7 +139,7 @@ func (l LSTM) ZeroState(batchDim int64) (retVal State) {
|
|||
}
|
||||
}
|
||||
|
||||
func (l LSTM) Step(input ts.Tensor, inState LSTMState) (retVal State) {
|
||||
func (l LSTM) Step(input ts.Tensor, inState State) (retVal State) {
|
||||
ip := input.MustUnsqueeze(1, false)
|
||||
|
||||
output, state := l.SeqInit(ip, inState)
|
||||
|
@ -155,9 +155,9 @@ func (l LSTM) Seq(input ts.Tensor) (ts.Tensor, State) {
|
|||
return defaultSeq(l, input)
|
||||
}
|
||||
|
||||
func (l LSTM) SeqInit(input ts.Tensor, inState LSTMState) (ts.Tensor, State) {
|
||||
func (l LSTM) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) {
|
||||
|
||||
output, h, c := input.MustLSTM([]ts.Tensor{inState.Tensor1, inState.Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
|
||||
output, h, c := input.MustLSTM([]ts.Tensor{inState.(LSTMState).Tensor1, inState.(LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
|
||||
|
||||
return output, LSTMState{
|
||||
Tensor1: h,
|
||||
|
@ -234,7 +234,7 @@ func (g GRU) ZeroState(batchDim int64) (retVal State) {
|
|||
return GRUState{Tensor: tensor}
|
||||
}
|
||||
|
||||
func (g GRU) Step(input ts.Tensor, inState GRUState) (retVal State) {
|
||||
func (g GRU) Step(input ts.Tensor, inState State) (retVal State) {
|
||||
ip := input.MustUnsqueeze(1, false)
|
||||
|
||||
output, state := g.SeqInit(ip, inState)
|
||||
|
@ -250,9 +250,9 @@ func (g GRU) Seq(input ts.Tensor) (ts.Tensor, State) {
|
|||
return defaultSeq(g, input)
|
||||
}
|
||||
|
||||
func (g GRU) SeqInit(input ts.Tensor, inState GRUState) (ts.Tensor, State) {
|
||||
func (g GRU) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) {
|
||||
|
||||
output, h := input.MustGRU(inState.Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst)
|
||||
output, h := input.MustGRU(inState.(GRUState).Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst)
|
||||
|
||||
return output, GRUState{Tensor: h}
|
||||
}
|
||||
|
|
149
nn/rnn_test.go
Normal file
149
nn/rnn_test.go
Normal file
|
@ -0,0 +1,149 @@
|
|||
package nn_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func gruTest(rnnConfig nn.RNNConfig, t *testing.T) {
|
||||
|
||||
var (
|
||||
batchDim int64 = 5
|
||||
seqLen int64 = 3
|
||||
inputDim int64 = 2
|
||||
outputDim int64 = 4
|
||||
)
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
path := vs.Root()
|
||||
|
||||
gru := nn.NewGRU(&path, inputDim, outputDim, rnnConfig)
|
||||
|
||||
numDirections := int64(1)
|
||||
if rnnConfig.Bidirectional {
|
||||
numDirections = 2
|
||||
}
|
||||
layerDim := rnnConfig.NumLayers * numDirections
|
||||
|
||||
// Step test
|
||||
input := ts.MustRandn([]int64{batchDim, inputDim}, gotch.Float, gotch.CPU)
|
||||
output := gru.Step(input, gru.ZeroState(batchDim).(nn.GRUState))
|
||||
|
||||
want := []int64{layerDim, batchDim, outputDim}
|
||||
got := output.(nn.GRUState).Tensor.MustSize()
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
fmt.Println("Step test:")
|
||||
t.Errorf("Expected ouput shape: %v\n", want)
|
||||
t.Errorf("Got output shape: %v\n", got)
|
||||
}
|
||||
|
||||
// seq test
|
||||
input = ts.MustRandn([]int64{batchDim, seqLen, inputDim}, gotch.Float, gotch.CPU)
|
||||
output, _ = gru.Seq(input)
|
||||
wantSeq := []int64{batchDim, seqLen, outputDim * numDirections}
|
||||
gotSeq := output.(ts.Tensor).MustSize()
|
||||
|
||||
if !reflect.DeepEqual(wantSeq, gotSeq) {
|
||||
fmt.Println("Seq test:")
|
||||
t.Errorf("Expected ouput shape: %v\n", wantSeq)
|
||||
t.Errorf("Got output shape: %v\n", gotSeq)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestGRU(t *testing.T) {
|
||||
|
||||
cfg := nn.DefaultRNNConfig()
|
||||
|
||||
gruTest(cfg, t)
|
||||
|
||||
cfg.Bidirectional = true
|
||||
gruTest(cfg, t)
|
||||
|
||||
cfg.NumLayers = 2
|
||||
cfg.Bidirectional = false
|
||||
gruTest(cfg, t)
|
||||
|
||||
cfg.NumLayers = 2
|
||||
cfg.Bidirectional = true
|
||||
gruTest(cfg, t)
|
||||
}
|
||||
|
||||
func lstmTest(rnnConfig nn.RNNConfig, t *testing.T) {
|
||||
|
||||
var (
|
||||
batchDim int64 = 5
|
||||
seqLen int64 = 3
|
||||
inputDim int64 = 2
|
||||
outputDim int64 = 4
|
||||
)
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
path := vs.Root()
|
||||
|
||||
lstm := nn.NewLSTM(&path, inputDim, outputDim, rnnConfig)
|
||||
|
||||
numDirections := int64(1)
|
||||
if rnnConfig.Bidirectional {
|
||||
numDirections = 2
|
||||
}
|
||||
layerDim := rnnConfig.NumLayers * numDirections
|
||||
|
||||
// Step test
|
||||
input := ts.MustRandn([]int64{batchDim, inputDim}, gotch.Float, gotch.CPU)
|
||||
output := lstm.Step(input, lstm.ZeroState(batchDim).(nn.LSTMState))
|
||||
|
||||
wantH := []int64{layerDim, batchDim, outputDim}
|
||||
gotH := output.(nn.LSTMState).Tensor1.MustSize()
|
||||
wantC := []int64{layerDim, batchDim, outputDim}
|
||||
gotC := output.(nn.LSTMState).Tensor2.MustSize()
|
||||
|
||||
if !reflect.DeepEqual(wantH, gotH) {
|
||||
fmt.Println("Step test:")
|
||||
t.Errorf("Expected ouput H shape: %v\n", wantH)
|
||||
t.Errorf("Got output H shape: %v\n", gotH)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(wantC, gotC) {
|
||||
fmt.Println("Step test:")
|
||||
t.Errorf("Expected ouput C shape: %v\n", wantC)
|
||||
t.Errorf("Got output C shape: %v\n", gotC)
|
||||
}
|
||||
|
||||
// seq test
|
||||
input = ts.MustRandn([]int64{batchDim, seqLen, inputDim}, gotch.Float, gotch.CPU)
|
||||
output, _ = lstm.Seq(input)
|
||||
|
||||
wantSeq := []int64{batchDim, seqLen, outputDim * numDirections}
|
||||
gotSeq := output.(ts.Tensor).MustSize()
|
||||
|
||||
if !reflect.DeepEqual(wantSeq, gotSeq) {
|
||||
fmt.Println("Seq test:")
|
||||
t.Errorf("Expected ouput shape: %v\n", wantSeq)
|
||||
t.Errorf("Got output shape: %v\n", gotSeq)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLSTM(t *testing.T) {
|
||||
|
||||
cfg := nn.DefaultRNNConfig()
|
||||
|
||||
lstmTest(cfg, t)
|
||||
|
||||
cfg.Bidirectional = true
|
||||
lstmTest(cfg, t)
|
||||
|
||||
cfg.NumLayers = 2
|
||||
cfg.Bidirectional = false
|
||||
lstmTest(cfg, t)
|
||||
|
||||
cfg.NumLayers = 2
|
||||
cfg.Bidirectional = true
|
||||
lstmTest(cfg, t)
|
||||
}
|
Loading…
Reference in New Issue
Block a user