feat(pointer-store): moved pointer-store.go to libtch package

This commit is contained in:
sugarme 2020-06-10 17:38:14 +10:00
parent 13f18878ae
commit a3f2b4c2c9
2 changed files with 18 additions and 43 deletions

View File

@ -9,8 +9,6 @@ package libtch
import "C"
import (
"fmt"
// "time"
"unsafe"
)
@ -18,6 +16,17 @@ import (
type Ctensor = C.tensor
type Cscalar = C.scalar
type NamedCtensor struct {
Name string
Ctensor C.tensor
}
type LoadData struct {
NamedCtensors []NamedCtensor
}
var PStore = NewPointerStore()
func AtNewTensor() Ctensor {
return C.at_new_tensor()
}
@ -228,9 +237,6 @@ func AtLoadMulti(tensors []Ctensor, tensor_names []string, ntensors int, filenam
// TODO: implement this
}
var pStore PointerStore = NewPointerStore()
var namedCtensors []NamedCtensor = make([]NamedCtensor, 0)
// void at_load_callback(char *filename, void *data, void (*f)(void *, char *, tensor));
/*
* void at_load_callback(char *filename, void *data, void (*f)(void *, char *, tensor)) {
@ -243,55 +249,22 @@ var namedCtensors []NamedCtensor = make([]NamedCtensor, 0)
* )
* }
* */
// func AtLoadCallback(filename string, data unsafe.Pointer, callbackFn *func(unsafe.Pointer, *C.char, C.tensor)) {
func AtLoadCallback(filename string) (retVal []NamedCtensor) {
func AtLoadCallback(filename string, dataPtr unsafe.Pointer) {
cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename))
var data Data = Data{NamedCtensors: make([]NamedCtensor, 0)}
dataPtr := pStore.Set(&data)
C.at_load_callback(cfilename, dataPtr, C.f(C.callback_fn))
data = *pStore.Get(dataPtr).(*Data)
retVal = data.NamedCtensors
fmt.Println(retVal)
return
}
type Data struct {
NamedCtensors []NamedCtensor
}
/*
* func (d *Data) Set(v NamedCtensor) {
* d.NamedCtensors = append(d.NamedCtensors, v)
* }
*
* func (d *Data) Get() []NamedCtensor {
* return d.NamedCtensors
* }
* */
type NamedCtensor struct {
Name string
Ctensor C.tensor
}
//TODO: move `callback_fn` to wrapper package???
//export callback_fn
func callback_fn(dataPtr unsafe.Pointer, name *C.char, ctensor C.tensor) {
// TODO: do something here
tsName := C.GoString(name)
fmt.Println(tsName)
namedCtensor := NamedCtensor{
Name: tsName,
Ctensor: ctensor,
}
data := pStore.Get(dataPtr).(*Data)
// data.Set(namedCtensor)
data := PStore.Get(dataPtr).(*LoadData)
data.NamedCtensors = append(data.NamedCtensors, namedCtensor)
}

View File

@ -732,12 +732,14 @@ func MustSaveMulti(namedTensors []NamedTensor, path string) {
// The file format is the same as the one used by the PyTorch C++ API.
func LoadMulti(path string) (retVal []NamedTensor, err error) {
data := lib.AtLoadCallback(path)
var data lib.LoadData
dataPtr := lib.PStore.Set(&data)
lib.AtLoadCallback(path, dataPtr)
if err = TorchErr(); err != nil {
return retVal, err
}
fmt.Println(data)
fmt.Println(data.NamedCtensors)
return retVal, nil
}