gotch/example/translation/dataset.go
2020-08-01 16:33:30 +10:00

183 lines
3.3 KiB
Go

package main
import (
"bufio"
"fmt"
"log"
"os"
"strings"
"unicode"
)
var Prefixes []string = []string{
"i am ",
"i m ",
"you are ",
"you re",
"he is ",
"he s ",
"she is ",
"she s ",
"we are ",
"we re ",
"they are ",
"they re ",
}
type Pair struct {
val1 string
val2 string
}
type Dataset struct {
inputLang Lang
outputLang Lang
pairs []Pair
}
func normalize(s string) (retVal string) {
lower := strings.ToLower(s)
/*
* // strip all spaces
* noSpaceStr := strings.Map(func(r rune) rune {
* if unicode.IsSpace(r) {
* return -1
* }
* return 1
* }, lower)
* */
// add single space before "!", ".", "?"
var res []rune
for _, r := range []rune(lower) {
char := fmt.Sprintf("%c", r)
switch {
case char == "!":
res = append(res, []rune(" !")...)
case char == ".":
res = append(res, []rune(" .")...)
case char == "?":
res = append(res, []rune(" ?")...)
case unicode.IsLetter(r), unicode.IsNumber(r):
res = append(res, r)
default:
res = append(res, []rune(" ")...)
}
}
return string(res)
}
func toIndexes(s string, lang Lang) (retVal []int) {
res := strings.Split(s, " ")
for _, l := range res {
idx := lang.GetIndex(l)
if idx >= 0 {
retVal = append(retVal, idx)
}
}
retVal = append(retVal, lang.EosToken())
return retVal
}
func filterPrefix(s string) (retVal bool) {
for _, prefix := range Prefixes {
if strings.HasPrefix(s, prefix) {
return true
}
}
return false
}
func readPairs(ilang, olang string, maxLength int) (retVal []Pair) {
file, err := os.Open(fmt.Sprintf("../../data/translation/%v-%v.txt", ilang, olang))
if err != nil {
log.Fatal(err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if strings.Contains(line, "\t") {
// NOTE: assuming there's only 1 '\t'
pair := strings.Split(line, "\t")
lhs := normalize(pair[0])
rhs := normalize(pair[1])
if (len(strings.Split(lhs, " ")) < maxLength) && (len(strings.Split(rhs, " ")) < maxLength) && (filterPrefix(lhs) || filterPrefix(rhs)) {
retVal = append(retVal, Pair{lhs, rhs})
}
} else {
log.Fatalf("A line does not contain a single tab: %v\n", line)
}
}
if err := scanner.Err(); err != nil {
log.Fatal(err)
}
return retVal
}
func newDataset(ilang, olang string, maxLength int) (retVal Dataset) {
pairs := readPairs(ilang, olang, maxLength)
inputLang := NewLang(ilang)
outputLang := NewLang(olang)
for _, p := range pairs {
inputLang.AddSentence(p.val1)
outputLang.AddSentence(p.val2)
}
return Dataset{
inputLang: inputLang,
outputLang: outputLang,
pairs: pairs,
}
}
func (ds Dataset) InputLang() (retVal Lang) {
return ds.inputLang
}
func (ds Dataset) OutputLang() (retVal Lang) {
return ds.outputLang
}
func (ds Dataset) Reverse() (retVal Dataset) {
var rpairs []Pair
for _, p := range ds.pairs {
rpairs = append(rpairs, Pair{p.val2, p.val1})
}
return Dataset{
inputLang: ds.outputLang,
outputLang: ds.inputLang,
pairs: rpairs,
}
}
type Pairs struct {
Val1 []int
Val2 []int
}
func (ds Dataset) Pairs() (retVal []Pairs) {
for _, p := range ds.pairs {
val1 := toIndexes(p.val1, ds.inputLang)
val2 := toIndexes(p.val2, ds.outputLang)
retVal = append(retVal, Pairs{val1, val2})
}
return retVal
}