183 lines
3.3 KiB
Go
183 lines
3.3 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"strings"
|
|
"unicode"
|
|
)
|
|
|
|
var Prefixes []string = []string{
|
|
"i am ",
|
|
"i m ",
|
|
"you are ",
|
|
"you re",
|
|
"he is ",
|
|
"he s ",
|
|
"she is ",
|
|
"she s ",
|
|
"we are ",
|
|
"we re ",
|
|
"they are ",
|
|
"they re ",
|
|
}
|
|
|
|
type Pair struct {
|
|
val1 string
|
|
val2 string
|
|
}
|
|
|
|
type Dataset struct {
|
|
inputLang Lang
|
|
outputLang Lang
|
|
pairs []Pair
|
|
}
|
|
|
|
func normalize(s string) (retVal string) {
|
|
|
|
lower := strings.ToLower(s)
|
|
/*
|
|
* // strip all spaces
|
|
* noSpaceStr := strings.Map(func(r rune) rune {
|
|
* if unicode.IsSpace(r) {
|
|
* return -1
|
|
* }
|
|
* return 1
|
|
* }, lower)
|
|
* */
|
|
// add single space before "!", ".", "?"
|
|
var res []rune
|
|
for _, r := range []rune(lower) {
|
|
char := fmt.Sprintf("%c", r)
|
|
switch {
|
|
case char == "!":
|
|
res = append(res, []rune(" !")...)
|
|
case char == ".":
|
|
res = append(res, []rune(" .")...)
|
|
case char == "?":
|
|
res = append(res, []rune(" ?")...)
|
|
case unicode.IsLetter(r), unicode.IsNumber(r):
|
|
res = append(res, r)
|
|
default:
|
|
res = append(res, []rune(" ")...)
|
|
}
|
|
}
|
|
|
|
return string(res)
|
|
}
|
|
|
|
func toIndexes(s string, lang Lang) (retVal []int) {
|
|
res := strings.Split(s, " ")
|
|
|
|
for _, l := range res {
|
|
idx := lang.GetIndex(l)
|
|
if idx >= 0 {
|
|
retVal = append(retVal, idx)
|
|
}
|
|
}
|
|
|
|
retVal = append(retVal, lang.EosToken())
|
|
|
|
return retVal
|
|
}
|
|
|
|
func filterPrefix(s string) (retVal bool) {
|
|
|
|
for _, prefix := range Prefixes {
|
|
if strings.HasPrefix(s, prefix) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func readPairs(ilang, olang string, maxLength int) (retVal []Pair) {
|
|
file, err := os.Open(fmt.Sprintf("../../data/translation/%v-%v.txt", ilang, olang))
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer file.Close()
|
|
scanner := bufio.NewScanner(file)
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
|
|
if strings.Contains(line, "\t") {
|
|
// NOTE: assuming there's only 1 '\t'
|
|
pair := strings.Split(line, "\t")
|
|
lhs := normalize(pair[0])
|
|
rhs := normalize(pair[1])
|
|
|
|
if (len(strings.Split(lhs, " ")) < maxLength) && (len(strings.Split(rhs, " ")) < maxLength) && (filterPrefix(lhs) || filterPrefix(rhs)) {
|
|
retVal = append(retVal, Pair{lhs, rhs})
|
|
}
|
|
|
|
} else {
|
|
log.Fatalf("A line does not contain a single tab: %v\n", line)
|
|
}
|
|
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return retVal
|
|
}
|
|
|
|
func newDataset(ilang, olang string, maxLength int) (retVal Dataset) {
|
|
pairs := readPairs(ilang, olang, maxLength)
|
|
inputLang := NewLang(ilang)
|
|
outputLang := NewLang(olang)
|
|
|
|
for _, p := range pairs {
|
|
inputLang.AddSentence(p.val1)
|
|
outputLang.AddSentence(p.val2)
|
|
}
|
|
|
|
return Dataset{
|
|
inputLang: inputLang,
|
|
outputLang: outputLang,
|
|
pairs: pairs,
|
|
}
|
|
}
|
|
|
|
func (ds Dataset) InputLang() (retVal Lang) {
|
|
return ds.inputLang
|
|
}
|
|
|
|
func (ds Dataset) OutputLang() (retVal Lang) {
|
|
return ds.outputLang
|
|
}
|
|
|
|
func (ds Dataset) Reverse() (retVal Dataset) {
|
|
var rpairs []Pair
|
|
for _, p := range ds.pairs {
|
|
rpairs = append(rpairs, Pair{p.val2, p.val1})
|
|
}
|
|
return Dataset{
|
|
inputLang: ds.outputLang,
|
|
outputLang: ds.inputLang,
|
|
pairs: rpairs,
|
|
}
|
|
}
|
|
|
|
type Pairs struct {
|
|
Val1 []int
|
|
Val2 []int
|
|
}
|
|
|
|
func (ds Dataset) Pairs() (retVal []Pairs) {
|
|
for _, p := range ds.pairs {
|
|
val1 := toIndexes(p.val1, ds.inputLang)
|
|
val2 := toIndexes(p.val2, ds.outputLang)
|
|
|
|
retVal = append(retVal, Pairs{val1, val2})
|
|
}
|
|
|
|
return retVal
|
|
}
|