From 2fb8acef0799a81e0f25bf2a3da7f5be448f15df Mon Sep 17 00:00:00 2001 From: sugarme Date: Sat, 2 Jan 2021 15:40:52 +1100 Subject: [PATCH] added train and eval atm_ API --- example/jit-train/main.go | 35 +++++++++++++++++++++++++++++++++++ libtch/tensor.go | 10 ++++++++++ libtch/torch_api.cpp | 12 ++++++++++++ libtch/torch_api.h | 2 ++ tensor/jit.go | 16 ++++++++++++++++ 5 files changed, 75 insertions(+) create mode 100644 example/jit-train/main.go diff --git a/example/jit-train/main.go b/example/jit-train/main.go new file mode 100644 index 0000000..a077514 --- /dev/null +++ b/example/jit-train/main.go @@ -0,0 +1,35 @@ +package main + +import ( + "fmt" + "log" + + "github.com/sugarme/gotch" + "github.com/sugarme/gotch/nn" + // ts "github.com/sugarme/gotch/tensor" + // "github.com/sugarme/gotch/vision" +) + +func main() { + runTrainAndSaveModel(gotch.CudaIfAvailable()) +} + +func runTrainAndSaveModel(device gotch.Device) { + + file := "./model.pt" + vs := nn.NewVarStore(device) + trainable, err := nn.TrainableCModuleLoad(vs.Root(), file) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Trainable JIT model loaded.\n") + + namedTensors, err := trainable.Inner.NamedParameters() + if err != nil { + log.Fatal(err) + } + + for _, x := range namedTensors { + fmt.Println(x.Name) + } +} diff --git a/libtch/tensor.go b/libtch/tensor.go index 6dcb5c6..f0dfce3 100644 --- a/libtch/tensor.go +++ b/libtch/tensor.go @@ -776,3 +776,13 @@ func AtmSetProfilingMode(b bool) { cbool := *(*C.int)(unsafe.Pointer(&b)) C.atm_set_profiling_mode(cbool) } + +// void atm_eval(module); +func AtmEval(m Cmodule) { + C.atm_eval(m) +} + +// void atm_train(module); +func AtmTrain(m Cmodule) { + C.atm_train(m) +} diff --git a/libtch/torch_api.cpp b/libtch/torch_api.cpp index 0462519..addd6b0 100644 --- a/libtch/torch_api.cpp +++ b/libtch/torch_api.cpp @@ -850,6 +850,18 @@ ivalue atm_method_(module m, char *method_name, ivalue *ivalues, int nivalues) { return nullptr; } +void atm_eval(module m) { + PROTECT( + m->eval(); + ) +} + +void atm_train(module m) { + PROTECT( + m->train(); + ) +} + void atm_free(module m) { delete(m); } diff --git a/libtch/torch_api.h b/libtch/torch_api.h index 8e29e46..a9420cc 100644 --- a/libtch/torch_api.h +++ b/libtch/torch_api.h @@ -157,6 +157,8 @@ int atm_get_profiling_mode(); void atm_set_profiling_mode(int); void atm_named_parameters(module, void *data, void (*f)(void *, char *, tensor)); +void atm_eval(module); +void atm_train(module); ivalue ati_none(); ivalue ati_tensor(tensor); diff --git a/tensor/jit.go b/tensor/jit.go index f0bdb1b..acde6dc 100644 --- a/tensor/jit.go +++ b/tensor/jit.go @@ -1104,6 +1104,22 @@ func (cm *CModule) SetProfilingMode(b bool) { } } +// Train set CModule to train mode +func (cm *CModule) Train() { + lib.AtmTrain() + if err := TorchErr(); err != nil { + log.Fatal(err) + } +} + +// Eval set CModule to inference mode +func (cm *CModule) Train() { + lib.AtmEval() + if err := TorchErr(); err != nil { + log.Fatal(err) + } +} + // Implement Module for CModule: // =============================