added train and eval atm_ API

This commit is contained in:
sugarme 2021-01-02 15:40:52 +11:00
parent 82113d7225
commit 2fb8acef07
5 changed files with 75 additions and 0 deletions

35
example/jit-train/main.go Normal file
View File

@ -0,0 +1,35 @@
package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
// ts "github.com/sugarme/gotch/tensor"
// "github.com/sugarme/gotch/vision"
)
func main() {
runTrainAndSaveModel(gotch.CudaIfAvailable())
}
func runTrainAndSaveModel(device gotch.Device) {
file := "./model.pt"
vs := nn.NewVarStore(device)
trainable, err := nn.TrainableCModuleLoad(vs.Root(), file)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Trainable JIT model loaded.\n")
namedTensors, err := trainable.Inner.NamedParameters()
if err != nil {
log.Fatal(err)
}
for _, x := range namedTensors {
fmt.Println(x.Name)
}
}

View File

@ -776,3 +776,13 @@ func AtmSetProfilingMode(b bool) {
cbool := *(*C.int)(unsafe.Pointer(&b))
C.atm_set_profiling_mode(cbool)
}
// void atm_eval(module);
func AtmEval(m Cmodule) {
C.atm_eval(m)
}
// void atm_train(module);
func AtmTrain(m Cmodule) {
C.atm_train(m)
}

View File

@ -850,6 +850,18 @@ ivalue atm_method_(module m, char *method_name, ivalue *ivalues, int nivalues) {
return nullptr;
}
void atm_eval(module m) {
PROTECT(
m->eval();
)
}
void atm_train(module m) {
PROTECT(
m->train();
)
}
void atm_free(module m) {
delete(m);
}

View File

@ -157,6 +157,8 @@ int atm_get_profiling_mode();
void atm_set_profiling_mode(int);
void atm_named_parameters(module, void *data,
void (*f)(void *, char *, tensor));
void atm_eval(module);
void atm_train(module);
ivalue ati_none();
ivalue ati_tensor(tensor);

View File

@ -1104,6 +1104,22 @@ func (cm *CModule) SetProfilingMode(b bool) {
}
}
// Train set CModule to train mode
func (cm *CModule) Train() {
lib.AtmTrain()
if err := TorchErr(); err != nil {
log.Fatal(err)
}
}
// Eval set CModule to inference mode
func (cm *CModule) Train() {
lib.AtmEval()
if err := TorchErr(); err != nil {
log.Fatal(err)
}
}
// Implement Module for CModule:
// =============================