diff --git a/example/tensor-io/file.pt b/example/tensor-io/file.pt new file mode 100644 index 0000000..382a733 Binary files /dev/null and b/example/tensor-io/file.pt differ diff --git a/example/tensor-io/main.go b/example/tensor-io/main.go new file mode 100644 index 0000000..960665c --- /dev/null +++ b/example/tensor-io/main.go @@ -0,0 +1,20 @@ +package main + +import ( + wrapper "github.com/sugarme/gotch/wrapper" +) + +func main() { + + ts, err := wrapper.OfSlice([]float64{1.3, 29.7}) + if err != nil { + panic(err) + } + + path := "file.pt" + ts.MustSave(path) + + loadedTs := wrapper.MustLoad(path) + + loadedTs.Print() +} diff --git a/libtch/tensor.go b/libtch/tensor.go index bd7000e..5d8ef35 100644 --- a/libtch/tensor.go +++ b/libtch/tensor.go @@ -2,6 +2,7 @@ package libtch //#include "stdbool.h" //#include "torch_api.h" +//#include "stdlib.h" import "C" import ( @@ -179,3 +180,18 @@ func AtCopy_(dst Ctensor, src Ctensor) { csrc := (C.tensor)(src) C.at_copy_(cdst, csrc) } + +// void at_save(tensor, char *filename); +func AtSave(ts Ctensor, path string) { + ctensor := (C.tensor)(ts) + cstringPtr := C.CString(path) + defer C.free(unsafe.Pointer(cstringPtr)) + C.at_save(ctensor, cstringPtr) +} + +// tensor at_load(char *filename); +func AtLoad(path string) Ctensor { + cstringPtr := C.CString(path) + defer C.free(unsafe.Pointer(cstringPtr)) + return C.at_load(cstringPtr) +} diff --git a/wrapper/tensor.go b/wrapper/tensor.go index 2d335eb..1d96107 100644 --- a/wrapper/tensor.go +++ b/wrapper/tensor.go @@ -633,3 +633,64 @@ func MustCopy_(self, src Tensor) { log.Fatal(err) } } + +// Save saves a tensor to a file. +func (ts Tensor) Save(path string) (err error) { + lib.AtSave(ts.ctensor, path) + + if err = TorchErr(); err != nil { + return err + } + + return nil +} + +// MustSave saves a tensor to a file. It will panic if error +func (ts Tensor) MustSave(path string) { + if err := ts.Save(path); err != nil { + log.Fatal(err) + } +} + +// Load loads a tensor from a file. +func Load(path string) (retVal Tensor, err error) { + ctensor := lib.AtLoad(path) + + if err = TorchErr(); err != nil { + return retVal, err + } + + retVal = Tensor{ctensor} + + return retVal, nil +} + +// MustLoad loads a tensor to a file. It will panic if error +func MustLoad(path string) (retVal Tensor) { + retVal, err := Load(path) + if err != nil { + log.Fatal(err) + } + + return retVal +} + +type NamedTensor struct { + Name string + Tensor Tensor +} + +// SaveMulti saves some named tensors to a file +// +// The file format is the same as the one used by the PyTorch C++ API. +func SaveMulti(namedTensors []NamedTensor, path string) (err error) { + var ctensors []Ctensor + var names []string + + for _, ts := range namedTensors { + ctensors = append(ctensors, ts.Tensor.ctensor) + names = append(names, ts.Name) + } + + return nil +}