added train and eval atm_ API
This commit is contained in:
parent
82113d7225
commit
2fb8acef07
35
example/jit-train/main.go
Normal file
35
example/jit-train/main.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
"github.com/sugarme/gotch/nn"
|
||||||
|
// ts "github.com/sugarme/gotch/tensor"
|
||||||
|
// "github.com/sugarme/gotch/vision"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
runTrainAndSaveModel(gotch.CudaIfAvailable())
|
||||||
|
}
|
||||||
|
|
||||||
|
func runTrainAndSaveModel(device gotch.Device) {
|
||||||
|
|
||||||
|
file := "./model.pt"
|
||||||
|
vs := nn.NewVarStore(device)
|
||||||
|
trainable, err := nn.TrainableCModuleLoad(vs.Root(), file)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("Trainable JIT model loaded.\n")
|
||||||
|
|
||||||
|
namedTensors, err := trainable.Inner.NamedParameters()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, x := range namedTensors {
|
||||||
|
fmt.Println(x.Name)
|
||||||
|
}
|
||||||
|
}
|
|
@ -776,3 +776,13 @@ func AtmSetProfilingMode(b bool) {
|
||||||
cbool := *(*C.int)(unsafe.Pointer(&b))
|
cbool := *(*C.int)(unsafe.Pointer(&b))
|
||||||
C.atm_set_profiling_mode(cbool)
|
C.atm_set_profiling_mode(cbool)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// void atm_eval(module);
|
||||||
|
func AtmEval(m Cmodule) {
|
||||||
|
C.atm_eval(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// void atm_train(module);
|
||||||
|
func AtmTrain(m Cmodule) {
|
||||||
|
C.atm_train(m)
|
||||||
|
}
|
||||||
|
|
|
@ -850,6 +850,18 @@ ivalue atm_method_(module m, char *method_name, ivalue *ivalues, int nivalues) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void atm_eval(module m) {
|
||||||
|
PROTECT(
|
||||||
|
m->eval();
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
void atm_train(module m) {
|
||||||
|
PROTECT(
|
||||||
|
m->train();
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
void atm_free(module m) {
|
void atm_free(module m) {
|
||||||
delete(m);
|
delete(m);
|
||||||
}
|
}
|
||||||
|
|
|
@ -157,6 +157,8 @@ int atm_get_profiling_mode();
|
||||||
void atm_set_profiling_mode(int);
|
void atm_set_profiling_mode(int);
|
||||||
void atm_named_parameters(module, void *data,
|
void atm_named_parameters(module, void *data,
|
||||||
void (*f)(void *, char *, tensor));
|
void (*f)(void *, char *, tensor));
|
||||||
|
void atm_eval(module);
|
||||||
|
void atm_train(module);
|
||||||
|
|
||||||
ivalue ati_none();
|
ivalue ati_none();
|
||||||
ivalue ati_tensor(tensor);
|
ivalue ati_tensor(tensor);
|
||||||
|
|
|
@ -1104,6 +1104,22 @@ func (cm *CModule) SetProfilingMode(b bool) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Train set CModule to train mode
|
||||||
|
func (cm *CModule) Train() {
|
||||||
|
lib.AtmTrain()
|
||||||
|
if err := TorchErr(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Eval set CModule to inference mode
|
||||||
|
func (cm *CModule) Train() {
|
||||||
|
lib.AtmEval()
|
||||||
|
if err := TorchErr(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Implement Module for CModule:
|
// Implement Module for CModule:
|
||||||
// =============================
|
// =============================
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user