WIP: tensor io

This commit is contained in:
sugarme 2020-06-08 17:06:35 +10:00
parent 8f32baff08
commit 049aa29d8a
4 changed files with 97 additions and 0 deletions

BIN
example/tensor-io/file.pt Normal file

Binary file not shown.

20
example/tensor-io/main.go Normal file
View File

@ -0,0 +1,20 @@
package main
import (
wrapper "github.com/sugarme/gotch/wrapper"
)
func main() {
ts, err := wrapper.OfSlice([]float64{1.3, 29.7})
if err != nil {
panic(err)
}
path := "file.pt"
ts.MustSave(path)
loadedTs := wrapper.MustLoad(path)
loadedTs.Print()
}

View File

@ -2,6 +2,7 @@ package libtch
//#include "stdbool.h"
//#include "torch_api.h"
//#include "stdlib.h"
import "C"
import (
@ -179,3 +180,18 @@ func AtCopy_(dst Ctensor, src Ctensor) {
csrc := (C.tensor)(src)
C.at_copy_(cdst, csrc)
}
// void at_save(tensor, char *filename);
func AtSave(ts Ctensor, path string) {
ctensor := (C.tensor)(ts)
cstringPtr := C.CString(path)
defer C.free(unsafe.Pointer(cstringPtr))
C.at_save(ctensor, cstringPtr)
}
// tensor at_load(char *filename);
func AtLoad(path string) Ctensor {
cstringPtr := C.CString(path)
defer C.free(unsafe.Pointer(cstringPtr))
return C.at_load(cstringPtr)
}

View File

@ -633,3 +633,64 @@ func MustCopy_(self, src Tensor) {
log.Fatal(err)
}
}
// Save saves a tensor to a file.
func (ts Tensor) Save(path string) (err error) {
lib.AtSave(ts.ctensor, path)
if err = TorchErr(); err != nil {
return err
}
return nil
}
// MustSave saves a tensor to a file. It will panic if error
func (ts Tensor) MustSave(path string) {
if err := ts.Save(path); err != nil {
log.Fatal(err)
}
}
// Load loads a tensor from a file.
func Load(path string) (retVal Tensor, err error) {
ctensor := lib.AtLoad(path)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor}
return retVal, nil
}
// MustLoad loads a tensor to a file. It will panic if error
func MustLoad(path string) (retVal Tensor) {
retVal, err := Load(path)
if err != nil {
log.Fatal(err)
}
return retVal
}
type NamedTensor struct {
Name string
Tensor Tensor
}
// SaveMulti saves some named tensors to a file
//
// The file format is the same as the one used by the PyTorch C++ API.
func SaveMulti(namedTensors []NamedTensor, path string) (err error) {
var ctensors []Ctensor
var names []string
for _, ts := range namedTensors {
ctensors = append(ctensors, ts.Tensor.ctensor)
names = append(names, ts.Name)
}
return nil
}