WIP: tensor io
This commit is contained in:
parent
8f32baff08
commit
049aa29d8a
BIN
example/tensor-io/file.pt
Normal file
BIN
example/tensor-io/file.pt
Normal file
Binary file not shown.
20
example/tensor-io/main.go
Normal file
20
example/tensor-io/main.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
wrapper "github.com/sugarme/gotch/wrapper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
ts, err := wrapper.OfSlice([]float64{1.3, 29.7})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
path := "file.pt"
|
||||
ts.MustSave(path)
|
||||
|
||||
loadedTs := wrapper.MustLoad(path)
|
||||
|
||||
loadedTs.Print()
|
||||
}
|
|
@ -2,6 +2,7 @@ package libtch
|
|||
|
||||
//#include "stdbool.h"
|
||||
//#include "torch_api.h"
|
||||
//#include "stdlib.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
|
@ -179,3 +180,18 @@ func AtCopy_(dst Ctensor, src Ctensor) {
|
|||
csrc := (C.tensor)(src)
|
||||
C.at_copy_(cdst, csrc)
|
||||
}
|
||||
|
||||
// void at_save(tensor, char *filename);
|
||||
func AtSave(ts Ctensor, path string) {
|
||||
ctensor := (C.tensor)(ts)
|
||||
cstringPtr := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cstringPtr))
|
||||
C.at_save(ctensor, cstringPtr)
|
||||
}
|
||||
|
||||
// tensor at_load(char *filename);
|
||||
func AtLoad(path string) Ctensor {
|
||||
cstringPtr := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cstringPtr))
|
||||
return C.at_load(cstringPtr)
|
||||
}
|
||||
|
|
|
@ -633,3 +633,64 @@ func MustCopy_(self, src Tensor) {
|
|||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Save saves a tensor to a file.
|
||||
func (ts Tensor) Save(path string) (err error) {
|
||||
lib.AtSave(ts.ctensor, path)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MustSave saves a tensor to a file. It will panic if error
|
||||
func (ts Tensor) MustSave(path string) {
|
||||
if err := ts.Save(path); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load loads a tensor from a file.
|
||||
func Load(path string) (retVal Tensor, err error) {
|
||||
ctensor := lib.AtLoad(path)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// MustLoad loads a tensor to a file. It will panic if error
|
||||
func MustLoad(path string) (retVal Tensor) {
|
||||
retVal, err := Load(path)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
type NamedTensor struct {
|
||||
Name string
|
||||
Tensor Tensor
|
||||
}
|
||||
|
||||
// SaveMulti saves some named tensors to a file
|
||||
//
|
||||
// The file format is the same as the one used by the PyTorch C++ API.
|
||||
func SaveMulti(namedTensors []NamedTensor, path string) (err error) {
|
||||
var ctensors []Ctensor
|
||||
var names []string
|
||||
|
||||
for _, ts := range namedTensors {
|
||||
ctensors = append(ctensors, ts.Tensor.ctensor)
|
||||
names = append(names, ts.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user