added train and eval atm_ API
This commit is contained in:
parent
82113d7225
commit
2fb8acef07
35
example/jit-train/main.go
Normal file
35
example/jit-train/main.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
// ts "github.com/sugarme/gotch/tensor"
|
||||
// "github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
func main() {
|
||||
runTrainAndSaveModel(gotch.CudaIfAvailable())
|
||||
}
|
||||
|
||||
func runTrainAndSaveModel(device gotch.Device) {
|
||||
|
||||
file := "./model.pt"
|
||||
vs := nn.NewVarStore(device)
|
||||
trainable, err := nn.TrainableCModuleLoad(vs.Root(), file)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Trainable JIT model loaded.\n")
|
||||
|
||||
namedTensors, err := trainable.Inner.NamedParameters()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for _, x := range namedTensors {
|
||||
fmt.Println(x.Name)
|
||||
}
|
||||
}
|
|
@ -776,3 +776,13 @@ func AtmSetProfilingMode(b bool) {
|
|||
cbool := *(*C.int)(unsafe.Pointer(&b))
|
||||
C.atm_set_profiling_mode(cbool)
|
||||
}
|
||||
|
||||
// void atm_eval(module);
|
||||
func AtmEval(m Cmodule) {
|
||||
C.atm_eval(m)
|
||||
}
|
||||
|
||||
// void atm_train(module);
|
||||
func AtmTrain(m Cmodule) {
|
||||
C.atm_train(m)
|
||||
}
|
||||
|
|
|
@ -850,6 +850,18 @@ ivalue atm_method_(module m, char *method_name, ivalue *ivalues, int nivalues) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void atm_eval(module m) {
|
||||
PROTECT(
|
||||
m->eval();
|
||||
)
|
||||
}
|
||||
|
||||
void atm_train(module m) {
|
||||
PROTECT(
|
||||
m->train();
|
||||
)
|
||||
}
|
||||
|
||||
void atm_free(module m) {
|
||||
delete(m);
|
||||
}
|
||||
|
|
|
@ -157,6 +157,8 @@ int atm_get_profiling_mode();
|
|||
void atm_set_profiling_mode(int);
|
||||
void atm_named_parameters(module, void *data,
|
||||
void (*f)(void *, char *, tensor));
|
||||
void atm_eval(module);
|
||||
void atm_train(module);
|
||||
|
||||
ivalue ati_none();
|
||||
ivalue ati_tensor(tensor);
|
||||
|
|
|
@ -1104,6 +1104,22 @@ func (cm *CModule) SetProfilingMode(b bool) {
|
|||
}
|
||||
}
|
||||
|
||||
// Train set CModule to train mode
|
||||
func (cm *CModule) Train() {
|
||||
lib.AtmTrain()
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Eval set CModule to inference mode
|
||||
func (cm *CModule) Train() {
|
||||
lib.AtmEval()
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Implement Module for CModule:
|
||||
// =============================
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user