feat(nn/sparse): append ForwardT method for Embedding

This commit is contained in:
sugarme 2020-08-04 15:16:21 +10:00
parent 4721ffe670
commit 79a0c2e6bf

View File

@ -40,8 +40,15 @@ func NewEmbedding(vs Path, numEmbeddings int64, embeddingDim int64, config Embed
}
}
// Implement Module interface for Embedding:
// Implement Module, ModuleT interfaces for Embedding:
// =========================================
// Forward implements Module interface for Embedding
func (e Embedding) Forward(xs ts.Tensor) (retVal ts.Tensor) {
return ts.MustEmbedding(e.Ws, xs, e.config.PaddingIdx, e.config.ScaleGradByFreq, e.config.Sparse)
}
// ForwardT implements ModuleT interface for Embedding
func (e Embedding) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
return ts.MustEmbedding(e.Ws, xs, e.config.PaddingIdx, e.config.ScaleGradByFreq, e.config.Sparse)
}