WIP(example/char-rnn)
This commit is contained in:
parent
cddf4c8a77
commit
38b8139bd5
|
@ -2,10 +2,11 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/nn"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -13,11 +14,86 @@ const (
|
|||
HiddenSize int64 = 256
|
||||
SeqLen int64 = 180
|
||||
BatchSize int64 = 256
|
||||
Epochs int64 = 100
|
||||
Epochs int = 100
|
||||
SamplingLen int64 = 1024
|
||||
)
|
||||
|
||||
func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Device) (retVal string) {
|
||||
|
||||
return
|
||||
labels := data.Labels()
|
||||
state := lstm.ZeroState(1)
|
||||
lastLabel := int64(0)
|
||||
var result string
|
||||
|
||||
for i := 0; i < int(SamplingLen); i++ {
|
||||
|
||||
input := ts.MustZeros([]int64{1, labels}, gotch.Float, device)
|
||||
input.MustNarrow(1, lastLabel, 1, false).MustFill_(ts.FloatScalar(1.0))
|
||||
state = lstm.Step(input, state)
|
||||
|
||||
forwardTs := linear.Forward(state.(nn.LSTMState).H())
|
||||
squeeze1Ts := forwardTs.MustSqueeze1(0, false)
|
||||
softmaxTs := squeeze1Ts.MustSoftmax(-1, gotch.Float, false)
|
||||
sampledY := softmaxTs.MustMultinomial(1, false, false)
|
||||
|
||||
lastLabel = sampledY.Int64Values()[0]
|
||||
|
||||
result += fmt.Sprintf("%v", lastLabel)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func main() {
|
||||
cuda := gotch.NewCuda()
|
||||
device := cuda.CudaIfAvailable()
|
||||
|
||||
vs := nn.NewVarStore(device)
|
||||
data, err := ts.NewTextData("../../data/char-rnn/input.txt")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
labels := data.Labels()
|
||||
fmt.Printf("Dataset loaded, %v labels\n", labels)
|
||||
|
||||
lstm := nn.NewLSTM(vs.Root(), labels, HiddenSize, nn.DefaultRNNConfig())
|
||||
linear := nn.NewLinear(vs.Root(), HiddenSize, labels, nn.DefaultLinearConfig())
|
||||
optConfig := nn.DefaultAdamConfig()
|
||||
opt, err := optConfig.Build(vs, LearningRate)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for epoch := 1; epoch <= Epochs; epoch++ {
|
||||
sumLoss := 0.0
|
||||
cntLoss := 0.0
|
||||
|
||||
dataIter := data.IterShuffle(SeqLen+1, BatchSize)
|
||||
|
||||
for {
|
||||
batchTs, ok := dataIter.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
batchNarrow := batchTs.MustNarrow(1, 1, SeqLen, false)
|
||||
xsOnehot := batchNarrow.MustOneHot(labels, false)
|
||||
ys := batchTs.MustNarrow(1, 1, SeqLen, false).MustTotype(gotch.Int64, false)
|
||||
lstmOut, _ := lstm.Seq(xsOnehot.MustTo(device, false))
|
||||
logits := linear.Forward(lstmOut)
|
||||
|
||||
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, false)
|
||||
loss := lossView.CrossEntropyForLogits(ys.MustTo(device, false)).MustView([]int64{BatchSize * SeqLen}, false)
|
||||
|
||||
opt.BackwardStepClip(loss, 0.5)
|
||||
sumLoss += loss.Float64Values()[0]
|
||||
cntLoss += 1.0
|
||||
}
|
||||
|
||||
fmt.Printf("Epoch %v - Loss: %v", epoch, sumLoss/cntLoss)
|
||||
fmt.Printf("Sample: %v", sample(data, lstm, linear, device))
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -86,7 +86,7 @@ type LSTM struct {
|
|||
}
|
||||
|
||||
// NewLSTM creates a LSTM layer.
|
||||
func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal LSTM) {
|
||||
func NewLSTM(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal LSTM) {
|
||||
|
||||
var numDirections int64 = 1
|
||||
if cfg.Bidirectional {
|
||||
|
|
|
@ -87,7 +87,7 @@ func lstmTest(rnnConfig nn.RNNConfig, t *testing.T) {
|
|||
vs := nn.NewVarStore(gotch.CPU)
|
||||
path := vs.Root()
|
||||
|
||||
lstm := nn.NewLSTM(&path, inputDim, outputDim, rnnConfig)
|
||||
lstm := nn.NewLSTM(path, inputDim, outputDim, rnnConfig)
|
||||
|
||||
numDirections := int64(1)
|
||||
if rnnConfig.Bidirectional {
|
||||
|
|
|
@ -1038,6 +1038,19 @@ func (ts Tensor) Float64Values() []float64 {
|
|||
return vec
|
||||
}
|
||||
|
||||
// Int64Values returns values of tensor in a slice of int64.
|
||||
func (ts Tensor) Int64Values() []int64 {
|
||||
numel := ts.Numel()
|
||||
vec := make([]int64, numel)
|
||||
|
||||
int64Ts := ts.MustTotype(gotch.Int64, false)
|
||||
|
||||
int64Ts.MustCopyData(vec, numel)
|
||||
int64Ts.MustDrop()
|
||||
|
||||
return vec
|
||||
}
|
||||
|
||||
// Vals returns tensor values in a slice
|
||||
// NOTE: need a type insersion to get runtime type
|
||||
// E.g. res := xs.Vals().([]int64)
|
||||
|
|
Loading…
Reference in New Issue
Block a user