WIP(example/char-rnn)

This commit is contained in:
sugarme 2020-07-25 19:02:48 +10:00
parent cddf4c8a77
commit 38b8139bd5
4 changed files with 94 additions and 5 deletions

View File

@ -2,10 +2,11 @@ package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
"github.com/sugarme/nn"
)
const (
@ -13,11 +14,86 @@ const (
HiddenSize int64 = 256
SeqLen int64 = 180
BatchSize int64 = 256
Epochs int64 = 100
Epochs int = 100
SamplingLen int64 = 1024
)
func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Device) (retVal string) {
return
labels := data.Labels()
state := lstm.ZeroState(1)
lastLabel := int64(0)
var result string
for i := 0; i < int(SamplingLen); i++ {
input := ts.MustZeros([]int64{1, labels}, gotch.Float, device)
input.MustNarrow(1, lastLabel, 1, false).MustFill_(ts.FloatScalar(1.0))
state = lstm.Step(input, state)
forwardTs := linear.Forward(state.(nn.LSTMState).H())
squeeze1Ts := forwardTs.MustSqueeze1(0, false)
softmaxTs := squeeze1Ts.MustSoftmax(-1, gotch.Float, false)
sampledY := softmaxTs.MustMultinomial(1, false, false)
lastLabel = sampledY.Int64Values()[0]
result += fmt.Sprintf("%v", lastLabel)
}
return result
}
func main() {
cuda := gotch.NewCuda()
device := cuda.CudaIfAvailable()
vs := nn.NewVarStore(device)
data, err := ts.NewTextData("../../data/char-rnn/input.txt")
if err != nil {
log.Fatal(err)
}
labels := data.Labels()
fmt.Printf("Dataset loaded, %v labels\n", labels)
lstm := nn.NewLSTM(vs.Root(), labels, HiddenSize, nn.DefaultRNNConfig())
linear := nn.NewLinear(vs.Root(), HiddenSize, labels, nn.DefaultLinearConfig())
optConfig := nn.DefaultAdamConfig()
opt, err := optConfig.Build(vs, LearningRate)
if err != nil {
log.Fatal(err)
}
for epoch := 1; epoch <= Epochs; epoch++ {
sumLoss := 0.0
cntLoss := 0.0
dataIter := data.IterShuffle(SeqLen+1, BatchSize)
for {
batchTs, ok := dataIter.Next()
if !ok {
break
}
batchNarrow := batchTs.MustNarrow(1, 1, SeqLen, false)
xsOnehot := batchNarrow.MustOneHot(labels, false)
ys := batchTs.MustNarrow(1, 1, SeqLen, false).MustTotype(gotch.Int64, false)
lstmOut, _ := lstm.Seq(xsOnehot.MustTo(device, false))
logits := linear.Forward(lstmOut)
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, false)
loss := lossView.CrossEntropyForLogits(ys.MustTo(device, false)).MustView([]int64{BatchSize * SeqLen}, false)
opt.BackwardStepClip(loss, 0.5)
sumLoss += loss.Float64Values()[0]
cntLoss += 1.0
}
fmt.Printf("Epoch %v - Loss: %v", epoch, sumLoss/cntLoss)
fmt.Printf("Sample: %v", sample(data, lstm, linear, device))
}
}

View File

@ -86,7 +86,7 @@ type LSTM struct {
}
// NewLSTM creates a LSTM layer.
func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal LSTM) {
func NewLSTM(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal LSTM) {
var numDirections int64 = 1
if cfg.Bidirectional {

View File

@ -87,7 +87,7 @@ func lstmTest(rnnConfig nn.RNNConfig, t *testing.T) {
vs := nn.NewVarStore(gotch.CPU)
path := vs.Root()
lstm := nn.NewLSTM(&path, inputDim, outputDim, rnnConfig)
lstm := nn.NewLSTM(path, inputDim, outputDim, rnnConfig)
numDirections := int64(1)
if rnnConfig.Bidirectional {

View File

@ -1038,6 +1038,19 @@ func (ts Tensor) Float64Values() []float64 {
return vec
}
// Int64Values returns values of tensor in a slice of int64.
func (ts Tensor) Int64Values() []int64 {
numel := ts.Numel()
vec := make([]int64, numel)
int64Ts := ts.MustTotype(gotch.Int64, false)
int64Ts.MustCopyData(vec, numel)
int64Ts.MustDrop()
return vec
}
// Vals returns tensor values in a slice
// NOTE: need a type insersion to get runtime type
// E.g. res := xs.Vals().([]int64)