WIP(example/translation)

This commit is contained in:
sugarme 2020-08-01 16:33:30 +10:00
parent a71479e0f5
commit b77fa54eb0
7 changed files with 619 additions and 10 deletions

View File

@ -0,0 +1,182 @@
package main
import (
"bufio"
"fmt"
"log"
"os"
"strings"
"unicode"
)
var Prefixes []string = []string{
"i am ",
"i m ",
"you are ",
"you re",
"he is ",
"he s ",
"she is ",
"she s ",
"we are ",
"we re ",
"they are ",
"they re ",
}
type Pair struct {
val1 string
val2 string
}
type Dataset struct {
inputLang Lang
outputLang Lang
pairs []Pair
}
func normalize(s string) (retVal string) {
lower := strings.ToLower(s)
/*
* // strip all spaces
* noSpaceStr := strings.Map(func(r rune) rune {
* if unicode.IsSpace(r) {
* return -1
* }
* return 1
* }, lower)
* */
// add single space before "!", ".", "?"
var res []rune
for _, r := range []rune(lower) {
char := fmt.Sprintf("%c", r)
switch {
case char == "!":
res = append(res, []rune(" !")...)
case char == ".":
res = append(res, []rune(" .")...)
case char == "?":
res = append(res, []rune(" ?")...)
case unicode.IsLetter(r), unicode.IsNumber(r):
res = append(res, r)
default:
res = append(res, []rune(" ")...)
}
}
return string(res)
}
func toIndexes(s string, lang Lang) (retVal []int) {
res := strings.Split(s, " ")
for _, l := range res {
idx := lang.GetIndex(l)
if idx >= 0 {
retVal = append(retVal, idx)
}
}
retVal = append(retVal, lang.EosToken())
return retVal
}
func filterPrefix(s string) (retVal bool) {
for _, prefix := range Prefixes {
if strings.HasPrefix(s, prefix) {
return true
}
}
return false
}
func readPairs(ilang, olang string, maxLength int) (retVal []Pair) {
file, err := os.Open(fmt.Sprintf("../../data/translation/%v-%v.txt", ilang, olang))
if err != nil {
log.Fatal(err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if strings.Contains(line, "\t") {
// NOTE: assuming there's only 1 '\t'
pair := strings.Split(line, "\t")
lhs := normalize(pair[0])
rhs := normalize(pair[1])
if (len(strings.Split(lhs, " ")) < maxLength) && (len(strings.Split(rhs, " ")) < maxLength) && (filterPrefix(lhs) || filterPrefix(rhs)) {
retVal = append(retVal, Pair{lhs, rhs})
}
} else {
log.Fatalf("A line does not contain a single tab: %v\n", line)
}
}
if err := scanner.Err(); err != nil {
log.Fatal(err)
}
return retVal
}
func newDataset(ilang, olang string, maxLength int) (retVal Dataset) {
pairs := readPairs(ilang, olang, maxLength)
inputLang := NewLang(ilang)
outputLang := NewLang(olang)
for _, p := range pairs {
inputLang.AddSentence(p.val1)
outputLang.AddSentence(p.val2)
}
return Dataset{
inputLang: inputLang,
outputLang: outputLang,
pairs: pairs,
}
}
func (ds Dataset) InputLang() (retVal Lang) {
return ds.inputLang
}
func (ds Dataset) OutputLang() (retVal Lang) {
return ds.outputLang
}
func (ds Dataset) Reverse() (retVal Dataset) {
var rpairs []Pair
for _, p := range ds.pairs {
rpairs = append(rpairs, Pair{p.val2, p.val1})
}
return Dataset{
inputLang: ds.outputLang,
outputLang: ds.inputLang,
pairs: rpairs,
}
}
type Pairs struct {
Val1 []int
Val2 []int
}
func (ds Dataset) Pairs() (retVal []Pairs) {
for _, p := range ds.pairs {
val1 := toIndexes(p.val1, ds.inputLang)
val2 := toIndexes(p.val2, ds.outputLang)
retVal = append(retVal, Pairs{val1, val2})
}
return retVal
}

View File

@ -0,0 +1,92 @@
package main
import (
"strings"
)
const (
SosToken = "SOS"
EosToken = "EOS"
)
type IndexCount struct {
Index int
Count int
}
type Lang struct {
Name string
WordToIndexAndCount map[string]IndexCount
IndexToWord map[int]string
}
func NewLang(name string) (retVal Lang) {
lang := Lang{
Name: name,
WordToIndexAndCount: make(map[string]IndexCount, 0),
IndexToWord: make(map[int]string, 0),
}
lang.AddWord(SosToken)
lang.AddWord(EosToken)
return lang
}
func (l *Lang) AddWord(word string) {
if len(word) > 0 {
idxCount, ok := l.WordToIndexAndCount[word]
if !ok {
length := len(l.WordToIndexAndCount)
l.WordToIndexAndCount[word] = IndexCount{length, 1}
l.IndexToWord[length] = word
} else {
idxCount.Count += 1
l.WordToIndexAndCount[word] = idxCount
}
}
}
func (l *Lang) AddSentence(sentence string) {
words := strings.Split(sentence, " ")
for _, word := range words {
l.AddWord(word)
}
}
func (l *Lang) Len() (retVal int) {
return len(l.IndexToWord)
}
func (l *Lang) SosToken() (retVal int) {
return l.WordToIndexAndCount[SosToken].Index
}
func (l *Lang) EosToken() (retVal int) {
return l.WordToIndexAndCount[EosToken].Index
}
func (l *Lang) GetName() (retVal string) {
return l.Name
}
func (l *Lang) GetIndex(word string) (retVal int) {
idxCount, ok := l.WordToIndexAndCount[word]
if ok {
return idxCount.Index
} else {
return -1 // word does not exist in Lang
}
}
func (l *Lang) SeqToString(seq []int) (retVal string) {
var words []string = make([]string, 0)
for _, idx := range seq {
w := l.IndexToWord[idx]
words = append(words, w)
}
return strings.Join(words, "")
}

315
example/translation/main.go Normal file
View File

@ -0,0 +1,315 @@
/* Translation with a Sequence to Sequence Model and Attention.
This follows the line of the PyTorch tutorial:
https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
And trains a Sequence to Sequence (seq2seq) model using attention to
perform translation between French and English.
The dataset can be downloaded from the following link:
https://download.pytorch.org/tutorial/data.zip
The eng-fra.txt file should be moved in the data directory.
*/
package main
import (
"fmt"
"log"
"math/rand"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
)
var (
MaxLength int64 = 10
LearningRate float64 = 0.001
Samples int64 = 100000
HiddenSize int64 = 256
)
type Encoder struct {
embedding nn.Embedding
gru nn.GRU
}
func newEncoder(vs nn.Path, inDim, hiddenDim int64) (retVal Encoder) {
gru := nn.NewGRU(vs, hiddenDim, hiddenDim, nn.DefaultRNNConfig())
embedding := nn.NewEmbedding(vs, inDim, hiddenDim, nn.DefaultEmbeddingConfig())
return Encoder{embedding, gru}
}
func (e Encoder) forward(xs ts.Tensor, state nn.GRUState) (retTs ts.Tensor, retState nn.GRUState) {
retTs = e.embedding.Forward(xs).MustView([]int64{1, -1}, false)
retState = e.gru.Step(retTs, state).(nn.GRUState)
return retTs, retState
}
type Decoder struct {
device gotch.Device
embedding nn.Embedding
gru nn.GRU
attn nn.Linear
attnCombine nn.Linear
linear nn.Linear
}
func newDecoder(vs nn.Path, hiddenDim, outDim int64) (retVal Decoder) {
return Decoder{
device: vs.Device(),
embedding: nn.NewEmbedding(vs, outDim, hiddenDim, nn.DefaultEmbeddingConfig()),
gru: nn.NewGRU(vs, hiddenDim, hiddenDim, nn.DefaultRNNConfig()),
attn: nn.NewLinear(vs, hiddenDim*2, MaxLength, nn.DefaultLinearConfig()),
attnCombine: nn.NewLinear(vs, hiddenDim*2, hiddenDim, nn.DefaultLinearConfig()),
linear: nn.NewLinear(vs, hiddenDim, outDim, nn.DefaultLinearConfig()),
}
}
func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor, isTraining bool) (retTs ts.Tensor, retState nn.GRUState) {
forwardTsTmp := d.embedding.Forward(xs)
forwardTsTmp.MustDropout_(0.1, isTraining)
forwardTs := forwardTsTmp.MustView([]int64{1, -1}, true)
// NOTE. forwardTs shape: [1, 256] state [1, 1, 256]
// hence, just get state[0] of 3D tensor state
stateTs := state.Value()
state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)})
catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1)
// NOTE. d.attn Ws shape : [512, 10]
appliedTs := catTs.Apply(d.attn)
attnWeights := appliedTs.MustUnsqueeze(0, true)
size3, err := encOutputs.Size3()
if err != nil {
log.Fatal(err)
}
sz1 := size3[0]
sz2 := size3[1]
sz3 := size3[2]
var encOutputsTs ts.Tensor
if sz2 == MaxLength {
encOutputsTs = encOutputs.MustShallowClone()
} else {
shape := []int64{sz1, MaxLength - sz2, sz3}
zerosTs := ts.MustZeros(shape, gotch.Float, d.device)
encOutputsTs = ts.MustCat([]ts.Tensor{encOutputs, zerosTs}, 1)
}
attnApplied := attnWeights.MustBmm(encOutputsTs, true).MustSqueeze1(1, true)
fmt.Printf("attnApplied shape: %v\n", attnApplied.MustSize())
fmt.Printf("xs shape: %v\n", forwardTs.MustSize())
xsTs := ts.MustCat([]ts.Tensor{forwardTs, attnApplied}, 1).Apply(d.attnCombine).MustRelu(true)
retState = d.gru.Step(xsTs, state).(nn.GRUState)
retTs = d.linear.Forward(retState.Value()).MustLogSoftmax(-1, gotch.Float, true)
return retTs, retState
}
type Model struct {
encoder Encoder
decoder Decoder
decoderStart ts.Tensor
decoderEos int64
device gotch.Device
}
func newModel(vs nn.Path, ilang Lang, olang Lang, hiddenDim int64) (retVal Model) {
return Model{
encoder: newEncoder(vs.Sub("enc"), int64(ilang.Len()), hiddenDim),
decoder: newDecoder(vs.Sub("dec"), hiddenDim, int64(olang.Len())),
decoderStart: ts.MustOfSlice([]int64{int64(olang.SosToken())}).MustTo(vs.Device(), true),
decoderEos: int64(olang.EosToken()),
device: vs.Device(),
}
}
func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
state := m.encoder.gru.ZeroState(1)
fmt.Printf("state shape: %v\n", state.(nn.GRUState).Value().MustSize())
var encOutputs []ts.Tensor
for _, v := range input {
s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true)
outTs, outState := m.encoder.forward(s, state.(nn.GRUState))
encOutputs = append(encOutputs, outTs)
state.(nn.GRUState).Tensor.MustDrop()
state = outState
}
stackTs := ts.MustStack(encOutputs, 1)
for _, t := range encOutputs {
t.MustDrop()
}
// TODO: should we implement random here???
loss := ts.TensorFrom([]float32{0.0}).MustTo(m.device, true)
prev := m.decoderStart.MustShallowClone()
for _, s := range target {
outTs, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true)
state.(nn.GRUState).Tensor.MustDrop()
state = outState
targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true)
fmt.Printf("targetTs shape: %v\n", targetTs.MustSize())
fmt.Printf("outTs shape: %v\n", outTs.MustSize())
// TODO: fix the error: input tensor should be 1D or 2D at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:35
ws := ts.NewTensor()
currLoss := outTs.MustNllLoss(targetTs, ws, int64(1), -100, false)
loss.MustAdd_(currLoss)
currLoss.MustDrop()
_, output := outTs.MustTopK(1, -1, true, true)
if m.decoderEos == outTs.Int64Values()[0] {
break
}
prev.MustDrop()
prev = output
}
return loss
}
func (m *Model) predict(input []int) (retVal []int) {
state := m.encoder.gru.ZeroState(1)
var encOutputs []ts.Tensor
for _, v := range input {
s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true)
outTs, outState := m.encoder.forward(s, state.(nn.GRUState))
encOutputs = append(encOutputs, outTs)
state.(nn.GRUState).Tensor.MustDrop()
state = outState
}
stackTs := ts.MustStack(encOutputs, 1)
for _, t := range encOutputs {
t.MustDrop()
}
prev := m.decoderStart.MustShallowClone()
var outputSeq []int
for i := 0; i < int(MaxLength); i++ {
outTs, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true)
_, output := outTs.MustTopK(1, -1, true, true)
outputVal := output.Int64Values()[0]
outputSeq = append(outputSeq, int(outputVal))
if m.decoderEos == outTs.Int64Values()[0] {
break
}
state.(nn.GRUState).Tensor.MustDrop()
state = outState
prev.MustDrop()
prev = output
}
return outputSeq
}
type LossStats struct {
totalLoss float64
samples int
}
func newLossStats() (retVal LossStats) {
return LossStats{
totalLoss: 0.0,
samples: 0,
}
}
func (ls *LossStats) update(loss float64) {
ls.totalLoss += loss
ls.samples += 1
}
func (ls *LossStats) avgAndReset() (retVal float64) {
avg := ls.totalLoss / float64(ls.samples)
ls.totalLoss = 0.0
ls.samples = 0
return avg
}
func main() {
dataset := newDataset("eng", "fra", int(MaxLength)).Reverse()
ilang := dataset.InputLang()
olang := dataset.OutputLang()
pairs := dataset.Pairs()
fmt.Printf("Input: %v %v words\n", ilang.GetName(), ilang.Len())
fmt.Printf("Output: %v %v words\n", olang.GetName(), olang.Len())
fmt.Printf("Pairs: %v\n", len(pairs))
// TODO: should we implement random here??
cuda := gotch.NewCuda()
device := cuda.CudaIfAvailable()
vs := nn.NewVarStore(device)
model := newModel(vs.Root(), ilang, olang, HiddenSize)
optConfig := nn.DefaultAdamConfig()
opt, err := optConfig.Build(vs, LearningRate)
if err != nil {
log.Fatal(err)
}
lossStats := newLossStats()
for i := 1; i < int(Samples); i++ {
// randomly choose a pair
idx := rand.Intn(len(pairs))
pair := pairs[idx]
input := pair.Val1
target := pair.Val2
loss := model.trainLoss(input, target)
panic("reached")
opt.BackwardStep(loss)
lossStats.update(loss.Float64Values()[0] / float64(len(target)))
if i%1000 == 0 {
fmt.Printf("%v %v\n", i, lossStats.avgAndReset())
for predIdx := 1; predIdx <= 5; predIdx++ {
idx := rand.Intn(len(pairs))
in := pairs[idx].Val1
tgt := pairs[idx].Val2
predict := model.predict(in)
fmt.Printf("input: %v\n", ilang.SeqToString(in))
fmt.Printf("target: %v\n", olang.SeqToString(tgt))
fmt.Printf("ouput: %v\n", olang.SeqToString(predict))
}
}
}
}

View File

@ -75,16 +75,19 @@ func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tens
var err error
rand.Seed(86)
data := make([]float64, ts.FlattenDim(dims))
data := make([]float32, ts.FlattenDim(dims))
for i := range data {
data[i] = rand.NormFloat64()*r.mean + r.stdev
// NOTE. tensor will have DType = Float (float32)
data[i] = float32(rand.NormFloat64()*r.mean + r.stdev)
}
retVal, err = ts.NewTensorFromData(data, dims)
newTs, err := ts.NewTensorFromData(data, dims)
if err != nil {
log.Fatalf("randInit - InitTensor method call error: %v\n", err)
}
retVal = newTs.MustTo(device, true)
return retVal
}

View File

@ -1,6 +1,8 @@
package nn
import (
"fmt"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
@ -199,7 +201,7 @@ type GRU struct {
}
// NewGRU create a new GRU layer
func NewGRU(vs *Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
func NewGRU(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
var numDirections int64 = 1
if cfg.Bidirectional {
numDirections = 2
@ -210,11 +212,14 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
for i := 0; i < int(cfg.NumLayers); i++ {
for n := 0; n < int(numDirections); n++ {
if i != 0 {
inDim = hiddenDim * numDirections
var inputDim int64
if i == 0 {
inputDim = inDim
} else {
inputDim = hiddenDim * numDirections
}
wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inDim})
wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inputDim})
wHh := vs.KaimingUniform("w_hh", []int64{gateDim, hiddenDim})
bIh := vs.Zeros("b_ih", []int64{gateDim})
bHh := vs.Zeros("b_hh", []int64{gateDim})
@ -249,9 +254,12 @@ func (g GRU) ZeroState(batchDim int64) (retVal State) {
}
func (g GRU) Step(input ts.Tensor, inState State) (retVal State) {
ip := input.MustUnsqueeze(1, false)
unsqueezedInput := input.MustUnsqueeze(1, false)
output, state := g.SeqInit(ip, inState)
fmt.Printf("unsqueezed input shape: %v\n", unsqueezedInput.MustSize())
fmt.Printf("inState Ts size: %v\n", inState.(GRUState).Tensor.MustSize())
output, state := g.SeqInit(unsqueezedInput, inState)
// NOTE: though we won't use `output`, it is a Ctensor created in C land, so
// it should be cleaned up here to prevent memory hold-up.

View File

@ -22,7 +22,7 @@ func gruTest(rnnConfig nn.RNNConfig, t *testing.T) {
vs := nn.NewVarStore(gotch.CPU)
path := vs.Root()
gru := nn.NewGRU(&path, inputDim, outputDim, rnnConfig)
gru := nn.NewGRU(path, inputDim, outputDim, rnnConfig)
numDirections := int64(1)
if rnnConfig.Bidirectional {

View File

@ -297,6 +297,15 @@ func (ts Tensor) Device() (retVal gotch.Device, err error) {
return device.OfCInt(int32(cInt)), nil
}
func (ts Tensor) MustDevice() (retVal gotch.Device) {
retVal, err := ts.Device()
if err != nil {
log.Fatal(err)
}
return retVal
}
/*
* func (ts Tensor) Eq1(other Tensor, del bool) (retVal Tensor, err error) {
*