From 36b61bf719008c2c5fbe3ea68512307da91ec6d5 Mon Sep 17 00:00:00 2001 From: sugarme Date: Thu, 24 Feb 2022 15:25:15 +1100 Subject: [PATCH] added more pickle types --- CHANGELOG.md | 1 + example/pickle/main.go | 32 +++++++--------- init.go | 4 +- pickle/serialization.go | 23 +++++++++++- pickle/storage.go | 82 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 120 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b1ee5fc..2b43340 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - Fixed wrong `cacheDir` and switch off logging. +- Added more pickle classes to handle unpickling ## [Nofix] - ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box. diff --git a/example/pickle/main.go b/example/pickle/main.go index e12d25d..bec4742 100644 --- a/example/pickle/main.go +++ b/example/pickle/main.go @@ -1,36 +1,32 @@ package main import ( - "fmt" "log" "github.com/sugarme/gotch" - "github.com/sugarme/gotch/nn" "github.com/sugarme/gotch/pickle" - "github.com/sugarme/gotch/vision" ) func main() { - device := gotch.CPU - vs := nn.NewVarStore(device) - net := vision.VGG16(vs.Root(), 1000) + // modelName := "vgg16" + // modelName := "mobilenet_v2" + // modelName := "resnet18" + // modelName := "alexnet" + // modelName := "squeezenet1_1" + // modelName := "inception_v3_google" + modelName := "efficientnet_b4" - modelName := "vgg16" - modelUrl, ok := gotch.ModelUrls[modelName] + url, ok := gotch.ModelUrls[modelName] if !ok { - log.Fatal("model name %q not found.", modelName) + log.Fatalf("Unsupported model name %q\n", modelName) + } + modelFile, err := gotch.CachedPath(url) + if err != nil { + panic(err) } - modelFile, err := gotch.CachedPath(modelUrl) + err = pickle.LoadInfo(modelFile) if err != nil { log.Fatal(err) } - - err = pickle.LoadAll(vs, modelFile) - if err != nil { - log.Fatal(err) - } - - fmt.Printf("%v\n", net) - vs.Summary() } diff --git a/init.go b/init.go index 0b00054..1b2f0d9 100644 --- a/init.go +++ b/init.go @@ -12,12 +12,10 @@ var ( ) func init() { - // default path: {$HOME}/.cache/gotch homeDir := os.Getenv("HOME") - CacheDir = fmt.Sprintf("%s/.cache/gotch", homeDir) + CacheDir = fmt.Sprintf("%s/.cache/gotch", homeDir) // default dir: "{$HOME}/.cache/gotch" initEnv() - // log.Printf("INFO: CacheDir=%q\n", CacheDir) } diff --git a/pickle/serialization.go b/pickle/serialization.go index 566d4fb..8ca6278 100644 --- a/pickle/serialization.go +++ b/pickle/serialization.go @@ -71,7 +71,15 @@ func Decode(filename string) (map[string]*ts.Tensor, error) { stride := sx.Stride storageOffset := sx.StorageOffset - // fmt.Printf("%q - shape: %v - stride: %v - storageOffset: %v\n", sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset) + // log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset) + // log.Printf("data: %v\n", data) + + // Dealing with Pytorch `..._tracked` variables. + // TODO. should we just skip them? + if reflect.ValueOf(data).Len() == 1 && len(size) == 0 { + size = []int64{1} + stride = []int64{1} + } x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true) if sx.RequiresGrad { @@ -373,6 +381,19 @@ func makePickleFindClass(fallback func(module, name string) (interface{}, error) return &RebuildTensor{}, nil case "torch._utils._rebuild_tensor_v2": return &RebuildTensorV2{}, nil + case "torch._utils._rebuild_parameter": + return &RebuildParameter{}, nil + case "torch._utils._sparse_tensor": + return &RebuildSparseTensor{}, nil + case "torch._utils._rebuild_sparse_csr_tensor": + return &RebuildSparseCsrTensor{}, nil + case "torch._utils._rebuild_device_tensor_from_numpy": + return &RebuildDeviceTensorFromNumpy{}, nil + case "torch._utils._rebuild_meta_tensor_no_storage": + return &RebuildMetaTensorNoStorage{}, nil + case "torch._utils._rebuild_qtensor": + return &RebuildQtensor{}, nil + case "torch.FloatStorage": return &FloatStorageClass{}, nil case "torch.HalfStorage": diff --git a/pickle/storage.go b/pickle/storage.go index 5a58bf5..0374fbb 100644 --- a/pickle/storage.go +++ b/pickle/storage.go @@ -653,6 +653,28 @@ func (r *RebuildTensorV2) Call(args ...interface{}) (interface{}, error) { return tensor, nil } +// Rebuild Parameter: +// ================== +// RebuildTensor represents a struct to rebuild tensor back from pickle object. +// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L240 +type RebuildParameter struct{} + +var _ Callable = &RebuildParameter{} + +func (r *RebuildParameter) Call(args ...interface{}) (interface{}, error) { + if len(args) != 3 { // data(*StorageTensor), requires_grad, backward_hooks + return nil, fmt.Errorf("RebuildParameter unexpected 3 args, got %d: %#v", len(args), args) + } + + tensor, ok := args[0].(*StorageTensor) + if !ok { + err := fmt.Errorf("RebuildParameter.Call() failed: unexpected arg: %#v\n", args) + return nil, err + } + + return tensor, nil +} + func tupleToInt64Slice(tuple *Tuple) ([]int64, error) { length := tuple.Len() slice := make([]int64, length) @@ -665,3 +687,63 @@ func tupleToInt64Slice(tuple *Tuple) ([]int64, error) { } return slice, nil } + +// Rebuild Sparse Tensor: +//======================= +// ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L178 +type RebuildSparseTensor struct{} + +var _ Callable = &RebuildSparseTensor{} + +func (r *RebuildSparseTensor) Call(args ...interface{}) (interface{}, error) { + // TODO. + panic("RebuildSparseTensor.Call(): NotImplementedError") +} + +// Rebuild Sparse CSR Tensor: +// ========================== +// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L187 +type RebuildSparseCsrTensor struct{} + +var _ Callable = &RebuildSparseCsrTensor{} + +func (r *RebuildSparseCsrTensor) Call(args ...interface{}) (interface{}, error) { + // TODO. + panic("RebuildSparseCsrTensor.Call(): NotImplementedError") +} + +// Rebuild Device Tensor From Numpy: +// ================================= +// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L197 +type RebuildDeviceTensorFromNumpy struct{} + +var _ Callable = &RebuildDeviceTensorFromNumpy{} + +func (r *RebuildDeviceTensorFromNumpy) Call(args ...interface{}) (interface{}, error) { + // TODO. + panic("RebuildDeviceTensorFromNumpy.Call(): NotImplementedError") +} + +// Rebuild Meta Tensor No Storage: +// =============================== +// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L208 +type RebuildMetaTensorNoStorage struct{} + +var _ Callable = &RebuildMetaTensorNoStorage{} + +func (r *RebuildMetaTensorNoStorage) Call(args ...interface{}) (interface{}, error) { + // TODO. + panic("RebuildMetaTensorNoStorage.Call(): NotImplementedError") +} + +// Rebuild QTensor: +// ================ +// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L214 +type RebuildQtensor struct{} + +var _ Callable = &RebuildQtensor{} + +func (r *RebuildQtensor) Call(args ...interface{}) (interface{}, error) { + // TODO. + panic("RebuildQtensor.Call(): NotImplementedError") +}