feat(pointer-store): moved pointer-store.go to libtch package
This commit is contained in:
parent
13f18878ae
commit
a3f2b4c2c9
|
@ -9,8 +9,6 @@ package libtch
|
|||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
// "time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
|
@ -18,6 +16,17 @@ import (
|
|||
type Ctensor = C.tensor
|
||||
type Cscalar = C.scalar
|
||||
|
||||
type NamedCtensor struct {
|
||||
Name string
|
||||
Ctensor C.tensor
|
||||
}
|
||||
|
||||
type LoadData struct {
|
||||
NamedCtensors []NamedCtensor
|
||||
}
|
||||
|
||||
var PStore = NewPointerStore()
|
||||
|
||||
func AtNewTensor() Ctensor {
|
||||
return C.at_new_tensor()
|
||||
}
|
||||
|
@ -228,9 +237,6 @@ func AtLoadMulti(tensors []Ctensor, tensor_names []string, ntensors int, filenam
|
|||
// TODO: implement this
|
||||
}
|
||||
|
||||
var pStore PointerStore = NewPointerStore()
|
||||
var namedCtensors []NamedCtensor = make([]NamedCtensor, 0)
|
||||
|
||||
// void at_load_callback(char *filename, void *data, void (*f)(void *, char *, tensor));
|
||||
/*
|
||||
* void at_load_callback(char *filename, void *data, void (*f)(void *, char *, tensor)) {
|
||||
|
@ -243,55 +249,22 @@ var namedCtensors []NamedCtensor = make([]NamedCtensor, 0)
|
|||
* )
|
||||
* }
|
||||
* */
|
||||
// func AtLoadCallback(filename string, data unsafe.Pointer, callbackFn *func(unsafe.Pointer, *C.char, C.tensor)) {
|
||||
func AtLoadCallback(filename string) (retVal []NamedCtensor) {
|
||||
|
||||
func AtLoadCallback(filename string, dataPtr unsafe.Pointer) {
|
||||
cfilename := C.CString(filename)
|
||||
defer C.free(unsafe.Pointer(cfilename))
|
||||
|
||||
var data Data = Data{NamedCtensors: make([]NamedCtensor, 0)}
|
||||
dataPtr := pStore.Set(&data)
|
||||
|
||||
C.at_load_callback(cfilename, dataPtr, C.f(C.callback_fn))
|
||||
|
||||
data = *pStore.Get(dataPtr).(*Data)
|
||||
|
||||
retVal = data.NamedCtensors
|
||||
fmt.Println(retVal)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type Data struct {
|
||||
NamedCtensors []NamedCtensor
|
||||
}
|
||||
|
||||
/*
|
||||
* func (d *Data) Set(v NamedCtensor) {
|
||||
* d.NamedCtensors = append(d.NamedCtensors, v)
|
||||
* }
|
||||
*
|
||||
* func (d *Data) Get() []NamedCtensor {
|
||||
* return d.NamedCtensors
|
||||
* }
|
||||
* */
|
||||
|
||||
type NamedCtensor struct {
|
||||
Name string
|
||||
Ctensor C.tensor
|
||||
}
|
||||
//TODO: move `callback_fn` to wrapper package???
|
||||
|
||||
//export callback_fn
|
||||
func callback_fn(dataPtr unsafe.Pointer, name *C.char, ctensor C.tensor) {
|
||||
// TODO: do something here
|
||||
tsName := C.GoString(name)
|
||||
fmt.Println(tsName)
|
||||
namedCtensor := NamedCtensor{
|
||||
Name: tsName,
|
||||
Ctensor: ctensor,
|
||||
}
|
||||
|
||||
data := pStore.Get(dataPtr).(*Data)
|
||||
// data.Set(namedCtensor)
|
||||
data := PStore.Get(dataPtr).(*LoadData)
|
||||
data.NamedCtensors = append(data.NamedCtensors, namedCtensor)
|
||||
}
|
||||
|
|
|
@ -732,12 +732,14 @@ func MustSaveMulti(namedTensors []NamedTensor, path string) {
|
|||
// The file format is the same as the one used by the PyTorch C++ API.
|
||||
func LoadMulti(path string) (retVal []NamedTensor, err error) {
|
||||
|
||||
data := lib.AtLoadCallback(path)
|
||||
var data lib.LoadData
|
||||
dataPtr := lib.PStore.Set(&data)
|
||||
lib.AtLoadCallback(path, dataPtr)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
fmt.Println(data)
|
||||
fmt.Println(data.NamedCtensors)
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user