diff --git a/example/translation/dataset.go b/example/translation/dataset.go new file mode 100644 index 0000000..5873f28 --- /dev/null +++ b/example/translation/dataset.go @@ -0,0 +1,182 @@ +package main + +import ( + "bufio" + "fmt" + "log" + "os" + "strings" + "unicode" +) + +var Prefixes []string = []string{ + "i am ", + "i m ", + "you are ", + "you re", + "he is ", + "he s ", + "she is ", + "she s ", + "we are ", + "we re ", + "they are ", + "they re ", +} + +type Pair struct { + val1 string + val2 string +} + +type Dataset struct { + inputLang Lang + outputLang Lang + pairs []Pair +} + +func normalize(s string) (retVal string) { + + lower := strings.ToLower(s) + /* + * // strip all spaces + * noSpaceStr := strings.Map(func(r rune) rune { + * if unicode.IsSpace(r) { + * return -1 + * } + * return 1 + * }, lower) + * */ + // add single space before "!", ".", "?" + var res []rune + for _, r := range []rune(lower) { + char := fmt.Sprintf("%c", r) + switch { + case char == "!": + res = append(res, []rune(" !")...) + case char == ".": + res = append(res, []rune(" .")...) + case char == "?": + res = append(res, []rune(" ?")...) + case unicode.IsLetter(r), unicode.IsNumber(r): + res = append(res, r) + default: + res = append(res, []rune(" ")...) + } + } + + return string(res) +} + +func toIndexes(s string, lang Lang) (retVal []int) { + res := strings.Split(s, " ") + + for _, l := range res { + idx := lang.GetIndex(l) + if idx >= 0 { + retVal = append(retVal, idx) + } + } + + retVal = append(retVal, lang.EosToken()) + + return retVal +} + +func filterPrefix(s string) (retVal bool) { + + for _, prefix := range Prefixes { + if strings.HasPrefix(s, prefix) { + return true + } + } + + return false +} + +func readPairs(ilang, olang string, maxLength int) (retVal []Pair) { + file, err := os.Open(fmt.Sprintf("../../data/translation/%v-%v.txt", ilang, olang)) + if err != nil { + log.Fatal(err) + } + defer file.Close() + scanner := bufio.NewScanner(file) + + for scanner.Scan() { + line := scanner.Text() + + if strings.Contains(line, "\t") { + // NOTE: assuming there's only 1 '\t' + pair := strings.Split(line, "\t") + lhs := normalize(pair[0]) + rhs := normalize(pair[1]) + + if (len(strings.Split(lhs, " ")) < maxLength) && (len(strings.Split(rhs, " ")) < maxLength) && (filterPrefix(lhs) || filterPrefix(rhs)) { + retVal = append(retVal, Pair{lhs, rhs}) + } + + } else { + log.Fatalf("A line does not contain a single tab: %v\n", line) + } + + } + + if err := scanner.Err(); err != nil { + log.Fatal(err) + } + + return retVal +} + +func newDataset(ilang, olang string, maxLength int) (retVal Dataset) { + pairs := readPairs(ilang, olang, maxLength) + inputLang := NewLang(ilang) + outputLang := NewLang(olang) + + for _, p := range pairs { + inputLang.AddSentence(p.val1) + outputLang.AddSentence(p.val2) + } + + return Dataset{ + inputLang: inputLang, + outputLang: outputLang, + pairs: pairs, + } +} + +func (ds Dataset) InputLang() (retVal Lang) { + return ds.inputLang +} + +func (ds Dataset) OutputLang() (retVal Lang) { + return ds.outputLang +} + +func (ds Dataset) Reverse() (retVal Dataset) { + var rpairs []Pair + for _, p := range ds.pairs { + rpairs = append(rpairs, Pair{p.val2, p.val1}) + } + return Dataset{ + inputLang: ds.outputLang, + outputLang: ds.inputLang, + pairs: rpairs, + } +} + +type Pairs struct { + Val1 []int + Val2 []int +} + +func (ds Dataset) Pairs() (retVal []Pairs) { + for _, p := range ds.pairs { + val1 := toIndexes(p.val1, ds.inputLang) + val2 := toIndexes(p.val2, ds.outputLang) + + retVal = append(retVal, Pairs{val1, val2}) + } + + return retVal +} diff --git a/example/translation/lang.go b/example/translation/lang.go new file mode 100644 index 0000000..b427ef1 --- /dev/null +++ b/example/translation/lang.go @@ -0,0 +1,92 @@ +package main + +import ( + "strings" +) + +const ( + SosToken = "SOS" + EosToken = "EOS" +) + +type IndexCount struct { + Index int + Count int +} + +type Lang struct { + Name string + WordToIndexAndCount map[string]IndexCount + IndexToWord map[int]string +} + +func NewLang(name string) (retVal Lang) { + + lang := Lang{ + Name: name, + WordToIndexAndCount: make(map[string]IndexCount, 0), + IndexToWord: make(map[int]string, 0), + } + + lang.AddWord(SosToken) + lang.AddWord(EosToken) + + return lang +} + +func (l *Lang) AddWord(word string) { + if len(word) > 0 { + idxCount, ok := l.WordToIndexAndCount[word] + if !ok { + length := len(l.WordToIndexAndCount) + l.WordToIndexAndCount[word] = IndexCount{length, 1} + l.IndexToWord[length] = word + } else { + idxCount.Count += 1 + l.WordToIndexAndCount[word] = idxCount + } + } +} + +func (l *Lang) AddSentence(sentence string) { + words := strings.Split(sentence, " ") + for _, word := range words { + l.AddWord(word) + } +} + +func (l *Lang) Len() (retVal int) { + return len(l.IndexToWord) +} + +func (l *Lang) SosToken() (retVal int) { + return l.WordToIndexAndCount[SosToken].Index +} + +func (l *Lang) EosToken() (retVal int) { + return l.WordToIndexAndCount[EosToken].Index +} + +func (l *Lang) GetName() (retVal string) { + return l.Name +} + +func (l *Lang) GetIndex(word string) (retVal int) { + idxCount, ok := l.WordToIndexAndCount[word] + if ok { + return idxCount.Index + } else { + return -1 // word does not exist in Lang + } +} + +func (l *Lang) SeqToString(seq []int) (retVal string) { + var words []string = make([]string, 0) + + for _, idx := range seq { + w := l.IndexToWord[idx] + words = append(words, w) + } + + return strings.Join(words, "") +} diff --git a/example/translation/main.go b/example/translation/main.go new file mode 100644 index 0000000..2930d92 --- /dev/null +++ b/example/translation/main.go @@ -0,0 +1,315 @@ +/* Translation with a Sequence to Sequence Model and Attention. + + This follows the line of the PyTorch tutorial: + https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html + And trains a Sequence to Sequence (seq2seq) model using attention to + perform translation between French and English. + + The dataset can be downloaded from the following link: + https://download.pytorch.org/tutorial/data.zip + The eng-fra.txt file should be moved in the data directory. +*/ + +package main + +import ( + "fmt" + "log" + "math/rand" + + "github.com/sugarme/gotch" + "github.com/sugarme/gotch/nn" + ts "github.com/sugarme/gotch/tensor" +) + +var ( + MaxLength int64 = 10 + LearningRate float64 = 0.001 + Samples int64 = 100000 + HiddenSize int64 = 256 +) + +type Encoder struct { + embedding nn.Embedding + gru nn.GRU +} + +func newEncoder(vs nn.Path, inDim, hiddenDim int64) (retVal Encoder) { + + gru := nn.NewGRU(vs, hiddenDim, hiddenDim, nn.DefaultRNNConfig()) + + embedding := nn.NewEmbedding(vs, inDim, hiddenDim, nn.DefaultEmbeddingConfig()) + + return Encoder{embedding, gru} +} + +func (e Encoder) forward(xs ts.Tensor, state nn.GRUState) (retTs ts.Tensor, retState nn.GRUState) { + + retTs = e.embedding.Forward(xs).MustView([]int64{1, -1}, false) + retState = e.gru.Step(retTs, state).(nn.GRUState) + + return retTs, retState +} + +type Decoder struct { + device gotch.Device + embedding nn.Embedding + gru nn.GRU + attn nn.Linear + attnCombine nn.Linear + linear nn.Linear +} + +func newDecoder(vs nn.Path, hiddenDim, outDim int64) (retVal Decoder) { + + return Decoder{ + device: vs.Device(), + embedding: nn.NewEmbedding(vs, outDim, hiddenDim, nn.DefaultEmbeddingConfig()), + gru: nn.NewGRU(vs, hiddenDim, hiddenDim, nn.DefaultRNNConfig()), + attn: nn.NewLinear(vs, hiddenDim*2, MaxLength, nn.DefaultLinearConfig()), + attnCombine: nn.NewLinear(vs, hiddenDim*2, hiddenDim, nn.DefaultLinearConfig()), + linear: nn.NewLinear(vs, hiddenDim, outDim, nn.DefaultLinearConfig()), + } +} + +func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor, isTraining bool) (retTs ts.Tensor, retState nn.GRUState) { + forwardTsTmp := d.embedding.Forward(xs) + forwardTsTmp.MustDropout_(0.1, isTraining) + forwardTs := forwardTsTmp.MustView([]int64{1, -1}, true) + + // NOTE. forwardTs shape: [1, 256] state [1, 1, 256] + // hence, just get state[0] of 3D tensor state + stateTs := state.Value() + state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)}) + + catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1) + + // NOTE. d.attn Ws shape : [512, 10] + appliedTs := catTs.Apply(d.attn) + attnWeights := appliedTs.MustUnsqueeze(0, true) + + size3, err := encOutputs.Size3() + if err != nil { + log.Fatal(err) + } + + sz1 := size3[0] + sz2 := size3[1] + sz3 := size3[2] + + var encOutputsTs ts.Tensor + if sz2 == MaxLength { + encOutputsTs = encOutputs.MustShallowClone() + } else { + shape := []int64{sz1, MaxLength - sz2, sz3} + zerosTs := ts.MustZeros(shape, gotch.Float, d.device) + encOutputsTs = ts.MustCat([]ts.Tensor{encOutputs, zerosTs}, 1) + } + + attnApplied := attnWeights.MustBmm(encOutputsTs, true).MustSqueeze1(1, true) + fmt.Printf("attnApplied shape: %v\n", attnApplied.MustSize()) + fmt.Printf("xs shape: %v\n", forwardTs.MustSize()) + + xsTs := ts.MustCat([]ts.Tensor{forwardTs, attnApplied}, 1).Apply(d.attnCombine).MustRelu(true) + + retState = d.gru.Step(xsTs, state).(nn.GRUState) + + retTs = d.linear.Forward(retState.Value()).MustLogSoftmax(-1, gotch.Float, true) + + return retTs, retState +} + +type Model struct { + encoder Encoder + decoder Decoder + decoderStart ts.Tensor + decoderEos int64 + device gotch.Device +} + +func newModel(vs nn.Path, ilang Lang, olang Lang, hiddenDim int64) (retVal Model) { + return Model{ + encoder: newEncoder(vs.Sub("enc"), int64(ilang.Len()), hiddenDim), + decoder: newDecoder(vs.Sub("dec"), hiddenDim, int64(olang.Len())), + decoderStart: ts.MustOfSlice([]int64{int64(olang.SosToken())}).MustTo(vs.Device(), true), + decoderEos: int64(olang.EosToken()), + device: vs.Device(), + } +} + +func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) { + state := m.encoder.gru.ZeroState(1) + fmt.Printf("state shape: %v\n", state.(nn.GRUState).Value().MustSize()) + var encOutputs []ts.Tensor + + for _, v := range input { + s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true) + + outTs, outState := m.encoder.forward(s, state.(nn.GRUState)) + + encOutputs = append(encOutputs, outTs) + state.(nn.GRUState).Tensor.MustDrop() + state = outState + } + + stackTs := ts.MustStack(encOutputs, 1) + for _, t := range encOutputs { + t.MustDrop() + } + + // TODO: should we implement random here??? + loss := ts.TensorFrom([]float32{0.0}).MustTo(m.device, true) + prev := m.decoderStart.MustShallowClone() + + for _, s := range target { + outTs, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true) + + state.(nn.GRUState).Tensor.MustDrop() + state = outState + + targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true) + fmt.Printf("targetTs shape: %v\n", targetTs.MustSize()) + fmt.Printf("outTs shape: %v\n", outTs.MustSize()) + + // TODO: fix the error: input tensor should be 1D or 2D at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:35 + ws := ts.NewTensor() + currLoss := outTs.MustNllLoss(targetTs, ws, int64(1), -100, false) + + loss.MustAdd_(currLoss) + currLoss.MustDrop() + + _, output := outTs.MustTopK(1, -1, true, true) + + if m.decoderEos == outTs.Int64Values()[0] { + break + } + + prev.MustDrop() + prev = output + } + + return loss + +} + +func (m *Model) predict(input []int) (retVal []int) { + state := m.encoder.gru.ZeroState(1) + var encOutputs []ts.Tensor + + for _, v := range input { + s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true) + outTs, outState := m.encoder.forward(s, state.(nn.GRUState)) + + encOutputs = append(encOutputs, outTs) + state.(nn.GRUState).Tensor.MustDrop() + state = outState + } + + stackTs := ts.MustStack(encOutputs, 1) + for _, t := range encOutputs { + t.MustDrop() + } + + prev := m.decoderStart.MustShallowClone() + var outputSeq []int + + for i := 0; i < int(MaxLength); i++ { + outTs, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true) + _, output := outTs.MustTopK(1, -1, true, true) + outputVal := output.Int64Values()[0] + outputSeq = append(outputSeq, int(outputVal)) + + if m.decoderEos == outTs.Int64Values()[0] { + break + } + + state.(nn.GRUState).Tensor.MustDrop() + state = outState + prev.MustDrop() + prev = output + } + + return outputSeq + +} + +type LossStats struct { + totalLoss float64 + samples int +} + +func newLossStats() (retVal LossStats) { + return LossStats{ + totalLoss: 0.0, + samples: 0, + } +} + +func (ls *LossStats) update(loss float64) { + ls.totalLoss += loss + ls.samples += 1 +} + +func (ls *LossStats) avgAndReset() (retVal float64) { + avg := ls.totalLoss / float64(ls.samples) + ls.totalLoss = 0.0 + ls.samples = 0 + return avg +} + +func main() { + + dataset := newDataset("eng", "fra", int(MaxLength)).Reverse() + + ilang := dataset.InputLang() + olang := dataset.OutputLang() + pairs := dataset.Pairs() + + fmt.Printf("Input: %v %v words\n", ilang.GetName(), ilang.Len()) + fmt.Printf("Output: %v %v words\n", olang.GetName(), olang.Len()) + fmt.Printf("Pairs: %v\n", len(pairs)) + + // TODO: should we implement random here?? + + cuda := gotch.NewCuda() + device := cuda.CudaIfAvailable() + + vs := nn.NewVarStore(device) + + model := newModel(vs.Root(), ilang, olang, HiddenSize) + + optConfig := nn.DefaultAdamConfig() + opt, err := optConfig.Build(vs, LearningRate) + if err != nil { + log.Fatal(err) + } + + lossStats := newLossStats() + + for i := 1; i < int(Samples); i++ { + // randomly choose a pair + idx := rand.Intn(len(pairs)) + pair := pairs[idx] + input := pair.Val1 + target := pair.Val2 + loss := model.trainLoss(input, target) + panic("reached") + opt.BackwardStep(loss) + lossStats.update(loss.Float64Values()[0] / float64(len(target))) + + if i%1000 == 0 { + fmt.Printf("%v %v\n", i, lossStats.avgAndReset()) + for predIdx := 1; predIdx <= 5; predIdx++ { + idx := rand.Intn(len(pairs)) + in := pairs[idx].Val1 + tgt := pairs[idx].Val2 + predict := model.predict(in) + + fmt.Printf("input: %v\n", ilang.SeqToString(in)) + fmt.Printf("target: %v\n", olang.SeqToString(tgt)) + fmt.Printf("ouput: %v\n", olang.SeqToString(predict)) + } + } + } + +} diff --git a/nn/init.go b/nn/init.go index dc796a2..b718b62 100644 --- a/nn/init.go +++ b/nn/init.go @@ -75,16 +75,19 @@ func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tens var err error rand.Seed(86) - data := make([]float64, ts.FlattenDim(dims)) + data := make([]float32, ts.FlattenDim(dims)) for i := range data { - data[i] = rand.NormFloat64()*r.mean + r.stdev + // NOTE. tensor will have DType = Float (float32) + data[i] = float32(rand.NormFloat64()*r.mean + r.stdev) } - retVal, err = ts.NewTensorFromData(data, dims) + newTs, err := ts.NewTensorFromData(data, dims) if err != nil { log.Fatalf("randInit - InitTensor method call error: %v\n", err) } + retVal = newTs.MustTo(device, true) + return retVal } diff --git a/nn/rnn.go b/nn/rnn.go index 358bf84..fb2873b 100644 --- a/nn/rnn.go +++ b/nn/rnn.go @@ -1,6 +1,8 @@ package nn import ( + "fmt" + "github.com/sugarme/gotch" ts "github.com/sugarme/gotch/tensor" ) @@ -199,7 +201,7 @@ type GRU struct { } // NewGRU create a new GRU layer -func NewGRU(vs *Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) { +func NewGRU(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) { var numDirections int64 = 1 if cfg.Bidirectional { numDirections = 2 @@ -210,11 +212,14 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) { for i := 0; i < int(cfg.NumLayers); i++ { for n := 0; n < int(numDirections); n++ { - if i != 0 { - inDim = hiddenDim * numDirections + var inputDim int64 + if i == 0 { + inputDim = inDim + } else { + inputDim = hiddenDim * numDirections } - wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inDim}) + wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inputDim}) wHh := vs.KaimingUniform("w_hh", []int64{gateDim, hiddenDim}) bIh := vs.Zeros("b_ih", []int64{gateDim}) bHh := vs.Zeros("b_hh", []int64{gateDim}) @@ -249,9 +254,12 @@ func (g GRU) ZeroState(batchDim int64) (retVal State) { } func (g GRU) Step(input ts.Tensor, inState State) (retVal State) { - ip := input.MustUnsqueeze(1, false) + unsqueezedInput := input.MustUnsqueeze(1, false) - output, state := g.SeqInit(ip, inState) + fmt.Printf("unsqueezed input shape: %v\n", unsqueezedInput.MustSize()) + fmt.Printf("inState Ts size: %v\n", inState.(GRUState).Tensor.MustSize()) + + output, state := g.SeqInit(unsqueezedInput, inState) // NOTE: though we won't use `output`, it is a Ctensor created in C land, so // it should be cleaned up here to prevent memory hold-up. diff --git a/nn/rnn_test.go b/nn/rnn_test.go index 9296b2e..79323b1 100644 --- a/nn/rnn_test.go +++ b/nn/rnn_test.go @@ -22,7 +22,7 @@ func gruTest(rnnConfig nn.RNNConfig, t *testing.T) { vs := nn.NewVarStore(gotch.CPU) path := vs.Root() - gru := nn.NewGRU(&path, inputDim, outputDim, rnnConfig) + gru := nn.NewGRU(path, inputDim, outputDim, rnnConfig) numDirections := int64(1) if rnnConfig.Bidirectional { diff --git a/tensor/tensor.go b/tensor/tensor.go index e0d755c..3ae63bf 100644 --- a/tensor/tensor.go +++ b/tensor/tensor.go @@ -297,6 +297,15 @@ func (ts Tensor) Device() (retVal gotch.Device, err error) { return device.OfCInt(int32(cInt)), nil } +func (ts Tensor) MustDevice() (retVal gotch.Device) { + retVal, err := ts.Device() + if err != nil { + log.Fatal(err) + } + + return retVal +} + /* * func (ts Tensor) Eq1(other Tensor, del bool) (retVal Tensor, err error) { *