feat(nn/sparse): append ForwardT method for Embedding
This commit is contained in:
parent
4721ffe670
commit
79a0c2e6bf
|
@ -40,8 +40,15 @@ func NewEmbedding(vs Path, numEmbeddings int64, embeddingDim int64, config Embed
|
|||
}
|
||||
}
|
||||
|
||||
// Implement Module interface for Embedding:
|
||||
// Implement Module, ModuleT interfaces for Embedding:
|
||||
// =========================================
|
||||
|
||||
// Forward implements Module interface for Embedding
|
||||
func (e Embedding) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||
return ts.MustEmbedding(e.Ws, xs, e.config.PaddingIdx, e.config.ScaleGradByFreq, e.config.Sparse)
|
||||
}
|
||||
|
||||
// ForwardT implements ModuleT interface for Embedding
|
||||
func (e Embedding) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||
return ts.MustEmbedding(e.Ws, xs, e.config.PaddingIdx, e.config.ScaleGradByFreq, e.config.Sparse)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user