WIP(example/translation)
This commit is contained in:
parent
a71479e0f5
commit
b77fa54eb0
182
example/translation/dataset.go
Normal file
182
example/translation/dataset.go
Normal file
|
@ -0,0 +1,182 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
var Prefixes []string = []string{
|
||||||
|
"i am ",
|
||||||
|
"i m ",
|
||||||
|
"you are ",
|
||||||
|
"you re",
|
||||||
|
"he is ",
|
||||||
|
"he s ",
|
||||||
|
"she is ",
|
||||||
|
"she s ",
|
||||||
|
"we are ",
|
||||||
|
"we re ",
|
||||||
|
"they are ",
|
||||||
|
"they re ",
|
||||||
|
}
|
||||||
|
|
||||||
|
type Pair struct {
|
||||||
|
val1 string
|
||||||
|
val2 string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Dataset struct {
|
||||||
|
inputLang Lang
|
||||||
|
outputLang Lang
|
||||||
|
pairs []Pair
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalize(s string) (retVal string) {
|
||||||
|
|
||||||
|
lower := strings.ToLower(s)
|
||||||
|
/*
|
||||||
|
* // strip all spaces
|
||||||
|
* noSpaceStr := strings.Map(func(r rune) rune {
|
||||||
|
* if unicode.IsSpace(r) {
|
||||||
|
* return -1
|
||||||
|
* }
|
||||||
|
* return 1
|
||||||
|
* }, lower)
|
||||||
|
* */
|
||||||
|
// add single space before "!", ".", "?"
|
||||||
|
var res []rune
|
||||||
|
for _, r := range []rune(lower) {
|
||||||
|
char := fmt.Sprintf("%c", r)
|
||||||
|
switch {
|
||||||
|
case char == "!":
|
||||||
|
res = append(res, []rune(" !")...)
|
||||||
|
case char == ".":
|
||||||
|
res = append(res, []rune(" .")...)
|
||||||
|
case char == "?":
|
||||||
|
res = append(res, []rune(" ?")...)
|
||||||
|
case unicode.IsLetter(r), unicode.IsNumber(r):
|
||||||
|
res = append(res, r)
|
||||||
|
default:
|
||||||
|
res = append(res, []rune(" ")...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func toIndexes(s string, lang Lang) (retVal []int) {
|
||||||
|
res := strings.Split(s, " ")
|
||||||
|
|
||||||
|
for _, l := range res {
|
||||||
|
idx := lang.GetIndex(l)
|
||||||
|
if idx >= 0 {
|
||||||
|
retVal = append(retVal, idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = append(retVal, lang.EosToken())
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterPrefix(s string) (retVal bool) {
|
||||||
|
|
||||||
|
for _, prefix := range Prefixes {
|
||||||
|
if strings.HasPrefix(s, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func readPairs(ilang, olang string, maxLength int) (retVal []Pair) {
|
||||||
|
file, err := os.Open(fmt.Sprintf("../../data/translation/%v-%v.txt", ilang, olang))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
scanner := bufio.NewScanner(file)
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
|
||||||
|
if strings.Contains(line, "\t") {
|
||||||
|
// NOTE: assuming there's only 1 '\t'
|
||||||
|
pair := strings.Split(line, "\t")
|
||||||
|
lhs := normalize(pair[0])
|
||||||
|
rhs := normalize(pair[1])
|
||||||
|
|
||||||
|
if (len(strings.Split(lhs, " ")) < maxLength) && (len(strings.Split(rhs, " ")) < maxLength) && (filterPrefix(lhs) || filterPrefix(rhs)) {
|
||||||
|
retVal = append(retVal, Pair{lhs, rhs})
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
log.Fatalf("A line does not contain a single tab: %v\n", line)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDataset(ilang, olang string, maxLength int) (retVal Dataset) {
|
||||||
|
pairs := readPairs(ilang, olang, maxLength)
|
||||||
|
inputLang := NewLang(ilang)
|
||||||
|
outputLang := NewLang(olang)
|
||||||
|
|
||||||
|
for _, p := range pairs {
|
||||||
|
inputLang.AddSentence(p.val1)
|
||||||
|
outputLang.AddSentence(p.val2)
|
||||||
|
}
|
||||||
|
|
||||||
|
return Dataset{
|
||||||
|
inputLang: inputLang,
|
||||||
|
outputLang: outputLang,
|
||||||
|
pairs: pairs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ds Dataset) InputLang() (retVal Lang) {
|
||||||
|
return ds.inputLang
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ds Dataset) OutputLang() (retVal Lang) {
|
||||||
|
return ds.outputLang
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ds Dataset) Reverse() (retVal Dataset) {
|
||||||
|
var rpairs []Pair
|
||||||
|
for _, p := range ds.pairs {
|
||||||
|
rpairs = append(rpairs, Pair{p.val2, p.val1})
|
||||||
|
}
|
||||||
|
return Dataset{
|
||||||
|
inputLang: ds.outputLang,
|
||||||
|
outputLang: ds.inputLang,
|
||||||
|
pairs: rpairs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Pairs struct {
|
||||||
|
Val1 []int
|
||||||
|
Val2 []int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ds Dataset) Pairs() (retVal []Pairs) {
|
||||||
|
for _, p := range ds.pairs {
|
||||||
|
val1 := toIndexes(p.val1, ds.inputLang)
|
||||||
|
val2 := toIndexes(p.val2, ds.outputLang)
|
||||||
|
|
||||||
|
retVal = append(retVal, Pairs{val1, val2})
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
92
example/translation/lang.go
Normal file
92
example/translation/lang.go
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SosToken = "SOS"
|
||||||
|
EosToken = "EOS"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IndexCount struct {
|
||||||
|
Index int
|
||||||
|
Count int
|
||||||
|
}
|
||||||
|
|
||||||
|
type Lang struct {
|
||||||
|
Name string
|
||||||
|
WordToIndexAndCount map[string]IndexCount
|
||||||
|
IndexToWord map[int]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLang(name string) (retVal Lang) {
|
||||||
|
|
||||||
|
lang := Lang{
|
||||||
|
Name: name,
|
||||||
|
WordToIndexAndCount: make(map[string]IndexCount, 0),
|
||||||
|
IndexToWord: make(map[int]string, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
lang.AddWord(SosToken)
|
||||||
|
lang.AddWord(EosToken)
|
||||||
|
|
||||||
|
return lang
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lang) AddWord(word string) {
|
||||||
|
if len(word) > 0 {
|
||||||
|
idxCount, ok := l.WordToIndexAndCount[word]
|
||||||
|
if !ok {
|
||||||
|
length := len(l.WordToIndexAndCount)
|
||||||
|
l.WordToIndexAndCount[word] = IndexCount{length, 1}
|
||||||
|
l.IndexToWord[length] = word
|
||||||
|
} else {
|
||||||
|
idxCount.Count += 1
|
||||||
|
l.WordToIndexAndCount[word] = idxCount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lang) AddSentence(sentence string) {
|
||||||
|
words := strings.Split(sentence, " ")
|
||||||
|
for _, word := range words {
|
||||||
|
l.AddWord(word)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lang) Len() (retVal int) {
|
||||||
|
return len(l.IndexToWord)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lang) SosToken() (retVal int) {
|
||||||
|
return l.WordToIndexAndCount[SosToken].Index
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lang) EosToken() (retVal int) {
|
||||||
|
return l.WordToIndexAndCount[EosToken].Index
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lang) GetName() (retVal string) {
|
||||||
|
return l.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lang) GetIndex(word string) (retVal int) {
|
||||||
|
idxCount, ok := l.WordToIndexAndCount[word]
|
||||||
|
if ok {
|
||||||
|
return idxCount.Index
|
||||||
|
} else {
|
||||||
|
return -1 // word does not exist in Lang
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lang) SeqToString(seq []int) (retVal string) {
|
||||||
|
var words []string = make([]string, 0)
|
||||||
|
|
||||||
|
for _, idx := range seq {
|
||||||
|
w := l.IndexToWord[idx]
|
||||||
|
words = append(words, w)
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(words, "")
|
||||||
|
}
|
315
example/translation/main.go
Normal file
315
example/translation/main.go
Normal file
|
@ -0,0 +1,315 @@
|
||||||
|
/* Translation with a Sequence to Sequence Model and Attention.
|
||||||
|
|
||||||
|
This follows the line of the PyTorch tutorial:
|
||||||
|
https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
|
||||||
|
And trains a Sequence to Sequence (seq2seq) model using attention to
|
||||||
|
perform translation between French and English.
|
||||||
|
|
||||||
|
The dataset can be downloaded from the following link:
|
||||||
|
https://download.pytorch.org/tutorial/data.zip
|
||||||
|
The eng-fra.txt file should be moved in the data directory.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"math/rand"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
"github.com/sugarme/gotch/nn"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
MaxLength int64 = 10
|
||||||
|
LearningRate float64 = 0.001
|
||||||
|
Samples int64 = 100000
|
||||||
|
HiddenSize int64 = 256
|
||||||
|
)
|
||||||
|
|
||||||
|
type Encoder struct {
|
||||||
|
embedding nn.Embedding
|
||||||
|
gru nn.GRU
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEncoder(vs nn.Path, inDim, hiddenDim int64) (retVal Encoder) {
|
||||||
|
|
||||||
|
gru := nn.NewGRU(vs, hiddenDim, hiddenDim, nn.DefaultRNNConfig())
|
||||||
|
|
||||||
|
embedding := nn.NewEmbedding(vs, inDim, hiddenDim, nn.DefaultEmbeddingConfig())
|
||||||
|
|
||||||
|
return Encoder{embedding, gru}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Encoder) forward(xs ts.Tensor, state nn.GRUState) (retTs ts.Tensor, retState nn.GRUState) {
|
||||||
|
|
||||||
|
retTs = e.embedding.Forward(xs).MustView([]int64{1, -1}, false)
|
||||||
|
retState = e.gru.Step(retTs, state).(nn.GRUState)
|
||||||
|
|
||||||
|
return retTs, retState
|
||||||
|
}
|
||||||
|
|
||||||
|
type Decoder struct {
|
||||||
|
device gotch.Device
|
||||||
|
embedding nn.Embedding
|
||||||
|
gru nn.GRU
|
||||||
|
attn nn.Linear
|
||||||
|
attnCombine nn.Linear
|
||||||
|
linear nn.Linear
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDecoder(vs nn.Path, hiddenDim, outDim int64) (retVal Decoder) {
|
||||||
|
|
||||||
|
return Decoder{
|
||||||
|
device: vs.Device(),
|
||||||
|
embedding: nn.NewEmbedding(vs, outDim, hiddenDim, nn.DefaultEmbeddingConfig()),
|
||||||
|
gru: nn.NewGRU(vs, hiddenDim, hiddenDim, nn.DefaultRNNConfig()),
|
||||||
|
attn: nn.NewLinear(vs, hiddenDim*2, MaxLength, nn.DefaultLinearConfig()),
|
||||||
|
attnCombine: nn.NewLinear(vs, hiddenDim*2, hiddenDim, nn.DefaultLinearConfig()),
|
||||||
|
linear: nn.NewLinear(vs, hiddenDim, outDim, nn.DefaultLinearConfig()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor, isTraining bool) (retTs ts.Tensor, retState nn.GRUState) {
|
||||||
|
forwardTsTmp := d.embedding.Forward(xs)
|
||||||
|
forwardTsTmp.MustDropout_(0.1, isTraining)
|
||||||
|
forwardTs := forwardTsTmp.MustView([]int64{1, -1}, true)
|
||||||
|
|
||||||
|
// NOTE. forwardTs shape: [1, 256] state [1, 1, 256]
|
||||||
|
// hence, just get state[0] of 3D tensor state
|
||||||
|
stateTs := state.Value()
|
||||||
|
state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)})
|
||||||
|
|
||||||
|
catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1)
|
||||||
|
|
||||||
|
// NOTE. d.attn Ws shape : [512, 10]
|
||||||
|
appliedTs := catTs.Apply(d.attn)
|
||||||
|
attnWeights := appliedTs.MustUnsqueeze(0, true)
|
||||||
|
|
||||||
|
size3, err := encOutputs.Size3()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sz1 := size3[0]
|
||||||
|
sz2 := size3[1]
|
||||||
|
sz3 := size3[2]
|
||||||
|
|
||||||
|
var encOutputsTs ts.Tensor
|
||||||
|
if sz2 == MaxLength {
|
||||||
|
encOutputsTs = encOutputs.MustShallowClone()
|
||||||
|
} else {
|
||||||
|
shape := []int64{sz1, MaxLength - sz2, sz3}
|
||||||
|
zerosTs := ts.MustZeros(shape, gotch.Float, d.device)
|
||||||
|
encOutputsTs = ts.MustCat([]ts.Tensor{encOutputs, zerosTs}, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
attnApplied := attnWeights.MustBmm(encOutputsTs, true).MustSqueeze1(1, true)
|
||||||
|
fmt.Printf("attnApplied shape: %v\n", attnApplied.MustSize())
|
||||||
|
fmt.Printf("xs shape: %v\n", forwardTs.MustSize())
|
||||||
|
|
||||||
|
xsTs := ts.MustCat([]ts.Tensor{forwardTs, attnApplied}, 1).Apply(d.attnCombine).MustRelu(true)
|
||||||
|
|
||||||
|
retState = d.gru.Step(xsTs, state).(nn.GRUState)
|
||||||
|
|
||||||
|
retTs = d.linear.Forward(retState.Value()).MustLogSoftmax(-1, gotch.Float, true)
|
||||||
|
|
||||||
|
return retTs, retState
|
||||||
|
}
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
encoder Encoder
|
||||||
|
decoder Decoder
|
||||||
|
decoderStart ts.Tensor
|
||||||
|
decoderEos int64
|
||||||
|
device gotch.Device
|
||||||
|
}
|
||||||
|
|
||||||
|
func newModel(vs nn.Path, ilang Lang, olang Lang, hiddenDim int64) (retVal Model) {
|
||||||
|
return Model{
|
||||||
|
encoder: newEncoder(vs.Sub("enc"), int64(ilang.Len()), hiddenDim),
|
||||||
|
decoder: newDecoder(vs.Sub("dec"), hiddenDim, int64(olang.Len())),
|
||||||
|
decoderStart: ts.MustOfSlice([]int64{int64(olang.SosToken())}).MustTo(vs.Device(), true),
|
||||||
|
decoderEos: int64(olang.EosToken()),
|
||||||
|
device: vs.Device(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
|
||||||
|
state := m.encoder.gru.ZeroState(1)
|
||||||
|
fmt.Printf("state shape: %v\n", state.(nn.GRUState).Value().MustSize())
|
||||||
|
var encOutputs []ts.Tensor
|
||||||
|
|
||||||
|
for _, v := range input {
|
||||||
|
s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true)
|
||||||
|
|
||||||
|
outTs, outState := m.encoder.forward(s, state.(nn.GRUState))
|
||||||
|
|
||||||
|
encOutputs = append(encOutputs, outTs)
|
||||||
|
state.(nn.GRUState).Tensor.MustDrop()
|
||||||
|
state = outState
|
||||||
|
}
|
||||||
|
|
||||||
|
stackTs := ts.MustStack(encOutputs, 1)
|
||||||
|
for _, t := range encOutputs {
|
||||||
|
t.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: should we implement random here???
|
||||||
|
loss := ts.TensorFrom([]float32{0.0}).MustTo(m.device, true)
|
||||||
|
prev := m.decoderStart.MustShallowClone()
|
||||||
|
|
||||||
|
for _, s := range target {
|
||||||
|
outTs, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true)
|
||||||
|
|
||||||
|
state.(nn.GRUState).Tensor.MustDrop()
|
||||||
|
state = outState
|
||||||
|
|
||||||
|
targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true)
|
||||||
|
fmt.Printf("targetTs shape: %v\n", targetTs.MustSize())
|
||||||
|
fmt.Printf("outTs shape: %v\n", outTs.MustSize())
|
||||||
|
|
||||||
|
// TODO: fix the error: input tensor should be 1D or 2D at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:35
|
||||||
|
ws := ts.NewTensor()
|
||||||
|
currLoss := outTs.MustNllLoss(targetTs, ws, int64(1), -100, false)
|
||||||
|
|
||||||
|
loss.MustAdd_(currLoss)
|
||||||
|
currLoss.MustDrop()
|
||||||
|
|
||||||
|
_, output := outTs.MustTopK(1, -1, true, true)
|
||||||
|
|
||||||
|
if m.decoderEos == outTs.Int64Values()[0] {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
prev.MustDrop()
|
||||||
|
prev = output
|
||||||
|
}
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) predict(input []int) (retVal []int) {
|
||||||
|
state := m.encoder.gru.ZeroState(1)
|
||||||
|
var encOutputs []ts.Tensor
|
||||||
|
|
||||||
|
for _, v := range input {
|
||||||
|
s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true)
|
||||||
|
outTs, outState := m.encoder.forward(s, state.(nn.GRUState))
|
||||||
|
|
||||||
|
encOutputs = append(encOutputs, outTs)
|
||||||
|
state.(nn.GRUState).Tensor.MustDrop()
|
||||||
|
state = outState
|
||||||
|
}
|
||||||
|
|
||||||
|
stackTs := ts.MustStack(encOutputs, 1)
|
||||||
|
for _, t := range encOutputs {
|
||||||
|
t.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
prev := m.decoderStart.MustShallowClone()
|
||||||
|
var outputSeq []int
|
||||||
|
|
||||||
|
for i := 0; i < int(MaxLength); i++ {
|
||||||
|
outTs, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true)
|
||||||
|
_, output := outTs.MustTopK(1, -1, true, true)
|
||||||
|
outputVal := output.Int64Values()[0]
|
||||||
|
outputSeq = append(outputSeq, int(outputVal))
|
||||||
|
|
||||||
|
if m.decoderEos == outTs.Int64Values()[0] {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
state.(nn.GRUState).Tensor.MustDrop()
|
||||||
|
state = outState
|
||||||
|
prev.MustDrop()
|
||||||
|
prev = output
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputSeq
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
type LossStats struct {
|
||||||
|
totalLoss float64
|
||||||
|
samples int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLossStats() (retVal LossStats) {
|
||||||
|
return LossStats{
|
||||||
|
totalLoss: 0.0,
|
||||||
|
samples: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ls *LossStats) update(loss float64) {
|
||||||
|
ls.totalLoss += loss
|
||||||
|
ls.samples += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ls *LossStats) avgAndReset() (retVal float64) {
|
||||||
|
avg := ls.totalLoss / float64(ls.samples)
|
||||||
|
ls.totalLoss = 0.0
|
||||||
|
ls.samples = 0
|
||||||
|
return avg
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
|
||||||
|
dataset := newDataset("eng", "fra", int(MaxLength)).Reverse()
|
||||||
|
|
||||||
|
ilang := dataset.InputLang()
|
||||||
|
olang := dataset.OutputLang()
|
||||||
|
pairs := dataset.Pairs()
|
||||||
|
|
||||||
|
fmt.Printf("Input: %v %v words\n", ilang.GetName(), ilang.Len())
|
||||||
|
fmt.Printf("Output: %v %v words\n", olang.GetName(), olang.Len())
|
||||||
|
fmt.Printf("Pairs: %v\n", len(pairs))
|
||||||
|
|
||||||
|
// TODO: should we implement random here??
|
||||||
|
|
||||||
|
cuda := gotch.NewCuda()
|
||||||
|
device := cuda.CudaIfAvailable()
|
||||||
|
|
||||||
|
vs := nn.NewVarStore(device)
|
||||||
|
|
||||||
|
model := newModel(vs.Root(), ilang, olang, HiddenSize)
|
||||||
|
|
||||||
|
optConfig := nn.DefaultAdamConfig()
|
||||||
|
opt, err := optConfig.Build(vs, LearningRate)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lossStats := newLossStats()
|
||||||
|
|
||||||
|
for i := 1; i < int(Samples); i++ {
|
||||||
|
// randomly choose a pair
|
||||||
|
idx := rand.Intn(len(pairs))
|
||||||
|
pair := pairs[idx]
|
||||||
|
input := pair.Val1
|
||||||
|
target := pair.Val2
|
||||||
|
loss := model.trainLoss(input, target)
|
||||||
|
panic("reached")
|
||||||
|
opt.BackwardStep(loss)
|
||||||
|
lossStats.update(loss.Float64Values()[0] / float64(len(target)))
|
||||||
|
|
||||||
|
if i%1000 == 0 {
|
||||||
|
fmt.Printf("%v %v\n", i, lossStats.avgAndReset())
|
||||||
|
for predIdx := 1; predIdx <= 5; predIdx++ {
|
||||||
|
idx := rand.Intn(len(pairs))
|
||||||
|
in := pairs[idx].Val1
|
||||||
|
tgt := pairs[idx].Val2
|
||||||
|
predict := model.predict(in)
|
||||||
|
|
||||||
|
fmt.Printf("input: %v\n", ilang.SeqToString(in))
|
||||||
|
fmt.Printf("target: %v\n", olang.SeqToString(tgt))
|
||||||
|
fmt.Printf("ouput: %v\n", olang.SeqToString(predict))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -75,16 +75,19 @@ func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tens
|
||||||
var err error
|
var err error
|
||||||
rand.Seed(86)
|
rand.Seed(86)
|
||||||
|
|
||||||
data := make([]float64, ts.FlattenDim(dims))
|
data := make([]float32, ts.FlattenDim(dims))
|
||||||
for i := range data {
|
for i := range data {
|
||||||
data[i] = rand.NormFloat64()*r.mean + r.stdev
|
// NOTE. tensor will have DType = Float (float32)
|
||||||
|
data[i] = float32(rand.NormFloat64()*r.mean + r.stdev)
|
||||||
}
|
}
|
||||||
|
|
||||||
retVal, err = ts.NewTensorFromData(data, dims)
|
newTs, err := ts.NewTensorFromData(data, dims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("randInit - InitTensor method call error: %v\n", err)
|
log.Fatalf("randInit - InitTensor method call error: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
retVal = newTs.MustTo(device, true)
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
20
nn/rnn.go
20
nn/rnn.go
|
@ -1,6 +1,8 @@
|
||||||
package nn
|
package nn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
"github.com/sugarme/gotch"
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
)
|
)
|
||||||
|
@ -199,7 +201,7 @@ type GRU struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGRU create a new GRU layer
|
// NewGRU create a new GRU layer
|
||||||
func NewGRU(vs *Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
|
func NewGRU(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
|
||||||
var numDirections int64 = 1
|
var numDirections int64 = 1
|
||||||
if cfg.Bidirectional {
|
if cfg.Bidirectional {
|
||||||
numDirections = 2
|
numDirections = 2
|
||||||
|
@ -210,11 +212,14 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
|
||||||
|
|
||||||
for i := 0; i < int(cfg.NumLayers); i++ {
|
for i := 0; i < int(cfg.NumLayers); i++ {
|
||||||
for n := 0; n < int(numDirections); n++ {
|
for n := 0; n < int(numDirections); n++ {
|
||||||
if i != 0 {
|
var inputDim int64
|
||||||
inDim = hiddenDim * numDirections
|
if i == 0 {
|
||||||
|
inputDim = inDim
|
||||||
|
} else {
|
||||||
|
inputDim = hiddenDim * numDirections
|
||||||
}
|
}
|
||||||
|
|
||||||
wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inDim})
|
wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inputDim})
|
||||||
wHh := vs.KaimingUniform("w_hh", []int64{gateDim, hiddenDim})
|
wHh := vs.KaimingUniform("w_hh", []int64{gateDim, hiddenDim})
|
||||||
bIh := vs.Zeros("b_ih", []int64{gateDim})
|
bIh := vs.Zeros("b_ih", []int64{gateDim})
|
||||||
bHh := vs.Zeros("b_hh", []int64{gateDim})
|
bHh := vs.Zeros("b_hh", []int64{gateDim})
|
||||||
|
@ -249,9 +254,12 @@ func (g GRU) ZeroState(batchDim int64) (retVal State) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g GRU) Step(input ts.Tensor, inState State) (retVal State) {
|
func (g GRU) Step(input ts.Tensor, inState State) (retVal State) {
|
||||||
ip := input.MustUnsqueeze(1, false)
|
unsqueezedInput := input.MustUnsqueeze(1, false)
|
||||||
|
|
||||||
output, state := g.SeqInit(ip, inState)
|
fmt.Printf("unsqueezed input shape: %v\n", unsqueezedInput.MustSize())
|
||||||
|
fmt.Printf("inState Ts size: %v\n", inState.(GRUState).Tensor.MustSize())
|
||||||
|
|
||||||
|
output, state := g.SeqInit(unsqueezedInput, inState)
|
||||||
|
|
||||||
// NOTE: though we won't use `output`, it is a Ctensor created in C land, so
|
// NOTE: though we won't use `output`, it is a Ctensor created in C land, so
|
||||||
// it should be cleaned up here to prevent memory hold-up.
|
// it should be cleaned up here to prevent memory hold-up.
|
||||||
|
|
|
@ -22,7 +22,7 @@ func gruTest(rnnConfig nn.RNNConfig, t *testing.T) {
|
||||||
vs := nn.NewVarStore(gotch.CPU)
|
vs := nn.NewVarStore(gotch.CPU)
|
||||||
path := vs.Root()
|
path := vs.Root()
|
||||||
|
|
||||||
gru := nn.NewGRU(&path, inputDim, outputDim, rnnConfig)
|
gru := nn.NewGRU(path, inputDim, outputDim, rnnConfig)
|
||||||
|
|
||||||
numDirections := int64(1)
|
numDirections := int64(1)
|
||||||
if rnnConfig.Bidirectional {
|
if rnnConfig.Bidirectional {
|
||||||
|
|
|
@ -297,6 +297,15 @@ func (ts Tensor) Device() (retVal gotch.Device, err error) {
|
||||||
return device.OfCInt(int32(cInt)), nil
|
return device.OfCInt(int32(cInt)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustDevice() (retVal gotch.Device) {
|
||||||
|
retVal, err := ts.Device()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* func (ts Tensor) Eq1(other Tensor, del bool) (retVal Tensor, err error) {
|
* func (ts Tensor) Eq1(other Tensor, del bool) (retVal Tensor, err error) {
|
||||||
*
|
*
|
||||||
|
|
Loading…
Reference in New Issue
Block a user