added(nn/jit): ForwadT method

This commit is contained in:
sugarme 2021-01-02 14:24:29 +11:00
parent d0727911c4
commit 82113d7225

View File

@ -2,6 +2,7 @@ package nn
import (
"io"
"log"
"strings"
ts "github.com/sugarme/gotch/tensor"
@ -74,3 +75,14 @@ func TrainableCModuleLoadData(p *Path, stream io.Reader) (*TrainableCModule, err
func (m *TrainableCModule) Save(file string) error {
return m.Inner.Save(file)
}
// ForwardT implements ModuleT for TrainableCModule.
// NOTE: train parameter will not be used.
func (m *TrainableCModule) ForwardT(x *ts.Tensor, train bool) *ts.Tensor {
retVal, err := m.Inner.ForwardTs([]ts.Tensor{*x})
if err != nil {
log.Fatal(err)
}
return retVal
}