added(nn/jit): ForwadT method
This commit is contained in:
parent
d0727911c4
commit
82113d7225
12
nn/jit.go
12
nn/jit.go
|
@ -2,6 +2,7 @@ package nn
|
|||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
|
@ -74,3 +75,14 @@ func TrainableCModuleLoadData(p *Path, stream io.Reader) (*TrainableCModule, err
|
|||
func (m *TrainableCModule) Save(file string) error {
|
||||
return m.Inner.Save(file)
|
||||
}
|
||||
|
||||
// ForwardT implements ModuleT for TrainableCModule.
|
||||
// NOTE: train parameter will not be used.
|
||||
func (m *TrainableCModule) ForwardT(x *ts.Tensor, train bool) *ts.Tensor {
|
||||
retVal, err := m.Inner.ForwardTs([]ts.Tensor{*x})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user