WIP: tensor io
This commit is contained in:
parent
8f32baff08
commit
049aa29d8a
BIN
example/tensor-io/file.pt
Normal file
BIN
example/tensor-io/file.pt
Normal file
Binary file not shown.
20
example/tensor-io/main.go
Normal file
20
example/tensor-io/main.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
wrapper "github.com/sugarme/gotch/wrapper"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
|
||||||
|
ts, err := wrapper.OfSlice([]float64{1.3, 29.7})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
path := "file.pt"
|
||||||
|
ts.MustSave(path)
|
||||||
|
|
||||||
|
loadedTs := wrapper.MustLoad(path)
|
||||||
|
|
||||||
|
loadedTs.Print()
|
||||||
|
}
|
|
@ -2,6 +2,7 @@ package libtch
|
||||||
|
|
||||||
//#include "stdbool.h"
|
//#include "stdbool.h"
|
||||||
//#include "torch_api.h"
|
//#include "torch_api.h"
|
||||||
|
//#include "stdlib.h"
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -179,3 +180,18 @@ func AtCopy_(dst Ctensor, src Ctensor) {
|
||||||
csrc := (C.tensor)(src)
|
csrc := (C.tensor)(src)
|
||||||
C.at_copy_(cdst, csrc)
|
C.at_copy_(cdst, csrc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// void at_save(tensor, char *filename);
|
||||||
|
func AtSave(ts Ctensor, path string) {
|
||||||
|
ctensor := (C.tensor)(ts)
|
||||||
|
cstringPtr := C.CString(path)
|
||||||
|
defer C.free(unsafe.Pointer(cstringPtr))
|
||||||
|
C.at_save(ctensor, cstringPtr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tensor at_load(char *filename);
|
||||||
|
func AtLoad(path string) Ctensor {
|
||||||
|
cstringPtr := C.CString(path)
|
||||||
|
defer C.free(unsafe.Pointer(cstringPtr))
|
||||||
|
return C.at_load(cstringPtr)
|
||||||
|
}
|
||||||
|
|
|
@ -633,3 +633,64 @@ func MustCopy_(self, src Tensor) {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Save saves a tensor to a file.
|
||||||
|
func (ts Tensor) Save(path string) (err error) {
|
||||||
|
lib.AtSave(ts.ctensor, path)
|
||||||
|
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustSave saves a tensor to a file. It will panic if error
|
||||||
|
func (ts Tensor) MustSave(path string) {
|
||||||
|
if err := ts.Save(path); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load loads a tensor from a file.
|
||||||
|
func Load(path string) (retVal Tensor, err error) {
|
||||||
|
ctensor := lib.AtLoad(path)
|
||||||
|
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustLoad loads a tensor to a file. It will panic if error
|
||||||
|
func MustLoad(path string) (retVal Tensor) {
|
||||||
|
retVal, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
type NamedTensor struct {
|
||||||
|
Name string
|
||||||
|
Tensor Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveMulti saves some named tensors to a file
|
||||||
|
//
|
||||||
|
// The file format is the same as the one used by the PyTorch C++ API.
|
||||||
|
func SaveMulti(namedTensors []NamedTensor, path string) (err error) {
|
||||||
|
var ctensors []Ctensor
|
||||||
|
var names []string
|
||||||
|
|
||||||
|
for _, ts := range namedTensors {
|
||||||
|
ctensors = append(ctensors, ts.Tensor.ctensor)
|
||||||
|
names = append(names, ts.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user