added more pickle types
This commit is contained in:
parent
7aa4e50199
commit
36b61bf719
|
@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
- Fixed wrong `cacheDir` and switch off logging.
|
- Fixed wrong `cacheDir` and switch off logging.
|
||||||
|
- Added more pickle classes to handle unpickling
|
||||||
|
|
||||||
## [Nofix]
|
## [Nofix]
|
||||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||||
|
|
|
@ -1,36 +1,32 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
"github.com/sugarme/gotch"
|
||||||
"github.com/sugarme/gotch/nn"
|
|
||||||
"github.com/sugarme/gotch/pickle"
|
"github.com/sugarme/gotch/pickle"
|
||||||
"github.com/sugarme/gotch/vision"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
device := gotch.CPU
|
// modelName := "vgg16"
|
||||||
vs := nn.NewVarStore(device)
|
// modelName := "mobilenet_v2"
|
||||||
net := vision.VGG16(vs.Root(), 1000)
|
// modelName := "resnet18"
|
||||||
|
// modelName := "alexnet"
|
||||||
|
// modelName := "squeezenet1_1"
|
||||||
|
// modelName := "inception_v3_google"
|
||||||
|
modelName := "efficientnet_b4"
|
||||||
|
|
||||||
modelName := "vgg16"
|
url, ok := gotch.ModelUrls[modelName]
|
||||||
modelUrl, ok := gotch.ModelUrls[modelName]
|
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Fatal("model name %q not found.", modelName)
|
log.Fatalf("Unsupported model name %q\n", modelName)
|
||||||
|
}
|
||||||
|
modelFile, err := gotch.CachedPath(url)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
modelFile, err := gotch.CachedPath(modelUrl)
|
err = pickle.LoadInfo(modelFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pickle.LoadAll(vs, modelFile)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("%v\n", net)
|
|
||||||
vs.Summary()
|
|
||||||
}
|
}
|
||||||
|
|
4
init.go
4
init.go
|
@ -12,12 +12,10 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// default path: {$HOME}/.cache/gotch
|
|
||||||
homeDir := os.Getenv("HOME")
|
homeDir := os.Getenv("HOME")
|
||||||
CacheDir = fmt.Sprintf("%s/.cache/gotch", homeDir)
|
CacheDir = fmt.Sprintf("%s/.cache/gotch", homeDir) // default dir: "{$HOME}/.cache/gotch"
|
||||||
|
|
||||||
initEnv()
|
initEnv()
|
||||||
|
|
||||||
// log.Printf("INFO: CacheDir=%q\n", CacheDir)
|
// log.Printf("INFO: CacheDir=%q\n", CacheDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -71,7 +71,15 @@ func Decode(filename string) (map[string]*ts.Tensor, error) {
|
||||||
stride := sx.Stride
|
stride := sx.Stride
|
||||||
storageOffset := sx.StorageOffset
|
storageOffset := sx.StorageOffset
|
||||||
|
|
||||||
// fmt.Printf("%q - shape: %v - stride: %v - storageOffset: %v\n", sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
|
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
|
||||||
|
// log.Printf("data: %v\n", data)
|
||||||
|
|
||||||
|
// Dealing with Pytorch `..._tracked` variables.
|
||||||
|
// TODO. should we just skip them?
|
||||||
|
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
|
||||||
|
size = []int64{1}
|
||||||
|
stride = []int64{1}
|
||||||
|
}
|
||||||
|
|
||||||
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
||||||
if sx.RequiresGrad {
|
if sx.RequiresGrad {
|
||||||
|
@ -373,6 +381,19 @@ func makePickleFindClass(fallback func(module, name string) (interface{}, error)
|
||||||
return &RebuildTensor{}, nil
|
return &RebuildTensor{}, nil
|
||||||
case "torch._utils._rebuild_tensor_v2":
|
case "torch._utils._rebuild_tensor_v2":
|
||||||
return &RebuildTensorV2{}, nil
|
return &RebuildTensorV2{}, nil
|
||||||
|
case "torch._utils._rebuild_parameter":
|
||||||
|
return &RebuildParameter{}, nil
|
||||||
|
case "torch._utils._sparse_tensor":
|
||||||
|
return &RebuildSparseTensor{}, nil
|
||||||
|
case "torch._utils._rebuild_sparse_csr_tensor":
|
||||||
|
return &RebuildSparseCsrTensor{}, nil
|
||||||
|
case "torch._utils._rebuild_device_tensor_from_numpy":
|
||||||
|
return &RebuildDeviceTensorFromNumpy{}, nil
|
||||||
|
case "torch._utils._rebuild_meta_tensor_no_storage":
|
||||||
|
return &RebuildMetaTensorNoStorage{}, nil
|
||||||
|
case "torch._utils._rebuild_qtensor":
|
||||||
|
return &RebuildQtensor{}, nil
|
||||||
|
|
||||||
case "torch.FloatStorage":
|
case "torch.FloatStorage":
|
||||||
return &FloatStorageClass{}, nil
|
return &FloatStorageClass{}, nil
|
||||||
case "torch.HalfStorage":
|
case "torch.HalfStorage":
|
||||||
|
|
|
@ -653,6 +653,28 @@ func (r *RebuildTensorV2) Call(args ...interface{}) (interface{}, error) {
|
||||||
return tensor, nil
|
return tensor, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rebuild Parameter:
|
||||||
|
// ==================
|
||||||
|
// RebuildTensor represents a struct to rebuild tensor back from pickle object.
|
||||||
|
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L240
|
||||||
|
type RebuildParameter struct{}
|
||||||
|
|
||||||
|
var _ Callable = &RebuildParameter{}
|
||||||
|
|
||||||
|
func (r *RebuildParameter) Call(args ...interface{}) (interface{}, error) {
|
||||||
|
if len(args) != 3 { // data(*StorageTensor), requires_grad, backward_hooks
|
||||||
|
return nil, fmt.Errorf("RebuildParameter unexpected 3 args, got %d: %#v", len(args), args)
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor, ok := args[0].(*StorageTensor)
|
||||||
|
if !ok {
|
||||||
|
err := fmt.Errorf("RebuildParameter.Call() failed: unexpected arg: %#v\n", args)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tensor, nil
|
||||||
|
}
|
||||||
|
|
||||||
func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
|
func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
|
||||||
length := tuple.Len()
|
length := tuple.Len()
|
||||||
slice := make([]int64, length)
|
slice := make([]int64, length)
|
||||||
|
@ -665,3 +687,63 @@ func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
|
||||||
}
|
}
|
||||||
return slice, nil
|
return slice, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rebuild Sparse Tensor:
|
||||||
|
//=======================
|
||||||
|
// ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L178
|
||||||
|
type RebuildSparseTensor struct{}
|
||||||
|
|
||||||
|
var _ Callable = &RebuildSparseTensor{}
|
||||||
|
|
||||||
|
func (r *RebuildSparseTensor) Call(args ...interface{}) (interface{}, error) {
|
||||||
|
// TODO.
|
||||||
|
panic("RebuildSparseTensor.Call(): NotImplementedError")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rebuild Sparse CSR Tensor:
|
||||||
|
// ==========================
|
||||||
|
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L187
|
||||||
|
type RebuildSparseCsrTensor struct{}
|
||||||
|
|
||||||
|
var _ Callable = &RebuildSparseCsrTensor{}
|
||||||
|
|
||||||
|
func (r *RebuildSparseCsrTensor) Call(args ...interface{}) (interface{}, error) {
|
||||||
|
// TODO.
|
||||||
|
panic("RebuildSparseCsrTensor.Call(): NotImplementedError")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rebuild Device Tensor From Numpy:
|
||||||
|
// =================================
|
||||||
|
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L197
|
||||||
|
type RebuildDeviceTensorFromNumpy struct{}
|
||||||
|
|
||||||
|
var _ Callable = &RebuildDeviceTensorFromNumpy{}
|
||||||
|
|
||||||
|
func (r *RebuildDeviceTensorFromNumpy) Call(args ...interface{}) (interface{}, error) {
|
||||||
|
// TODO.
|
||||||
|
panic("RebuildDeviceTensorFromNumpy.Call(): NotImplementedError")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rebuild Meta Tensor No Storage:
|
||||||
|
// ===============================
|
||||||
|
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L208
|
||||||
|
type RebuildMetaTensorNoStorage struct{}
|
||||||
|
|
||||||
|
var _ Callable = &RebuildMetaTensorNoStorage{}
|
||||||
|
|
||||||
|
func (r *RebuildMetaTensorNoStorage) Call(args ...interface{}) (interface{}, error) {
|
||||||
|
// TODO.
|
||||||
|
panic("RebuildMetaTensorNoStorage.Call(): NotImplementedError")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rebuild QTensor:
|
||||||
|
// ================
|
||||||
|
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L214
|
||||||
|
type RebuildQtensor struct{}
|
||||||
|
|
||||||
|
var _ Callable = &RebuildQtensor{}
|
||||||
|
|
||||||
|
func (r *RebuildQtensor) Call(args ...interface{}) (interface{}, error) {
|
||||||
|
// TODO.
|
||||||
|
panic("RebuildQtensor.Call(): NotImplementedError")
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user