From 049aa29d8a8511f1549d88f5b43a1d98afbf7658 Mon Sep 17 00:00:00 2001 From: sugarme Date: Mon, 8 Jun 2020 17:06:35 +1000 Subject: [PATCH] WIP: tensor io --- example/tensor-io/file.pt | Bin 0 -> 1334 bytes example/tensor-io/main.go | 20 +++++++++++++ libtch/tensor.go | 16 ++++++++++ wrapper/tensor.go | 61 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+) create mode 100644 example/tensor-io/file.pt create mode 100644 example/tensor-io/main.go diff --git a/example/tensor-io/file.pt b/example/tensor-io/file.pt new file mode 100644 index 0000000000000000000000000000000000000000..382a7332794fcea97976801adb63300a66c5c53b GIT binary patch literal 1334 zcmZ`&&2G~`5MDctlWZFa<&XnH$f|m9q>_}PUXUVCROPb#m{5tRs9HJp(nPp+)?Qa2 zffNpaxN+pj10eO#2j~lw8+UjBj))^-)=m>UiANf(y))mpv)}BPq_Boixr~-<1J%)s zk;_`AEC@%w*X`8M;gaoQla$Jz&)z&xfkC>N%ean;38P~PtF_bZDCqF}rL;fy?npat z->;K2q3=X4!yGl>8t_U}%5%Lv z{&@yG((~o1qiQ0L?K`YByJK_wy4$JDb+TP64E5yVMzS>SVqK?nYz5W{6VQB0^|tV33e>lPGc!W>b}(jtv>eSbD}5k?BcT&!Bqn7bBBdS_G!wNHqieng+^SE<)2? zs%M}VFcI;YOHXQ%Bdxlg;dlkgtmc?^cabfvZ7sw0d(M{Xn@KcndzyE%aBGIpre9a< gL{2!^ft*S(z{xR2KV%cZWY{FdI9)>;lKDyQe|W_iI{*Lx literal 0 HcmV?d00001 diff --git a/example/tensor-io/main.go b/example/tensor-io/main.go new file mode 100644 index 0000000..960665c --- /dev/null +++ b/example/tensor-io/main.go @@ -0,0 +1,20 @@ +package main + +import ( + wrapper "github.com/sugarme/gotch/wrapper" +) + +func main() { + + ts, err := wrapper.OfSlice([]float64{1.3, 29.7}) + if err != nil { + panic(err) + } + + path := "file.pt" + ts.MustSave(path) + + loadedTs := wrapper.MustLoad(path) + + loadedTs.Print() +} diff --git a/libtch/tensor.go b/libtch/tensor.go index bd7000e..5d8ef35 100644 --- a/libtch/tensor.go +++ b/libtch/tensor.go @@ -2,6 +2,7 @@ package libtch //#include "stdbool.h" //#include "torch_api.h" +//#include "stdlib.h" import "C" import ( @@ -179,3 +180,18 @@ func AtCopy_(dst Ctensor, src Ctensor) { csrc := (C.tensor)(src) C.at_copy_(cdst, csrc) } + +// void at_save(tensor, char *filename); +func AtSave(ts Ctensor, path string) { + ctensor := (C.tensor)(ts) + cstringPtr := C.CString(path) + defer C.free(unsafe.Pointer(cstringPtr)) + C.at_save(ctensor, cstringPtr) +} + +// tensor at_load(char *filename); +func AtLoad(path string) Ctensor { + cstringPtr := C.CString(path) + defer C.free(unsafe.Pointer(cstringPtr)) + return C.at_load(cstringPtr) +} diff --git a/wrapper/tensor.go b/wrapper/tensor.go index 2d335eb..1d96107 100644 --- a/wrapper/tensor.go +++ b/wrapper/tensor.go @@ -633,3 +633,64 @@ func MustCopy_(self, src Tensor) { log.Fatal(err) } } + +// Save saves a tensor to a file. +func (ts Tensor) Save(path string) (err error) { + lib.AtSave(ts.ctensor, path) + + if err = TorchErr(); err != nil { + return err + } + + return nil +} + +// MustSave saves a tensor to a file. It will panic if error +func (ts Tensor) MustSave(path string) { + if err := ts.Save(path); err != nil { + log.Fatal(err) + } +} + +// Load loads a tensor from a file. +func Load(path string) (retVal Tensor, err error) { + ctensor := lib.AtLoad(path) + + if err = TorchErr(); err != nil { + return retVal, err + } + + retVal = Tensor{ctensor} + + return retVal, nil +} + +// MustLoad loads a tensor to a file. It will panic if error +func MustLoad(path string) (retVal Tensor) { + retVal, err := Load(path) + if err != nil { + log.Fatal(err) + } + + return retVal +} + +type NamedTensor struct { + Name string + Tensor Tensor +} + +// SaveMulti saves some named tensors to a file +// +// The file format is the same as the one used by the PyTorch C++ API. +func SaveMulti(namedTensors []NamedTensor, path string) (err error) { + var ctensors []Ctensor + var names []string + + for _, ts := range namedTensors { + ctensors = append(ctensors, ts.Tensor.ctensor) + names = append(names, ts.Name) + } + + return nil +}