WIP(example/translation)
This commit is contained in:
parent
a71479e0f5
commit
b77fa54eb0
182
example/translation/dataset.go
Normal file
182
example/translation/dataset.go
Normal file
|
@ -0,0 +1,182 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
var Prefixes []string = []string{
|
||||
"i am ",
|
||||
"i m ",
|
||||
"you are ",
|
||||
"you re",
|
||||
"he is ",
|
||||
"he s ",
|
||||
"she is ",
|
||||
"she s ",
|
||||
"we are ",
|
||||
"we re ",
|
||||
"they are ",
|
||||
"they re ",
|
||||
}
|
||||
|
||||
type Pair struct {
|
||||
val1 string
|
||||
val2 string
|
||||
}
|
||||
|
||||
type Dataset struct {
|
||||
inputLang Lang
|
||||
outputLang Lang
|
||||
pairs []Pair
|
||||
}
|
||||
|
||||
func normalize(s string) (retVal string) {
|
||||
|
||||
lower := strings.ToLower(s)
|
||||
/*
|
||||
* // strip all spaces
|
||||
* noSpaceStr := strings.Map(func(r rune) rune {
|
||||
* if unicode.IsSpace(r) {
|
||||
* return -1
|
||||
* }
|
||||
* return 1
|
||||
* }, lower)
|
||||
* */
|
||||
// add single space before "!", ".", "?"
|
||||
var res []rune
|
||||
for _, r := range []rune(lower) {
|
||||
char := fmt.Sprintf("%c", r)
|
||||
switch {
|
||||
case char == "!":
|
||||
res = append(res, []rune(" !")...)
|
||||
case char == ".":
|
||||
res = append(res, []rune(" .")...)
|
||||
case char == "?":
|
||||
res = append(res, []rune(" ?")...)
|
||||
case unicode.IsLetter(r), unicode.IsNumber(r):
|
||||
res = append(res, r)
|
||||
default:
|
||||
res = append(res, []rune(" ")...)
|
||||
}
|
||||
}
|
||||
|
||||
return string(res)
|
||||
}
|
||||
|
||||
func toIndexes(s string, lang Lang) (retVal []int) {
|
||||
res := strings.Split(s, " ")
|
||||
|
||||
for _, l := range res {
|
||||
idx := lang.GetIndex(l)
|
||||
if idx >= 0 {
|
||||
retVal = append(retVal, idx)
|
||||
}
|
||||
}
|
||||
|
||||
retVal = append(retVal, lang.EosToken())
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func filterPrefix(s string) (retVal bool) {
|
||||
|
||||
for _, prefix := range Prefixes {
|
||||
if strings.HasPrefix(s, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func readPairs(ilang, olang string, maxLength int) (retVal []Pair) {
|
||||
file, err := os.Open(fmt.Sprintf("../../data/translation/%v-%v.txt", ilang, olang))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer file.Close()
|
||||
scanner := bufio.NewScanner(file)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
if strings.Contains(line, "\t") {
|
||||
// NOTE: assuming there's only 1 '\t'
|
||||
pair := strings.Split(line, "\t")
|
||||
lhs := normalize(pair[0])
|
||||
rhs := normalize(pair[1])
|
||||
|
||||
if (len(strings.Split(lhs, " ")) < maxLength) && (len(strings.Split(rhs, " ")) < maxLength) && (filterPrefix(lhs) || filterPrefix(rhs)) {
|
||||
retVal = append(retVal, Pair{lhs, rhs})
|
||||
}
|
||||
|
||||
} else {
|
||||
log.Fatalf("A line does not contain a single tab: %v\n", line)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func newDataset(ilang, olang string, maxLength int) (retVal Dataset) {
|
||||
pairs := readPairs(ilang, olang, maxLength)
|
||||
inputLang := NewLang(ilang)
|
||||
outputLang := NewLang(olang)
|
||||
|
||||
for _, p := range pairs {
|
||||
inputLang.AddSentence(p.val1)
|
||||
outputLang.AddSentence(p.val2)
|
||||
}
|
||||
|
||||
return Dataset{
|
||||
inputLang: inputLang,
|
||||
outputLang: outputLang,
|
||||
pairs: pairs,
|
||||
}
|
||||
}
|
||||
|
||||
func (ds Dataset) InputLang() (retVal Lang) {
|
||||
return ds.inputLang
|
||||
}
|
||||
|
||||
func (ds Dataset) OutputLang() (retVal Lang) {
|
||||
return ds.outputLang
|
||||
}
|
||||
|
||||
func (ds Dataset) Reverse() (retVal Dataset) {
|
||||
var rpairs []Pair
|
||||
for _, p := range ds.pairs {
|
||||
rpairs = append(rpairs, Pair{p.val2, p.val1})
|
||||
}
|
||||
return Dataset{
|
||||
inputLang: ds.outputLang,
|
||||
outputLang: ds.inputLang,
|
||||
pairs: rpairs,
|
||||
}
|
||||
}
|
||||
|
||||
type Pairs struct {
|
||||
Val1 []int
|
||||
Val2 []int
|
||||
}
|
||||
|
||||
func (ds Dataset) Pairs() (retVal []Pairs) {
|
||||
for _, p := range ds.pairs {
|
||||
val1 := toIndexes(p.val1, ds.inputLang)
|
||||
val2 := toIndexes(p.val2, ds.outputLang)
|
||||
|
||||
retVal = append(retVal, Pairs{val1, val2})
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
92
example/translation/lang.go
Normal file
92
example/translation/lang.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
SosToken = "SOS"
|
||||
EosToken = "EOS"
|
||||
)
|
||||
|
||||
type IndexCount struct {
|
||||
Index int
|
||||
Count int
|
||||
}
|
||||
|
||||
type Lang struct {
|
||||
Name string
|
||||
WordToIndexAndCount map[string]IndexCount
|
||||
IndexToWord map[int]string
|
||||
}
|
||||
|
||||
func NewLang(name string) (retVal Lang) {
|
||||
|
||||
lang := Lang{
|
||||
Name: name,
|
||||
WordToIndexAndCount: make(map[string]IndexCount, 0),
|
||||
IndexToWord: make(map[int]string, 0),
|
||||
}
|
||||
|
||||
lang.AddWord(SosToken)
|
||||
lang.AddWord(EosToken)
|
||||
|
||||
return lang
|
||||
}
|
||||
|
||||
func (l *Lang) AddWord(word string) {
|
||||
if len(word) > 0 {
|
||||
idxCount, ok := l.WordToIndexAndCount[word]
|
||||
if !ok {
|
||||
length := len(l.WordToIndexAndCount)
|
||||
l.WordToIndexAndCount[word] = IndexCount{length, 1}
|
||||
l.IndexToWord[length] = word
|
||||
} else {
|
||||
idxCount.Count += 1
|
||||
l.WordToIndexAndCount[word] = idxCount
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Lang) AddSentence(sentence string) {
|
||||
words := strings.Split(sentence, " ")
|
||||
for _, word := range words {
|
||||
l.AddWord(word)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Lang) Len() (retVal int) {
|
||||
return len(l.IndexToWord)
|
||||
}
|
||||
|
||||
func (l *Lang) SosToken() (retVal int) {
|
||||
return l.WordToIndexAndCount[SosToken].Index
|
||||
}
|
||||
|
||||
func (l *Lang) EosToken() (retVal int) {
|
||||
return l.WordToIndexAndCount[EosToken].Index
|
||||
}
|
||||
|
||||
func (l *Lang) GetName() (retVal string) {
|
||||
return l.Name
|
||||
}
|
||||
|
||||
func (l *Lang) GetIndex(word string) (retVal int) {
|
||||
idxCount, ok := l.WordToIndexAndCount[word]
|
||||
if ok {
|
||||
return idxCount.Index
|
||||
} else {
|
||||
return -1 // word does not exist in Lang
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Lang) SeqToString(seq []int) (retVal string) {
|
||||
var words []string = make([]string, 0)
|
||||
|
||||
for _, idx := range seq {
|
||||
w := l.IndexToWord[idx]
|
||||
words = append(words, w)
|
||||
}
|
||||
|
||||
return strings.Join(words, "")
|
||||
}
|
315
example/translation/main.go
Normal file
315
example/translation/main.go
Normal file
|
@ -0,0 +1,315 @@
|
|||
/* Translation with a Sequence to Sequence Model and Attention.
|
||||
|
||||
This follows the line of the PyTorch tutorial:
|
||||
https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
|
||||
And trains a Sequence to Sequence (seq2seq) model using attention to
|
||||
perform translation between French and English.
|
||||
|
||||
The dataset can be downloaded from the following link:
|
||||
https://download.pytorch.org/tutorial/data.zip
|
||||
The eng-fra.txt file should be moved in the data directory.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
var (
|
||||
MaxLength int64 = 10
|
||||
LearningRate float64 = 0.001
|
||||
Samples int64 = 100000
|
||||
HiddenSize int64 = 256
|
||||
)
|
||||
|
||||
type Encoder struct {
|
||||
embedding nn.Embedding
|
||||
gru nn.GRU
|
||||
}
|
||||
|
||||
func newEncoder(vs nn.Path, inDim, hiddenDim int64) (retVal Encoder) {
|
||||
|
||||
gru := nn.NewGRU(vs, hiddenDim, hiddenDim, nn.DefaultRNNConfig())
|
||||
|
||||
embedding := nn.NewEmbedding(vs, inDim, hiddenDim, nn.DefaultEmbeddingConfig())
|
||||
|
||||
return Encoder{embedding, gru}
|
||||
}
|
||||
|
||||
func (e Encoder) forward(xs ts.Tensor, state nn.GRUState) (retTs ts.Tensor, retState nn.GRUState) {
|
||||
|
||||
retTs = e.embedding.Forward(xs).MustView([]int64{1, -1}, false)
|
||||
retState = e.gru.Step(retTs, state).(nn.GRUState)
|
||||
|
||||
return retTs, retState
|
||||
}
|
||||
|
||||
type Decoder struct {
|
||||
device gotch.Device
|
||||
embedding nn.Embedding
|
||||
gru nn.GRU
|
||||
attn nn.Linear
|
||||
attnCombine nn.Linear
|
||||
linear nn.Linear
|
||||
}
|
||||
|
||||
func newDecoder(vs nn.Path, hiddenDim, outDim int64) (retVal Decoder) {
|
||||
|
||||
return Decoder{
|
||||
device: vs.Device(),
|
||||
embedding: nn.NewEmbedding(vs, outDim, hiddenDim, nn.DefaultEmbeddingConfig()),
|
||||
gru: nn.NewGRU(vs, hiddenDim, hiddenDim, nn.DefaultRNNConfig()),
|
||||
attn: nn.NewLinear(vs, hiddenDim*2, MaxLength, nn.DefaultLinearConfig()),
|
||||
attnCombine: nn.NewLinear(vs, hiddenDim*2, hiddenDim, nn.DefaultLinearConfig()),
|
||||
linear: nn.NewLinear(vs, hiddenDim, outDim, nn.DefaultLinearConfig()),
|
||||
}
|
||||
}
|
||||
|
||||
func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor, isTraining bool) (retTs ts.Tensor, retState nn.GRUState) {
|
||||
forwardTsTmp := d.embedding.Forward(xs)
|
||||
forwardTsTmp.MustDropout_(0.1, isTraining)
|
||||
forwardTs := forwardTsTmp.MustView([]int64{1, -1}, true)
|
||||
|
||||
// NOTE. forwardTs shape: [1, 256] state [1, 1, 256]
|
||||
// hence, just get state[0] of 3D tensor state
|
||||
stateTs := state.Value()
|
||||
state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)})
|
||||
|
||||
catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1)
|
||||
|
||||
// NOTE. d.attn Ws shape : [512, 10]
|
||||
appliedTs := catTs.Apply(d.attn)
|
||||
attnWeights := appliedTs.MustUnsqueeze(0, true)
|
||||
|
||||
size3, err := encOutputs.Size3()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
sz1 := size3[0]
|
||||
sz2 := size3[1]
|
||||
sz3 := size3[2]
|
||||
|
||||
var encOutputsTs ts.Tensor
|
||||
if sz2 == MaxLength {
|
||||
encOutputsTs = encOutputs.MustShallowClone()
|
||||
} else {
|
||||
shape := []int64{sz1, MaxLength - sz2, sz3}
|
||||
zerosTs := ts.MustZeros(shape, gotch.Float, d.device)
|
||||
encOutputsTs = ts.MustCat([]ts.Tensor{encOutputs, zerosTs}, 1)
|
||||
}
|
||||
|
||||
attnApplied := attnWeights.MustBmm(encOutputsTs, true).MustSqueeze1(1, true)
|
||||
fmt.Printf("attnApplied shape: %v\n", attnApplied.MustSize())
|
||||
fmt.Printf("xs shape: %v\n", forwardTs.MustSize())
|
||||
|
||||
xsTs := ts.MustCat([]ts.Tensor{forwardTs, attnApplied}, 1).Apply(d.attnCombine).MustRelu(true)
|
||||
|
||||
retState = d.gru.Step(xsTs, state).(nn.GRUState)
|
||||
|
||||
retTs = d.linear.Forward(retState.Value()).MustLogSoftmax(-1, gotch.Float, true)
|
||||
|
||||
return retTs, retState
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
encoder Encoder
|
||||
decoder Decoder
|
||||
decoderStart ts.Tensor
|
||||
decoderEos int64
|
||||
device gotch.Device
|
||||
}
|
||||
|
||||
func newModel(vs nn.Path, ilang Lang, olang Lang, hiddenDim int64) (retVal Model) {
|
||||
return Model{
|
||||
encoder: newEncoder(vs.Sub("enc"), int64(ilang.Len()), hiddenDim),
|
||||
decoder: newDecoder(vs.Sub("dec"), hiddenDim, int64(olang.Len())),
|
||||
decoderStart: ts.MustOfSlice([]int64{int64(olang.SosToken())}).MustTo(vs.Device(), true),
|
||||
decoderEos: int64(olang.EosToken()),
|
||||
device: vs.Device(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
|
||||
state := m.encoder.gru.ZeroState(1)
|
||||
fmt.Printf("state shape: %v\n", state.(nn.GRUState).Value().MustSize())
|
||||
var encOutputs []ts.Tensor
|
||||
|
||||
for _, v := range input {
|
||||
s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true)
|
||||
|
||||
outTs, outState := m.encoder.forward(s, state.(nn.GRUState))
|
||||
|
||||
encOutputs = append(encOutputs, outTs)
|
||||
state.(nn.GRUState).Tensor.MustDrop()
|
||||
state = outState
|
||||
}
|
||||
|
||||
stackTs := ts.MustStack(encOutputs, 1)
|
||||
for _, t := range encOutputs {
|
||||
t.MustDrop()
|
||||
}
|
||||
|
||||
// TODO: should we implement random here???
|
||||
loss := ts.TensorFrom([]float32{0.0}).MustTo(m.device, true)
|
||||
prev := m.decoderStart.MustShallowClone()
|
||||
|
||||
for _, s := range target {
|
||||
outTs, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true)
|
||||
|
||||
state.(nn.GRUState).Tensor.MustDrop()
|
||||
state = outState
|
||||
|
||||
targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true)
|
||||
fmt.Printf("targetTs shape: %v\n", targetTs.MustSize())
|
||||
fmt.Printf("outTs shape: %v\n", outTs.MustSize())
|
||||
|
||||
// TODO: fix the error: input tensor should be 1D or 2D at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:35
|
||||
ws := ts.NewTensor()
|
||||
currLoss := outTs.MustNllLoss(targetTs, ws, int64(1), -100, false)
|
||||
|
||||
loss.MustAdd_(currLoss)
|
||||
currLoss.MustDrop()
|
||||
|
||||
_, output := outTs.MustTopK(1, -1, true, true)
|
||||
|
||||
if m.decoderEos == outTs.Int64Values()[0] {
|
||||
break
|
||||
}
|
||||
|
||||
prev.MustDrop()
|
||||
prev = output
|
||||
}
|
||||
|
||||
return loss
|
||||
|
||||
}
|
||||
|
||||
func (m *Model) predict(input []int) (retVal []int) {
|
||||
state := m.encoder.gru.ZeroState(1)
|
||||
var encOutputs []ts.Tensor
|
||||
|
||||
for _, v := range input {
|
||||
s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true)
|
||||
outTs, outState := m.encoder.forward(s, state.(nn.GRUState))
|
||||
|
||||
encOutputs = append(encOutputs, outTs)
|
||||
state.(nn.GRUState).Tensor.MustDrop()
|
||||
state = outState
|
||||
}
|
||||
|
||||
stackTs := ts.MustStack(encOutputs, 1)
|
||||
for _, t := range encOutputs {
|
||||
t.MustDrop()
|
||||
}
|
||||
|
||||
prev := m.decoderStart.MustShallowClone()
|
||||
var outputSeq []int
|
||||
|
||||
for i := 0; i < int(MaxLength); i++ {
|
||||
outTs, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true)
|
||||
_, output := outTs.MustTopK(1, -1, true, true)
|
||||
outputVal := output.Int64Values()[0]
|
||||
outputSeq = append(outputSeq, int(outputVal))
|
||||
|
||||
if m.decoderEos == outTs.Int64Values()[0] {
|
||||
break
|
||||
}
|
||||
|
||||
state.(nn.GRUState).Tensor.MustDrop()
|
||||
state = outState
|
||||
prev.MustDrop()
|
||||
prev = output
|
||||
}
|
||||
|
||||
return outputSeq
|
||||
|
||||
}
|
||||
|
||||
type LossStats struct {
|
||||
totalLoss float64
|
||||
samples int
|
||||
}
|
||||
|
||||
func newLossStats() (retVal LossStats) {
|
||||
return LossStats{
|
||||
totalLoss: 0.0,
|
||||
samples: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (ls *LossStats) update(loss float64) {
|
||||
ls.totalLoss += loss
|
||||
ls.samples += 1
|
||||
}
|
||||
|
||||
func (ls *LossStats) avgAndReset() (retVal float64) {
|
||||
avg := ls.totalLoss / float64(ls.samples)
|
||||
ls.totalLoss = 0.0
|
||||
ls.samples = 0
|
||||
return avg
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
dataset := newDataset("eng", "fra", int(MaxLength)).Reverse()
|
||||
|
||||
ilang := dataset.InputLang()
|
||||
olang := dataset.OutputLang()
|
||||
pairs := dataset.Pairs()
|
||||
|
||||
fmt.Printf("Input: %v %v words\n", ilang.GetName(), ilang.Len())
|
||||
fmt.Printf("Output: %v %v words\n", olang.GetName(), olang.Len())
|
||||
fmt.Printf("Pairs: %v\n", len(pairs))
|
||||
|
||||
// TODO: should we implement random here??
|
||||
|
||||
cuda := gotch.NewCuda()
|
||||
device := cuda.CudaIfAvailable()
|
||||
|
||||
vs := nn.NewVarStore(device)
|
||||
|
||||
model := newModel(vs.Root(), ilang, olang, HiddenSize)
|
||||
|
||||
optConfig := nn.DefaultAdamConfig()
|
||||
opt, err := optConfig.Build(vs, LearningRate)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
lossStats := newLossStats()
|
||||
|
||||
for i := 1; i < int(Samples); i++ {
|
||||
// randomly choose a pair
|
||||
idx := rand.Intn(len(pairs))
|
||||
pair := pairs[idx]
|
||||
input := pair.Val1
|
||||
target := pair.Val2
|
||||
loss := model.trainLoss(input, target)
|
||||
panic("reached")
|
||||
opt.BackwardStep(loss)
|
||||
lossStats.update(loss.Float64Values()[0] / float64(len(target)))
|
||||
|
||||
if i%1000 == 0 {
|
||||
fmt.Printf("%v %v\n", i, lossStats.avgAndReset())
|
||||
for predIdx := 1; predIdx <= 5; predIdx++ {
|
||||
idx := rand.Intn(len(pairs))
|
||||
in := pairs[idx].Val1
|
||||
tgt := pairs[idx].Val2
|
||||
predict := model.predict(in)
|
||||
|
||||
fmt.Printf("input: %v\n", ilang.SeqToString(in))
|
||||
fmt.Printf("target: %v\n", olang.SeqToString(tgt))
|
||||
fmt.Printf("ouput: %v\n", olang.SeqToString(predict))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -75,16 +75,19 @@ func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tens
|
|||
var err error
|
||||
rand.Seed(86)
|
||||
|
||||
data := make([]float64, ts.FlattenDim(dims))
|
||||
data := make([]float32, ts.FlattenDim(dims))
|
||||
for i := range data {
|
||||
data[i] = rand.NormFloat64()*r.mean + r.stdev
|
||||
// NOTE. tensor will have DType = Float (float32)
|
||||
data[i] = float32(rand.NormFloat64()*r.mean + r.stdev)
|
||||
}
|
||||
|
||||
retVal, err = ts.NewTensorFromData(data, dims)
|
||||
newTs, err := ts.NewTensorFromData(data, dims)
|
||||
if err != nil {
|
||||
log.Fatalf("randInit - InitTensor method call error: %v\n", err)
|
||||
}
|
||||
|
||||
retVal = newTs.MustTo(device, true)
|
||||
|
||||
return retVal
|
||||
|
||||
}
|
||||
|
|
20
nn/rnn.go
20
nn/rnn.go
|
@ -1,6 +1,8 @@
|
|||
package nn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
@ -199,7 +201,7 @@ type GRU struct {
|
|||
}
|
||||
|
||||
// NewGRU create a new GRU layer
|
||||
func NewGRU(vs *Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
|
||||
func NewGRU(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
|
||||
var numDirections int64 = 1
|
||||
if cfg.Bidirectional {
|
||||
numDirections = 2
|
||||
|
@ -210,11 +212,14 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
|
|||
|
||||
for i := 0; i < int(cfg.NumLayers); i++ {
|
||||
for n := 0; n < int(numDirections); n++ {
|
||||
if i != 0 {
|
||||
inDim = hiddenDim * numDirections
|
||||
var inputDim int64
|
||||
if i == 0 {
|
||||
inputDim = inDim
|
||||
} else {
|
||||
inputDim = hiddenDim * numDirections
|
||||
}
|
||||
|
||||
wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inDim})
|
||||
wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inputDim})
|
||||
wHh := vs.KaimingUniform("w_hh", []int64{gateDim, hiddenDim})
|
||||
bIh := vs.Zeros("b_ih", []int64{gateDim})
|
||||
bHh := vs.Zeros("b_hh", []int64{gateDim})
|
||||
|
@ -249,9 +254,12 @@ func (g GRU) ZeroState(batchDim int64) (retVal State) {
|
|||
}
|
||||
|
||||
func (g GRU) Step(input ts.Tensor, inState State) (retVal State) {
|
||||
ip := input.MustUnsqueeze(1, false)
|
||||
unsqueezedInput := input.MustUnsqueeze(1, false)
|
||||
|
||||
output, state := g.SeqInit(ip, inState)
|
||||
fmt.Printf("unsqueezed input shape: %v\n", unsqueezedInput.MustSize())
|
||||
fmt.Printf("inState Ts size: %v\n", inState.(GRUState).Tensor.MustSize())
|
||||
|
||||
output, state := g.SeqInit(unsqueezedInput, inState)
|
||||
|
||||
// NOTE: though we won't use `output`, it is a Ctensor created in C land, so
|
||||
// it should be cleaned up here to prevent memory hold-up.
|
||||
|
|
|
@ -22,7 +22,7 @@ func gruTest(rnnConfig nn.RNNConfig, t *testing.T) {
|
|||
vs := nn.NewVarStore(gotch.CPU)
|
||||
path := vs.Root()
|
||||
|
||||
gru := nn.NewGRU(&path, inputDim, outputDim, rnnConfig)
|
||||
gru := nn.NewGRU(path, inputDim, outputDim, rnnConfig)
|
||||
|
||||
numDirections := int64(1)
|
||||
if rnnConfig.Bidirectional {
|
||||
|
|
|
@ -297,6 +297,15 @@ func (ts Tensor) Device() (retVal gotch.Device, err error) {
|
|||
return device.OfCInt(int32(cInt)), nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustDevice() (retVal gotch.Device) {
|
||||
retVal, err := ts.Device()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
/*
|
||||
* func (ts Tensor) Eq1(other Tensor, del bool) (retVal Tensor, err error) {
|
||||
*
|
||||
|
|
Loading…
Reference in New Issue
Block a user