From dc4ab3047ca5609e9b917db8bfd897586c264e4d Mon Sep 17 00:00:00 2001 From: sugarme Date: Thu, 24 Feb 2022 12:43:39 +1100 Subject: [PATCH] added subpackage 'pickle' and file-util --- CHANGELOG.md | 2 + example/pickle/main.go | 36 + file-util.go | 289 +++++ init.go | 35 + pickle/pickle.go | 1937 +++++++++++++++++++++++++++++++++ pickle/pickle_example_test.go | 60 + pickle/serialization.go | 519 +++++++++ pickle/storage.go | 667 ++++++++++++ pickle/type.go | 519 +++++++++ pickle/util.go | 112 ++ 10 files changed, 4176 insertions(+) create mode 100644 example/pickle/main.go create mode 100644 file-util.go create mode 100644 init.go create mode 100644 pickle/pickle.go create mode 100644 pickle/pickle_example_test.go create mode 100644 pickle/serialization.go create mode 100644 pickle/storage.go create mode 100644 pickle/type.go create mode 100644 pickle/util.go diff --git a/CHANGELOG.md b/CHANGELOG.md index ac0c78d..22a1efc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +- Added subpackage `pickle`. Now we can load directly Python Pytorch pretrained model without any Python script conversion. +- Added `gotch.CachePath()` and `gotch.ModelUrls` - Remove Travis CI for now. - fixed `tensor.OfSlice()` throw error due to "Unsupported Go type" (e.g. []float32) - added `nn.Path.Paths()` method diff --git a/example/pickle/main.go b/example/pickle/main.go new file mode 100644 index 0000000..e12d25d --- /dev/null +++ b/example/pickle/main.go @@ -0,0 +1,36 @@ +package main + +import ( + "fmt" + "log" + + "github.com/sugarme/gotch" + "github.com/sugarme/gotch/nn" + "github.com/sugarme/gotch/pickle" + "github.com/sugarme/gotch/vision" +) + +func main() { + device := gotch.CPU + vs := nn.NewVarStore(device) + net := vision.VGG16(vs.Root(), 1000) + + modelName := "vgg16" + modelUrl, ok := gotch.ModelUrls[modelName] + if !ok { + log.Fatal("model name %q not found.", modelName) + } + + modelFile, err := gotch.CachedPath(modelUrl) + if err != nil { + log.Fatal(err) + } + + err = pickle.LoadAll(vs, modelFile) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%v\n", net) + vs.Summary() +} diff --git a/file-util.go b/file-util.go new file mode 100644 index 0000000..9fad808 --- /dev/null +++ b/file-util.go @@ -0,0 +1,289 @@ +package gotch + +import ( + "fmt" + "io" + "log" + "net/http" + "os" + "path" + "strconv" + "strings" +) + +// This file provides functions to work with local dataset cache, ... + +// ModelUrls maps model name to its pretrained URL. +// +// This URLS taken from separate models in pytorch/vision repository +// https://github.com/pytorch/vision/tree/main/torchvision/models +var ModelUrls map[string]string = map[string]string{ + "alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", + + "convnext_tiny": "https://download.pytorch.org/models/convnext_tiny-983f1562.pth", + "convnext_small": "https://download.pytorch.org/models/convnext_small-0c510722.pth", + "convnext_base": "https://download.pytorch.org/models/convnext_base-6075fbad.pth", + "convnext_large": "https://download.pytorch.org/models/convnext_large-ea097f82.pth", + + "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", + "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", + "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", + "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth", + + //Weights ported from https://github.com/rwightman/pytorch-image-models/ + "efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", + "efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", + "efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", + "efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", + "efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", + //Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/ + "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", + "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", + "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", + + //GoogLeNet ported from TensorFlow + "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth", + + //Inception v3 ported from TensorFlow + "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", + + "mnasnet0_5": "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", + "mnasnet0_75": "", + "mnasnet1_0": "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", + "mnasnet1_3": "", + + "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", + "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", + "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", + + "regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", + "regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", + "regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", + "regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", + "regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", + "regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", + "regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", + "regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", + "regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", + "regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", + "regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", + "regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", + "regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", + "regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", + + "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth", + "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth", + "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth", + "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth", + "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth", + "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", + "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", + + "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", + "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", + "shufflenetv2_x1.5": "", + "shufflenetv2_x2.0": "", + + "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", + "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", + + "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth", + "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth", + "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", + "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", + "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", + "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", + "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", + "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", + + "vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth", + "vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", + "vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", + "vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth", +} + +// CachedPath resolves and caches data based on input string, then returns fullpath to the cached data. +// +// Parameters: +// - `filenameOrUrl`: full path to filename or url +// +// CachedPath does several things consequently: +// 1. Resolves input string to a fullpath cached filename candidate. +// 2. Check it at `CachePath`, if exists, then return the candidate. If not +// 3. Retrieves and Caches data to `CachePath` and returns path to cached data +func CachedPath(filenameOrUrl string) (resolvedPath string, err error) { + filename := path.Base(filenameOrUrl) + // Resolves to "candidate" filename at `CacheDir` + cachedFileCandidate := fmt.Sprintf("%s/%s", CacheDir, filename) + + // 1. Cached candidate file exists + if _, err := os.Stat(cachedFileCandidate); err == nil { + return cachedFileCandidate, nil + } + + // 2. If valid fullpath to local file, caches it and return cached filename + if _, err := os.Stat(filenameOrUrl); err == nil { + err := copyFile(filenameOrUrl, cachedFileCandidate) + if err != nil { + return "", err + } + return cachedFileCandidate, nil + } + + // 3. Cached candidate file NOT exist. Try to download it and save to `CacheDir` + if isValidURL(filenameOrUrl) { + if _, err := http.Get(filenameOrUrl); err == nil { + err := downloadFile(filenameOrUrl, cachedFileCandidate) + if err != nil { + return "", err + } + + return cachedFileCandidate, nil + } else { + fmt.Printf("Error: %v\n", err) + err = fmt.Errorf("Unable to parse %q as a URL or as a local path.\n", filenameOrUrl) + return "", err + } + } + + // Not resolves + err = fmt.Errorf("Unable to parse %q as a URL or as a local path.\n", filenameOrUrl) + return "", err +} + +func isValidURL(url string) bool { + + // TODO: implement + return true +} + +// downloadFile downloads file from URL and stores it in local filepath. +// It writes to the destination file as it downloads it, without loading +// the entire file into memory. An `io.TeeReader` is passed into Copy() +// to report progress on the download. +func downloadFile(url string, filepath string) error { + // Create path if not existing + dir := path.Dir(filepath) + filename := path.Base(filepath) + if _, err := os.Stat(dir); os.IsNotExist(err) { + if err := os.MkdirAll(dir, 0755); err != nil { + log.Fatal(err) + } + } + + // Create the file with .tmp extension, so that we won't overwrite a + // file until it's downloaded fully + out, err := os.Create(filepath + ".tmp") + if err != nil { + return err + } + defer out.Close() + + // Get the data + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + // Check server response + if resp.StatusCode != http.StatusOK { + err := fmt.Errorf("bad status: %s(%v)", resp.Status, resp.StatusCode) + if resp.StatusCode == 404 { + err = fmt.Errorf("download file not found: %q for downloading", url) + } else { + err = fmt.Errorf("download file failed: %q", url) + } + return err + } + + // the total file size to download + size, _ := strconv.Atoi(resp.Header.Get("Content-Length")) + downloadSize := uint64(size) + + // Create our bytes counter and pass it to be used alongside our writer + counter := &writeCounter{FileSize: downloadSize} + _, err = io.Copy(out, io.TeeReader(resp.Body, counter)) + if err != nil { + return err + } + + fmt.Printf("\r%s... %s/%s completed", filename, byteCountIEC(counter.Total), byteCountIEC(counter.FileSize)) + // The progress use the same line so print a new line once it's finished downloading + fmt.Println() + + // Rename the tmp file back to the original file + err = os.Rename(filepath+".tmp", filepath) + if err != nil { + return err + } + + return nil +} + +// writeCounter counts the number of bytes written to it. By implementing the Write method, +// it is of the io.Writer interface and we can pass this into io.TeeReader() +// Every write to this writer, will print the progress of the file write. +type writeCounter struct { + Total uint64 + FileSize uint64 +} + +func (wc *writeCounter) Write(p []byte) (int, error) { + n := len(p) + wc.Total += uint64(n) + wc.printProgress() + return n, nil +} + +// PrintProgress prints the progress of a file write +func (wc writeCounter) printProgress() { + // Clear the line by using a character return to go back to the start and remove + // the remaining characters by filling it with spaces + fmt.Printf("\r%s", strings.Repeat(" ", 50)) + + // Return again and print current status of download + fmt.Printf("\rDownloading... %s/%s", byteCountIEC(wc.Total), byteCountIEC(wc.FileSize)) +} + +// byteCountIEC converts bytes to human-readable string in binary (IEC) format. +func byteCountIEC(b uint64) string { + const unit = 1024 + if b < unit { + return fmt.Sprintf("%d B", b) + } + div, exp := uint64(unit), 0 + for n := b / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", + float64(b)/float64(div), "KMGTPE"[exp]) +} + +func copyFile(src, dst string) error { + sourceFileStat, err := os.Stat(src) + if err != nil { + return err + } + + if !sourceFileStat.Mode().IsRegular() { + return fmt.Errorf("%s is not a regular file", src) + } + + source, err := os.Open(src) + if err != nil { + return err + } + defer source.Close() + + destination, err := os.Create(dst) + if err != nil { + return err + } + defer destination.Close() + _, err = io.Copy(destination, source) + return err +} diff --git a/init.go b/init.go new file mode 100644 index 0000000..ffdc39a --- /dev/null +++ b/init.go @@ -0,0 +1,35 @@ +package gotch + +import ( + "fmt" + "log" + "os" +) + +var ( + CacheDir string = "NOT_SETTING" + gotchEnvKey string = "GOTCH_CACHE" +) + +func init() { + // default path: {$HOME}/.cache/gotch + homeDir := os.Getenv("HOME") + CacheDir = fmt.Sprintf("%s/.cache/transformer", homeDir) + + initEnv() + + log.Printf("INFO: CacheDir=%q\n", CacheDir) +} + +func initEnv() { + val := os.Getenv(gotchEnvKey) + if val != "" { + CacheDir = val + } + + if _, err := os.Stat(CacheDir); os.IsNotExist(err) { + if err := os.MkdirAll(CacheDir, 0755); err != nil { + log.Fatal(err) + } + } +} diff --git a/pickle/pickle.go b/pickle/pickle.go new file mode 100644 index 0000000..4bd5725 --- /dev/null +++ b/pickle/pickle.go @@ -0,0 +1,1937 @@ +package pickle + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math" + "math/big" + "os" + "strconv" + "strings" +) + +// This file implements Python Pickle Machinery. +// +// Pickle creates portable serialized representations of Python objects. +// See module `copyreg` for a mechanism for registering custom picklers +// See module `pickletools` source for extensive comments +// Ref. https://github.com/python/cpython/blob/main/Lib/pickle.py + +// See more:... +// https://docs.python.org/3/library/pickle.html +// https://docs.python.org/3/library/pickletools.html +// https://pytorch.org/tutorials/beginner/saving_loading_models.html +// https://github.com/pytorch/pytorch/blob/master/torch/serialization.py + +// Pikle Version: +// ============== +// FormatVersion = "4.0" +// CompatibleFormats [ +// "1.0": Original protocol 0 +// "1.1": Protocol 0 with INST added +// "1.2": Original protocol 1 +// "1.3": Protocol 1 with BINFLOAT added +// "2.0": Protocol 2 +// "3.0": Protocol 3 +// "4.0": Protocol 4 +// "5.0": Protocol 5 +// ] + +const HighestProtocol byte = 5 // The highest protocol number pickle currently knows how to read +var DefaultProtocol byte = 4 // The protocol pickle currently used to write by default. + +// Error formatter: +// ================ +func pickleError(msg string) error { + err := fmt.Errorf("PicklingError: %s", msg) + return err +} + +func picklingError(msg string) error { + err := fmt.Errorf("Unpickable Object: %s", msg) + return err +} + +func unpicklingError(msg string) error { + err := fmt.Errorf("UnpicklingError: %s", msg) + return err +} + +// Stop implements error interface. It is raised by `Unpickler.LoadStop()` +// in response to the STOP opcode, passing the object that is the result of unpickling. +type Stop struct { + value interface{} // TODO. specific type +} + +func newStop(value interface{}) Stop { return Stop{value} } +func (s Stop) Error() string { return "STOP" } + +var _ error = Stop{} + +// Pickle opcodes: +// ============== +// See pickletools.py for extensive docs. + +var ( + MARK rune = '(' // push special markobject on stack + STOP rune = '.' // every pickle ends with STOP + POP rune = '0' // discard topmost stack item + POP_MARK rune = '1' // discard stack top through topmost markobject + DUP rune = '2' // duplicate top stack item + FLOAT rune = 'F' // push float object; decimal string argument + INT rune = 'I' // push integer or bool; decimal string argument + BININT rune = 'J' // push four-byte signed int + BININT1 rune = 'K' // push 1-byte unsigned int + LONG rune = 'L' // push long; decimal string argument + BININT2 rune = 'M' // push 2-byte unsigned int + NONE rune = 'N' // push None + PERSID rune = 'P' // push persistent object; id is taken from string arg + BINPERSID rune = 'Q' // " " " ; " " " " stack + REDUCE rune = 'R' // apply callable to argtuple, both on stack + STRING rune = 'S' // push string; NL-terminated string argument + BINSTRING rune = 'T' // push string; counted binary string argument + SHORT_BINSTRING rune = 'U' // " " ; " " " " < 256 bytes + UNICODE rune = 'V' // push Unicode string; raw-unicode-escaped'd argument + BINUNICODE rune = 'X' // " " " ; counted UTF-8 string argument + APPEND rune = 'a' // append stack top to list below it + BUILD rune = 'b' // call __setstate__ or __dict__.update() + GLOBAL rune = 'c' // push self.find_class(modname, name); 2 string args + DICT rune = 'd' // build a dict from stack items + EMPTY_DICT rune = '}' // push empty dict + APPENDS rune = 'e' // extend list on stack by topmost stack slice + GET rune = 'g' // push item from memo on stack; index is string arg + BINGET rune = 'h' // " " " " " " ; " " 1-byte arg + INST rune = 'i' // build & push class instance + LONG_BINGET rune = 'j' // push item from memo on stack; index is 4-byte arg + LIST rune = 'l' // build list from topmost stack items + EMPTY_LIST rune = ']' // push empty list + OBJ rune = 'o' // build & push class instance + PUT rune = 'p' // store stack top in memo; index is string arg + BINPUT rune = 'q' // " " " " " ; " " 1-byte arg + LONG_BINPUT rune = 'r' // " " " " " ; " " 4-byte arg + SETITEM rune = 's' // add key+value pair to dict + TUPLE rune = 't' // build tuple from topmost stack items + EMPTY_TUPLE rune = ')' // push empty tuple + SETITEMS rune = 'u' // modify dict by adding topmost key+value pairs + BINFLOAT rune = 'G' // push float; arg is 8-byte float encoding + + // TRUE rune = 'I01\n' // not an opcode; see INT docs in pickletools.py + // FALSE rune = 'I00\n' // not an opcode; see INT docs in pickletools.py + + // Protocol 2 + + PROTO rune = '\x80' // identify pickle protocol + NEWOBJ rune = '\x81' // build object by applying cls.__new__ to argtuple + EXT1 rune = '\x82' // push object from extension registry; 1-byte index + EXT2 rune = '\x83' // ditto, but 2-byte index + EXT4 rune = '\x84' // ditto, but 4-byte index + TUPLE1 rune = '\x85' // build 1-tuple from stack top + TUPLE2 rune = '\x86' // build 2-tuple from two topmost stack items + TUPLE3 rune = '\x87' // build 3-tuple from three topmost stack items + NEWTRUE rune = '\x88' // push True + NEWFALSE rune = '\x89' // push False + LONG1 rune = '\x8a' // push long from < 256 bytes + LONG4 rune = '\x8b' // push really big long + + tuplesize2code []rune = []rune{EMPTY_TUPLE, TUPLE1, TUPLE2, TUPLE3} + + // Protocol 3 (Python 3.x) + + BINBYTES rune = 'B' // push bytes; counted binary string argument + SHORT_BINBYTES rune = 'C' // " " ; " " " " < 256 bytes + + // Protocol 4 + + SHORT_BINUNICODE rune = '\x8c' // push short string; UTF-8 length < 256 bytes + BINUNICODE8 rune = '\x8d' // push very long string + BINBYTES8 rune = '\x8e' // push very long bytes string + EMPTY_SET rune = '\x8f' // push empty set on the stack + ADDITEMS rune = '\x90' // modify set by adding topmost stack items + FROZENSET rune = '\x91' // build frozenset from topmost stack items + NEWOBJ_EX rune = '\x92' // like NEWOBJ but work with keyword only arguments + STACK_GLOBAL rune = '\x93' // same as GLOBAL but using names on the stacks + MEMOIZE rune = '\x94' // store top of the stack in memo + FRAME rune = '\x95' // indicate the beginning of a new frame + + // Protocol 5 + + BYTEARRAY8 rune = '\x96' // push bytearray + NEXT_BUFFER rune = '\x97' // push next out-of-band buffer + READONLY_BUFFER rune = '\x98' // make top of stack readonly +) + +// Unpickling Machinery: +// ===================== + +type Unpickler struct { + proto byte // protocol version of the pickle + reader io.Reader // binary file reader + currentFrame *bytes.Reader // buffer frame reader + stack []interface{} // keeps marked objects + metaStack [][]interface{} // keeps stacks of marked objects + + // data structure that remembers which objects the pickler/unpickler has already seen + // so that shared or recursive objects are pickled/unpickled by reference and not by value + // This property is useful when re-using picklers/unpicklers. + memo map[int]interface{} + + FindClass func(module, name string) (interface{}, error) // function to determine data type + PersistentLoad func(interface{}) (interface{}, error) // function how to load pickled objects by its id. + + GetExtension func(code int) (interface{}, error) + NextBufferFunc func() (interface{}, error) + MakeReadOnlyFunc func(interface{}) (interface{}, error) +} + +// NewUnpickler creates a new Unpickler. +func NewUnpickler(r io.Reader) Unpickler { + return Unpickler{ + reader: r, + memo: make(map[int]interface{}, 0), + } +} + +// read reads n bytes from reader. +func (up *Unpickler) read(n int) ([]byte, error) { + data := make([]byte, n) + if up.currentFrame != nil { + nbytes, err := io.ReadFull(up.currentFrame, data) + + switch { + case err != nil && err != io.EOF && err != io.ErrUnexpectedEOF: + return nil, err + + case nbytes == 0 && n != 0: // remaining data + up.currentFrame = nil + nbytes, err := io.ReadFull(up.reader, data) + return data[0:nbytes], err + + case nbytes < n: + err := fmt.Errorf("Unpickler.read() failed: pickle exhausted before end of frame") + return nil, err + + default: + return data[0:nbytes], nil + } + } + + nbytes, err := io.ReadFull(up.reader, data) + return data[0:nbytes], err +} + +// readOne reads 1 byte. +func (up *Unpickler) readOne() (byte, error) { + data, err := up.read(1) + if err != nil { + return 0, err + } + + return data[0], nil +} + +// readLine reads one line of data. +func (up *Unpickler) readLine() ([]byte, error) { + if up.currentFrame != nil { + line, err := readLine(up.currentFrame) + if err != nil { + if err == io.EOF && len(line) == 0 { + up.currentFrame = nil + return readLine(up.reader) + } + + return nil, err + } + + if len(line) == 0 { + err := fmt.Errorf("Unpickler.readLine() failed: no data.") + return nil, err + } + if line[len(line)-1] != '\n' { + err := fmt.Errorf("Unpickler.readLine() failed: pickle exhausted before end of frame.") + return nil, err + } + + return line, nil + } + + return readLine(up.reader) +} + +// readLine reads one line of data. Line ends by '\n' byte. +func readLine(r io.Reader) ([]byte, error) { + bufferSize := 64 // just set buffer line = 64. One might change it. + line := make([]byte, 0, bufferSize) + + buf := make([]byte, 1) + for { + nbytes, err := r.Read(buf) + + if nbytes != 1 { + return line, err + } + + line = append(line, buf[0]) + if buf[0] == '\n' || err != nil { + return line, err + } + } +} + +// loadFrame loads new data to currentFrame. It throws error if currentFrame is not empty. +func (up *Unpickler) loadFrame(frameSize int) error { + buf := make([]byte, frameSize) + // Throw error if current frame is not empty + if up.currentFrame != nil { + nbytes, err := up.currentFrame.Read(buf) + if nbytes > 0 || err == nil { + err := unpicklingError("beginning of a new frame before end of a current frame") + return err + } + } + + // now, load data to currentFrame + _, err := io.ReadFull(up.reader, buf) + if err != nil { + return err + } + up.currentFrame = bytes.NewReader(buf) + + return nil +} + +// append appends an object to stack. +func (up *Unpickler) append(obj interface{}) { + up.stack = append(up.stack, obj) +} + +// stackPop pops an object out of stack. +func (up *Unpickler) stackPop() (interface{}, error) { + obj, err := up.stackLast() + if err != nil { + return nil, err + } + + up.stack = up.stack[:len(up.stack)-1] + + return obj, nil +} + +// stackLast get last object in stack. +func (up *Unpickler) stackLast() (interface{}, error) { + if len(up.stack) == 0 { + err := fmt.Errorf("Unpickler.stackLast() failed: stack is empty.") + return nil, err + } + + last := up.stack[len(up.stack)-1] + + return last, nil +} + +// metaStackPop pop a stack out from metaStack. +func (up *Unpickler) metaStackPop() ([]interface{}, error) { + stack, err := up.metaStackLast() + if err != nil { + return nil, err + } + + up.metaStack = up.metaStack[:len(up.metaStack)-1] + + return stack, nil +} + +// metaStackLast get last stack in metaStack. +func (up *Unpickler) metaStackLast() ([]interface{}, error) { + if len(up.metaStack) == 0 { + err := fmt.Errorf("Unpickler.metaStackLast() failed: metaStack is empty.") + return nil, err + } + + last := up.metaStack[len(up.metaStack)-1] + + return last, nil +} + +// popMark pops all objects those have been pushed to the stack (after last mask). +func (up *Unpickler) popMark() ([]interface{}, error) { + objects := up.stack + newStack, err := up.metaStackPop() + if err != nil { + return nil, err + } + up.stack = newStack + + return objects, nil +} + +func (up *Unpickler) findClass(module, name string) (interface{}, error) { + switch module { + case "collections": + switch name { + case "OrderedDict": + return &OrderedDictClass{}, nil + } + + case "__builtin__": + switch name { + case "object": + return &ObjectClass{}, nil + } + case "copy_reg": + switch name { + case "_reconstructor": + return &Reconstructor{}, nil + } + } + if up.FindClass != nil { + return up.FindClass(module, name) + } + return NewGenericClass(module, name), nil +} + +func (up *Unpickler) persistentLoad(pid interface{}) error { + err := unpicklingError("Unpickler.persistentLoad() failed: unsupported persistent id encountered.") + return err +} + +// Construct dispatch table: +// ========================= +// See https://en.wikipedia.org/wiki/Dispatch_table +// dispatch table is a table of pointers to functions/methods. + +// unpickle dispatch table +var upDispatch [math.MaxUint8]func(*Unpickler) error + +// loadProto reads pickle protocol version. +func loadProto(up *Unpickler) error { + proto, err := up.readOne() + if err != nil { + return err + } + if proto < 0 || proto >= HighestProtocol { + err := fmt.Errorf("loadProto() failed: unsupported pickle protocol (%d)", proto) + return err + } + + up.proto = proto + + return nil +} + +// loadFrame loads new frame. +func loadFrame(up *Unpickler) error { + buf, err := up.read(8) + if err != nil { + return err + } + frameSize := binary.LittleEndian.Uint64(buf) + if frameSize > math.MaxUint64 { + err := fmt.Errorf("loadFrame() failed: frame size > sys.maxsize %v", frameSize) + return err + } + + return up.loadFrame(int(frameSize)) +} + +// loadPersIds load persistent object to stack. +func loadPersId(up *Unpickler) error { + if up.PersistentLoad == nil { + err := fmt.Errorf("loadPersId() failed: unsupported persistent Id encountered.") + return err + } + + line, err := up.readLine() + if err != nil { + return err + } + + pid := string(line[:len(line)-1]) + obj, err := up.PersistentLoad(pid) + if err != nil { + err = fmt.Errorf("loadPersId() failed: %w", err) + return err + } + + up.append(obj) + + return nil +} + +func loadBinPersId(up *Unpickler) error { + if up.PersistentLoad == nil { + err := fmt.Errorf("loadPersId() failed: unsupported persistent Id encountered.") + return err + } + + pid, err := up.stackPop() + if err != nil { + return err + } + + obj, err := up.PersistentLoad(pid) + if err != nil { + err = fmt.Errorf("loadBinPersId() failed: %w", err) + return err + } + + up.append(obj) + + return nil +} + +// loads nil object +func loadNone(up *Unpickler) error { + up.append(nil) + return nil +} + +// loads a bool object with value = false. +func loadFalse(up *Unpickler) error { + up.append(false) + return nil +} + +// loads a bool object with value = true +func loadTrue(up *Unpickler) error { + up.append(true) + return nil +} + +// loads object of type int (can be integer, bool or decimal string value) +func loadInt(up *Unpickler) error { + line, err := up.readLine() + if err != nil { + err = fmt.Errorf("loadInt() failed: %w", err) + return err + } + + data := string(line[:len(line)-1]) + switch { + case len(data) == 2 && data[0] == '0' && data[1] == '0': + up.append(false) + return nil + + case len(data) == 2 && data[0] == '0' && data[1] == '1': + up.append(true) + return nil + + default: + val, err := strconv.Atoi(data) + if err != nil { + err = fmt.Errorf("loadInt() failed: %w", err) + } + up.append(val) + return nil + } +} + +// load 4 bytes of uint. +func loadBinInt(up *Unpickler) error { + buf, err := up.read(4) + if err != nil { + err = fmt.Errorf("loadBinInt() failed: %w", err) + return err + } + + uval := binary.LittleEndian.Uint32(buf) + val := int(uval) + if buf[3]&0x80 != 0 { + val = -(int(^uval) + 1) + } + up.append(val) + + return nil +} + +// loads one byte of uint. +func loadBinInt1(up *Unpickler) error { + b, err := up.readOne() + if err != nil { + err = fmt.Errorf("loadBinInt1() failed: %w", err) + return err + } + up.append(int(b)) + + return nil +} + +// loads 2 bytes of uint. +func loadBinInt2(up *Unpickler) error { + buf, err := up.read(2) + if err != nil { + err = fmt.Errorf("loadBinInt2() failed: %w", err) + return err + } + + val := int(binary.LittleEndian.Uint16(buf)) + up.append(val) + + return nil +} + +// load long; decimal string argument. +func loadLong(up *Unpickler) error { + line, err := up.readLine() + if err != nil { + err = fmt.Errorf("loadLong() failed: %w", err) + return nil + } + + // last byte is string dtype. + if len(line) == 1 { + err = fmt.Errorf("loadLong() failed: invalid long data") + } + data := line[:len(line)-1] + if data[len(data)-1] == 'L' { + data = data[0 : len(data)-1] + } + + val, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + // check for overflow, if so, swap to larger range. + if numErr, ok := err.(*strconv.NumError); ok && numErr.Err == strconv.ErrRange { + bigInt, ok := new(big.Int).SetString(string(data), 10) + if !ok { + err = fmt.Errorf("loadLong() failed: invalid long data") + return err + } + + up.append(bigInt) + return nil + } + + err = fmt.Errorf("loadLong() failed: %w", err) + return err + } + up.append(int(val)) + + return nil +} + +// loads long interger of less than 256 bytes. +func loadLong1(up *Unpickler) error { + len, err := up.readOne() + if err != nil { + err = fmt.Errorf("loadLong1() failed: %w", err) + return err + } + + buf, err := up.read(int(len)) + if err != nil { + err = fmt.Errorf("loadLong1() failed: %w", err) + return err + } + + val := decodeLong(buf) + up.append(val) + + return nil +} + +// loads object of really big long integer. +func loadLong4(up *Unpickler) error { + buf, err := up.read(4) + if err != nil { + err = fmt.Errorf("loadLong4() failed: %w", err) + return err + } + + len := decodeInt32(buf) + if len < 0 { + err = fmt.Errorf("loadLong4() failed: LONG pickle has negative byte count") + } + data, err := up.read(len) + if err != nil { + err = fmt.Errorf("loadLong4() failed: %w", err) + return err + } + val := decodeLong(data) + up.append(val) + + return nil +} + +// loads float object or decimal string argument. +func loadFloat(up *Unpickler) error { + line, err := up.readLine() + if err != nil { + err = fmt.Errorf("loadFloat() failed: %w", err) + return err + } + + val, err := strconv.ParseFloat(string(line[:len(line)-1]), 64) + if err != nil { + err = fmt.Errorf("loadFloat() failed: %w", err) + return err + } + up.append(val) + + return nil +} + +// loads float object of 8-byte encoding. +func loadBinFloat(up *Unpickler) error { + buf, err := up.read(8) + if err != nil { + err = fmt.Errorf("loadBinFloat() failed: %w", err) + return err + } + + val := math.Float64frombits(binary.BigEndian.Uint64(buf)) + up.append(val) + + return nil +} + +// loads object of string value. +func loadString(up *Unpickler) error { + line, err := up.readLine() + if err != nil { + err = fmt.Errorf("loadString() failed: %w", err) + return err + } + + data := line[:len(line)-1] + + // strip outermost quotes + if len(data) >= 2 && data[0] == data[len(data)-1] && (data[0] == '\'' || data[0] == '"') { + data = data[1 : len(data)-1] + } else { + err = unpicklingError("the STRING opcode argument must be quoted.") + err = fmt.Errorf("loadString() failed: %w", err) + return err + } + up.append(data) + + return nil +} + +// loads object of counted binary string. +func loadBinString(up *Unpickler) error { + // Deprecated BINSTRING uses signed 32-bit length + buf, err := up.read(4) + if err != nil { + err = fmt.Errorf("loadBinString() failed: %w", err) + return err + } + + len := decodeInt32(buf) + if len < 0 { + err = unpicklingError("loadBinString() failed: BINSTRING pickle has negative byte count.") + return err + } + + data, err := up.read(len) + if err != nil { + err = fmt.Errorf("loadBinString() failed: %w", err) + return err + } + + val := string(data) + up.append(val) + + return nil +} + +// loads object of bytes +func loadBinBytes(up *Unpickler) error { + buf, err := up.read(4) + if err != nil { + err := fmt.Errorf("loadBinBytes() failed: %w", err) + return err + } + + len := int(binary.LittleEndian.Uint32(buf)) + buf, err = up.read(len) + if err != nil { + err := fmt.Errorf("loadBinBytes() failed: %w", err) + return err + } + up.append(buf) + + return nil +} + +// loads object of Unicode string value (raw-unicode-escaped). +func loadUnicode(up *Unpickler) error { + line, err := up.readLine() + if err != nil { + err := fmt.Errorf("loadUnicode() failed: %w", err) + return err + } + val := string(line[:len(line)-1]) + up.append(val) + + return nil +} + +// loads objects of Unicode string (counted UTF-8 string) +func loadBinUnicode(up *Unpickler) error { + buf, err := up.read(4) + if err != nil { + err = fmt.Errorf("loadBinUnicode() failed: %w", err) + return err + } + + len := int(binary.LittleEndian.Uint32(buf)) + buf, err = up.read(len) + if err != nil { + err = fmt.Errorf("loadBinUnicode() failed: %w", err) + return err + } + val := string(buf) + up.append(val) + + return nil +} + +// loads a object of very long string value. +func loadBinUnicode8(up *Unpickler) error { + buf, err := up.read(8) + if err != nil { + err = fmt.Errorf("loadBinUnicode8() failed: %w", err) + return err + } + + len := int(binary.LittleEndian.Uint64(buf)) + if len > math.MaxInt64 { + err = unpicklingError("loadBinUnicode8() failed: BINUNICODE8 exceeds system's maximum size") + return err + } + buf, err = up.read(len) + if err != nil { + err = fmt.Errorf("loadBinUnicode8() failed: %w", err) + return err + } + val := string(buf) + up.append(val) + + return nil +} + +// loads object of very long bytes string value. +func loadBinBytes8(up *Unpickler) error { + buf, err := up.read(8) + if err != nil { + err = fmt.Errorf("loadBinBytes8() failed: %w", err) + return err + } + + len := binary.LittleEndian.Uint64(buf) + if len > math.MaxInt64 { + err = unpicklingError("loadBinBytes8() failed: BINBYTES8 exceeds system's maximum size") + return err + } + buf, err = up.read(int(len)) + if err != nil { + err = fmt.Errorf("loadBinBytes8() failed: %w", err) + return err + } + up.append(buf) + + return nil +} + +func loadByteArray8(up *Unpickler) error { + buf, err := up.read(8) + if err != nil { + err = fmt.Errorf("loadBinBytes8() failed: %w", err) + return err + } + + len := binary.LittleEndian.Uint64(buf) + if len > math.MaxInt64 { + err = unpicklingError("loadBinBytes8() failed: BINBYTES8 exceeds system's maximum size.") + return err + } + buf, err = up.read(int(len)) + if err != nil { + err = fmt.Errorf("loadBinBytes8() failed: %w", err) + return err + } + + val := NewByteArrayFromSlice(buf) + up.append(val) + + return nil +} + +// loads next out-of-band buffer. +func loadNextBuffer(up *Unpickler) error { + if up.NextBufferFunc == nil { + err := fmt.Errorf("loadNextBuffer() failed: Pickle stream refers to out-of-band data but NextBufferFunc was not given") + return err + } + + buf, err := up.NextBufferFunc() + if err != nil { + err = fmt.Errorf("loadNextBuffer() failed: %w", err) + return err + } + + up.append(buf) + + return nil +} + +// makes top of stack readonly. +func loadReadOnlyBuffer(up *Unpickler) error { + if up.MakeReadOnlyFunc == nil { + return nil + } + + buf, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadReadOnlyBuffer() failed: %w", err) + return err + } + + buf, err = up.MakeReadOnlyFunc(buf) + if err != nil { + err = fmt.Errorf("loadReadOnlyBuffer() failed: %w", err) + return err + } + up.append(buf) + + return nil +} + +// loads counted binary string object (< 256 bytes). +func loadShortBinString(up *Unpickler) error { + len, err := up.readOne() + if err != nil { + err = fmt.Errorf("loadShortBinString() failed: %w", err) + return err + } + data, err := up.read(int(len)) + if err != nil { + err = fmt.Errorf("loadShortBinString() failed: %w", err) + return err + } + up.append(string(data)) + + return nil +} + +// loads bytes object with counted binary string < 256 bytes. +func loadShortBinBytes(up *Unpickler) error { + len, err := up.readOne() + if err != nil { + err = fmt.Errorf("loadShortBinBytes() failed: %w", err) + return err + } + buf, err := up.read(int(len)) + if err != nil { + err = fmt.Errorf("loadShortBinBytes() failed: %w", err) + return err + } + up.append(buf) + + return nil +} + +// load short string object; UTF-8 length < 256 bytes. +func loadShortBinUnicode(up *Unpickler) error { + len, err := up.readOne() + if err != nil { + err = fmt.Errorf("loadShortBinUnicode() failed: %w", err) + return err + } + + buf, err := up.read(int(len)) + if err != nil { + err = fmt.Errorf("loadShortBinUnicode() failed: %w", err) + return err + } + up.append(string(buf)) + + return nil +} + +// loads tuple from last-mark stack objects. +func loadTuple(up *Unpickler) error { + objects, err := up.popMark() + if err != nil { + err = fmt.Errorf("loadTuple() failed: %w", err) + return err + } + val := NewTupleFromSlice(objects) + up.append(val) + + return nil +} + +// load empty tuple. +func loadEmptyTuple(up *Unpickler) error { + t := NewTupleFromSlice([]interface{}{}) + up.append(t) + + return nil +} + +// load one tuple from stack top. +func loadTuple1(up *Unpickler) error { + obj, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadTuple() failed: %w", err) + return err + } + val := NewTupleFromSlice([]interface{}{obj}) + up.append(val) + + return nil +} + +// load 2-tuple object from 2 topmost stack objects. +func loadTuple2(up *Unpickler) error { + obj2, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadTuple2() failed: %w", err) + return err + } + obj1, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadTuple2() failed: %w", err) + return err + } + val := NewTupleFromSlice([]interface{}{obj1, obj2}) + up.append(val) + + return nil +} + +// loads 3-tuple object from 3 most stack objects +func loadTuple3(up *Unpickler) error { + objects := make([]interface{}, 3) + var ( + err error + n int = 0 + ) + for i := 2; i >= 0; i-- { + objects[n], err = up.stackPop() + if err != nil { + err = fmt.Errorf("loadTuple3() failed: %w", err) + } + } + val := NewTupleFromSlice(objects) + up.append(val) + + return nil +} + +// loads empty list object. +func loadEmptyList(up *Unpickler) error { + up.append(NewList()) + + return nil +} + +// loads empty dict. +func loadEmptyDict(up *Unpickler) error { + up.append(NewDict()) + return nil +} + +// loads empty set on the stack. +func loadEmptySet(up *Unpickler) error { + up.append(NewSet()) + return nil +} + +// loads frozenset from topmost stack objects. +func loadFrozenSet(up *Unpickler) error { + objects, err := up.popMark() + if err != nil { + err = fmt.Errorf("loadFrozenSet() failed: %w", err) + return err + } + up.append(NewFrozenSetFromSlice(objects)) + return nil +} + +// loads list from topmost stack objects +func loadList(up *Unpickler) error { + objects, err := up.popMark() + if err != nil { + err = fmt.Errorf("loadList() failed: %w", err) + return err + } + up.append(NewListFromSlice(objects)) + return nil +} + +// loads a dict from stack objects +func loadDict(up *Unpickler) error { + objects, err := up.popMark() + if err != nil { + err = fmt.Errorf("loadDict() failed: %w", err) + return err + } + d := NewDict() + objectsLen := len(objects) + for i := 0; i < objectsLen; i += 2 { + d.Set(objects[i], objects[i+1]) + } + up.append(d) + return nil +} + +// loads class instance. +func loadInst(up *Unpickler) error { + line, err := up.readLine() + if err != nil { + err = fmt.Errorf("loadInst() failed: %w", err) + return err + } + module := string(line[0 : len(line)-1]) + + line, err = up.readLine() + if err != nil { + err = fmt.Errorf("loadInst() failed: %w", err) + return err + } + name := string(line[0 : len(line)-1]) + + class, err := up.findClass(module, name) + if err != nil { + err = fmt.Errorf("loadInst() failed: %w", err) + return err + } + + args, err := up.popMark() + if err != nil { + err = fmt.Errorf("loadInst() failed: %w", err) + return err + } + + return up.instantiate(class, args) +} + +// loads class instance +func loadObj(up *Unpickler) error { + // Stack is ... markobject classobject arg1 arg2 ... + args, err := up.popMark() + if err != nil { + err = fmt.Errorf("loadObj() failed: %w", err) + return err + } + if len(args) == 0 { + return fmt.Errorf("OBJ class missing") + } + class := args[0] + args = args[1:len(args)] + return up.instantiate(class, args) +} + +// instantiates a object based on input dtype and arguments. +func (up *Unpickler) instantiate(class interface{}, args []interface{}) error { + var err error + var value interface{} + switch ct := class.(type) { + case Callable: + value, err = ct.Call(args...) + case PyNewable: + value, err = ct.PyNew(args...) + default: + return fmt.Errorf("cannot instantiate %#v", class) + } + + if err != nil { + err = fmt.Errorf("instantiate() failed: %w", err) + return err + } + up.append(value) + return nil +} + +// loads object by applying cls.__new__ to argtuple +func loadNewObj(up *Unpickler) error { + args, err := up.stackPop() + if err != nil { + return err + } + argsTuple, argsOk := args.(*Tuple) + if !argsOk { + err := fmt.Errorf("NEWOBJ args must be *Tuple") + err = fmt.Errorf("loadNewObj() failed: %w", err) + return err + } + + rawClass, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadNewObj() failed: %w", err) + return err + } + class, classOk := rawClass.(PyNewable) + if !classOk { + err := fmt.Errorf("NEWOBJ requires a PyNewable object: %#v", rawClass) + err = fmt.Errorf("loadNewObj() failed: %w", err) + return err + } + + result, err := class.PyNew(*argsTuple...) + if err != nil { + return err + } + up.append(result) + return nil +} + +// like NEWOBJ but work with keyword only arguments +func loadNewObjEx(up *Unpickler) error { + kwargs, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadNewObjEx() failed: %w", err) + return err + } + + args, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadNewObjEx() failed: %w", err) + return err + } + argsTuple, argsOk := args.(*Tuple) + if !argsOk { + err := fmt.Errorf("NEWOBJ_EX args must be *Tuple") + err = fmt.Errorf("loadNewObjEx() failed: %w", err) + return err + } + + rawClass, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadNewObjEx() failed: %w", err) + return err + } + class, classOk := rawClass.(PyNewable) + if !classOk { + err := fmt.Errorf("NEWOBJ_EX requires a PyNewable object") + err = fmt.Errorf("loadNewObjEx() failed: %w", err) + return err + } + + allArgs := []interface{}(*argsTuple) + allArgs = append(allArgs, kwargs) + + result, err := class.PyNew(allArgs...) + if err != nil { + err = fmt.Errorf("loadNewObjEx() failed: %w", err) + return err + } + up.append(result) + return nil +} + +// loads 'self.find_class(module, name)'; 2 string args. +// It decodes "module" and "name" of the object from binary file +// and find object class, then push to stack. +// +// NOTE. Pytorch rebuilds tensor (legacy) triggers from here +// with: module "torch._utils" - name "_rebuild_tensor" or "_rebuild_tensor_v2" +// rebuild tensor based on '_rebuild_tensor_v2' hook may break in the future. +// ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L138 +func loadGlobal(up *Unpickler) error { + line, err := up.readLine() + if err != nil { + err = fmt.Errorf("loadGlobal() failed: %w", err) + return err + } + module := string(line[0 : len(line)-1]) + + line, err = up.readLine() + if err != nil { + err = fmt.Errorf("loadGlobal() failed: %w", err) + return err + } + name := string(line[0 : len(line)-1]) + + class, err := up.findClass(module, name) + if err != nil { + err = fmt.Errorf("loadGlobal() failed: %w", err) + return err + } + up.append(class) + return nil +} + +// same as GLOBAL but using names on the stacks +func loadStackGlobal(up *Unpickler) error { + rawName, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadStackGlobal() failed: %w", err) + return err + } + name, nameOk := rawName.(string) + if !nameOk { + err := fmt.Errorf("STACK_GLOBAL requires str name: %#v", rawName) + err = fmt.Errorf("loadStackGlobal() failed: %w", err) + return err + } + + rawModule, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadStackGlobal() failed: %w", err) + return err + } + module, moduleOk := rawModule.(string) + if !moduleOk { + err := fmt.Errorf("STACK_GLOBAL requires str module: %#v", rawModule) + err = fmt.Errorf("loadStackGlobal() failed: %w", err) + return err + } + + class, err := up.findClass(module, name) + if err != nil { + err = fmt.Errorf("loadStackGlobal() failed: %w", err) + return err + } + up.append(class) + + return nil +} + +// loads object from extension registry; 1-byte index +func opExt1(up *Unpickler) error { + if up.GetExtension == nil { + err := fmt.Errorf("unsupported extension code encountered") + err = fmt.Errorf("loadStackGlobal() failed: %w", err) + return err + } + i, err := up.readOne() + if err != nil { + err = fmt.Errorf("loadStackGlobal() failed: %w", err) + return err + } + obj, err := up.GetExtension(int(i)) + if err != nil { + err = fmt.Errorf("loadStackGlobal() failed: %w", err) + return err + } + up.append(obj) + + return nil +} + +// ditto, but 2-byte index +func opExt2(up *Unpickler) error { + if up.GetExtension == nil { + err := fmt.Errorf("unsupported extension code encountered") + err = fmt.Errorf("opExt2() failed: %w", err) + return err + } + buf, err := up.read(2) + if err != nil { + err = fmt.Errorf("opExt2() failed: %w", err) + return err + } + code := int(binary.LittleEndian.Uint16(buf)) + obj, err := up.GetExtension(code) + if err != nil { + err = fmt.Errorf("opExt2() failed: %w", err) + return err + } + up.append(obj) + + return nil +} + +// ditto, but 4-byte index +func opExt4(up *Unpickler) error { + if up.GetExtension == nil { + err := fmt.Errorf("unsupported extension code encountered") + err = fmt.Errorf("opExt4() failed: %w", err) + return err + } + buf, err := up.read(4) + if err != nil { + err = fmt.Errorf("opExt4() failed: %w", err) + return err + } + code := int(binary.LittleEndian.Uint32(buf)) + obj, err := up.GetExtension(code) + if err != nil { + err = fmt.Errorf("opExt4() failed: %w", err) + return err + } + up.append(obj) + + return nil +} + +// apply callable to argtuple, both on stack +func loadReduce(up *Unpickler) error { + args, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadReduce() failed: %w", err) + return err + } + argsTuple, argsOk := args.(*Tuple) + if !argsOk { + err := fmt.Errorf("REDUCE args must be *Tuple") + err = fmt.Errorf("loadReduce() failed: %w", err) + return err + } + + function, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadReduce() failed: %w", err) + return err + } + callable, callableOk := function.(Callable) + if !callableOk { + err := fmt.Errorf("REDUCE requires a Callable object: %#v", function) + err = fmt.Errorf("loadReduce() failed: %w", err) + return err + } + + result, err := callable.Call(*argsTuple...) + if err != nil { + err = fmt.Errorf("loadReduce() failed: %w", err) + return err + } + up.append(result) + + return nil +} + +// discards topmost stack item +func loadPop(up *Unpickler) error { + if len(up.stack) == 0 { + _, err := up.popMark() + err = fmt.Errorf("loadPop() failed: %w", err) + return err + } + up.stack = up.stack[:len(up.stack)-1] + + return nil +} + +// discards stack top through topmost markobject +func loadPopMark(up *Unpickler) error { + _, err := up.popMark() + if err != nil { + err = fmt.Errorf("loadPopMark() failed: %w", err) + return err + } + + return nil +} + +// duplicate top stack item +func loadDup(up *Unpickler) error { + item, err := up.stackLast() + if err != nil { + err = fmt.Errorf("loadDup() failed: %w", err) + return err + } + up.append(item) + + return nil +} + +// loads object from memo on stack; index is string arg +func loadGet(up *Unpickler) error { + line, err := up.readLine() + if err != nil { + err = fmt.Errorf("loadGet() failed: %w", err) + return err + } + i, err := strconv.Atoi(string(line[:len(line)-1])) + if err != nil { + err = fmt.Errorf("loadGet() failed: %w", err) + return err + } + up.append(up.memo[i]) + + return nil +} + +// loads object from memo on stack; index is 1-byte arg +func loadBinGet(up *Unpickler) error { + i, err := up.readOne() + if err != nil { + err = fmt.Errorf("loadBinGet() failed: %w", err) + return err + } + up.append(up.memo[int(i)]) + + return nil +} + +// load object from memo on stack; index is 4-byte arg +func loadLongBinGet(up *Unpickler) error { + buf, err := up.read(4) + if err != nil { + err = fmt.Errorf("loadBinGet() failed: %w", err) + return err + } + i := int(binary.LittleEndian.Uint32(buf)) + + up.append(up.memo[i]) + + return nil +} + +// store stack top in memo; index is string arg +func loadPut(up *Unpickler) error { + line, err := up.readLine() + if err != nil { + err = fmt.Errorf("loadPut() failed: %w", err) + return err + } + i, err := strconv.Atoi(string(line[:len(line)-1])) + if err != nil { + err = fmt.Errorf("loadPut() failed: %w", err) + return err + } + if i < 0 { + err := fmt.Errorf("negative PUT argument") + err = fmt.Errorf("loadPut() failed: %w", err) + return err + } + up.memo[i], err = up.stackLast() + if err != nil { + err = fmt.Errorf("loadPut() failed: %w", err) + return err + } + + return nil +} + +// store stack top in memo; index is 1-byte arg +func loadBinPut(up *Unpickler) error { + i, err := up.readOne() + if err != nil { + err = fmt.Errorf("loadBinPut() failed: %w", err) + return err + } + up.memo[int(i)], err = up.stackLast() + if err != nil { + err = fmt.Errorf("loadBinPut() failed: %w", err) + } + + return nil +} + +// stores stack top in memo; index is 4-byte arg +func loadLongBinPut(up *Unpickler) error { + buf, err := up.read(4) + if err != nil { + err = fmt.Errorf("loadLongBinPut() failed: %w", err) + return err + } + i := int(binary.LittleEndian.Uint32(buf)) + up.memo[i], err = up.stackLast() + if err != nil { + err = fmt.Errorf("loadLongBinPut() failed: %w", err) + return err + } + + return nil +} + +// stores top of the stack in memo +func loadMemoize(up *Unpickler) error { + value, err := up.stackLast() + if err != nil { + err = fmt.Errorf("loadMemoize() failed: %w", err) + return err + } + up.memo[len(up.memo)] = value + + return nil +} + +// appends stack top to list below it +func loadAppend(up *Unpickler) error { + value, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadAppend() failed: %w", err) + return err + } + obj, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadAppend() failed: %w", err) + return err + } + list, listOk := obj.(ListAppender) + if !listOk { + err := fmt.Errorf("APPEND requires ListAppender") + err = fmt.Errorf("loadAppend() failed: %w", err) + return err + } + list.Append(value) + up.append(list) + + return nil +} + +// extends list on stack by topmost stack slice +func loadAppends(up *Unpickler) error { + items, err := up.popMark() + if err != nil { + err = fmt.Errorf("loadAppends() failed: %w", err) + return err + } + obj, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadAppends() failed: %w", err) + return err + } + list, listOk := obj.(ListAppender) + if !listOk { + err := fmt.Errorf("APPEND requires List") + err = fmt.Errorf("loadAppends() failed: %w", err) + return err + } + for _, item := range items { + list.Append(item) + } + up.append(list) + + return nil +} + +// adds key+value pair to dict +func loadSetItem(up *Unpickler) error { + value, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadSetItem() failed: %w", err) + return err + } + key, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadSetItem() failed: %w", err) + return err + } + obj, err := up.stackLast() + if err != nil { + err = fmt.Errorf("loadSetItem() failed: %w", err) + return err + } + dict, dictOk := obj.(DictSetter) + if !dictOk { + err := fmt.Errorf("SETITEM requires DictSetter") + err = fmt.Errorf("loadSetItem() failed: %w", err) + return err + } + dict.Set(key, value) + + return nil +} + +// modifies dict by adding topmost key+value pairs +func loadSetItems(up *Unpickler) error { + items, err := up.popMark() + if err != nil { + err = fmt.Errorf("loadSetItems() failed: %w", err) + return err + } + obj, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadSetItems() failed: %w", err) + return err + } + dict, dictOk := obj.(DictSetter) + if !dictOk { + err := fmt.Errorf("SETITEMS requires DictSetter") + err = fmt.Errorf("loadSetItems() failed: %w", err) + return err + } + itemsLen := len(items) + for i := 0; i < itemsLen; i += 2 { + dict.Set(items[i], items[i+1]) + } + up.append(dict) + + return nil +} + +// modifies set by adding topmost stack items +func loadAddItems(up *Unpickler) error { + items, err := up.popMark() + if err != nil { + err = fmt.Errorf("loadAddItems() failed: %w", err) + return err + } + obj, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadAddItems() failed: %w", err) + return err + } + set, setOk := obj.(SetAdder) + if !setOk { + err := fmt.Errorf("ADDITEMS requires SetAdder") + err = fmt.Errorf("loadAddItems() failed: %w", err) + return err + } + for _, item := range items { + set.Add(item) + } + up.append(set) + + return nil +} + +// calls __setstate__ or __dict__.update() +func loadBuild(up *Unpickler) error { + state, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadBuild() failed: %w", err) + return err + } + inst, err := up.stackLast() + if err != nil { + err = fmt.Errorf("loadBuild() failed: %w", err) + return err + } + if obj, ok := inst.(PyStateSettable); ok { + return obj.PySetState(state) + } + + var slotState interface{} + if tuple, ok := state.(*Tuple); ok && tuple.Len() == 2 { + state = tuple.Get(0) + slotState = tuple.Get(1) + } + + if stateDict, ok := state.(*Dict); ok { + instPds, instPdsOk := inst.(PyDictSettable) + if !instPdsOk { + err := fmt.Errorf("BUILD requires a PyDictSettable instance: %#v", inst) + err = fmt.Errorf("loadBuild() failed: %w", err) + return err + } + for _, entry := range *stateDict { + err := instPds.PyDictSet(entry.Key, entry.Value) + if err != nil { + err = fmt.Errorf("loadBuild() failed: %w", err) + return err + } + } + } + + if slotStateDict, ok := slotState.(*Dict); ok { + instSa, instOk := inst.(PyAttrSettable) + if !instOk { + err := fmt.Errorf("BUILD requires a PyAttrSettable instance: %#v", inst) + err = fmt.Errorf("loadBuild() failed: %w", err) + return err + } + + for _, entry := range *slotStateDict { + sk, keyOk := entry.Key.(string) + if !keyOk { + err := fmt.Errorf("BUILD requires string slot state keys") + err = fmt.Errorf("loadBuild() failed: %w", err) + return err + } + err := instSa.PySetAttr(sk, entry.Value) + if err != nil { + err = fmt.Errorf("loadBuild() failed: %w", err) + return err + } + } + } + + return nil +} + +// loads special markobject on stack +func loadMark(up *Unpickler) error { + up.metaStack = append(up.metaStack, up.stack) + up.stack = make([]interface{}, 0) + + return nil +} + +// every pickle ends with STOP +func loadStop(up *Unpickler) error { + value, err := up.stackPop() + if err != nil { + err = fmt.Errorf("loadStop() failed: %w", err) + return err + } + + return Stop{value: value} +} + +// initUnpickleDispatch creates a dispatch table for unpickling machinery. +func initUnpicklerDispatch() { + upDispatch[PROTO] = loadProto + upDispatch[FRAME] = loadFrame + upDispatch[PERSID] = loadPersId + upDispatch[BINPERSID] = loadBinPersId + upDispatch[NONE] = loadNone + upDispatch[NEWFALSE] = loadFalse + upDispatch[NEWTRUE] = loadTrue + upDispatch[INT] = loadInt + upDispatch[BININT] = loadBinInt + upDispatch[BININT1] = loadBinInt1 + upDispatch[BININT2] = loadBinInt2 + upDispatch[LONG] = loadLong + upDispatch[LONG1] = loadLong1 + upDispatch[LONG4] = loadLong4 + upDispatch[FLOAT] = loadFloat + upDispatch[BINFLOAT] = loadBinFloat + upDispatch[STRING] = loadString + upDispatch[BINSTRING] = loadBinString + upDispatch[BINBYTES] = loadBinBytes + upDispatch[UNICODE] = loadUnicode + upDispatch[BINUNICODE] = loadBinUnicode + upDispatch[BINUNICODE8] = loadBinUnicode8 + upDispatch[BINBYTES8] = loadBinBytes8 + upDispatch[BYTEARRAY8] = loadByteArray8 + upDispatch[NEXT_BUFFER] = loadNextBuffer + upDispatch[READONLY_BUFFER] = loadReadOnlyBuffer + upDispatch[SHORT_BINSTRING] = loadShortBinString + upDispatch[SHORT_BINBYTES] = loadShortBinBytes + upDispatch[SHORT_BINUNICODE] = loadShortBinUnicode + upDispatch[TUPLE] = loadTuple + upDispatch[EMPTY_TUPLE] = loadEmptyTuple + upDispatch[TUPLE1] = loadTuple1 + upDispatch[TUPLE2] = loadTuple2 + upDispatch[TUPLE3] = loadTuple3 + upDispatch[EMPTY_LIST] = loadEmptyList + upDispatch[EMPTY_DICT] = loadEmptyDict + upDispatch[EMPTY_SET] = loadEmptySet + upDispatch[FROZENSET] = loadFrozenSet + upDispatch[LIST] = loadList + upDispatch[DICT] = loadDict + upDispatch[INST] = loadInst + upDispatch[OBJ] = loadObj + upDispatch[NEWOBJ] = loadNewObj + upDispatch[NEWOBJ_EX] = loadNewObjEx + upDispatch[GLOBAL] = loadGlobal + upDispatch[STACK_GLOBAL] = loadStackGlobal + upDispatch[EXT1] = opExt1 + upDispatch[EXT2] = opExt2 + upDispatch[EXT4] = opExt4 + upDispatch[REDUCE] = loadReduce + upDispatch[POP] = loadPop + upDispatch[POP_MARK] = loadPopMark + upDispatch[DUP] = loadDup + upDispatch[GET] = loadGet + upDispatch[BINGET] = loadBinGet + upDispatch[LONG_BINGET] = loadLongBinGet + upDispatch[PUT] = loadPut + upDispatch[BINPUT] = loadBinPut + upDispatch[LONG_BINPUT] = loadLongBinPut + upDispatch[MEMOIZE] = loadMemoize + upDispatch[APPEND] = loadAppend + upDispatch[APPENDS] = loadAppends + upDispatch[SETITEM] = loadSetItem + upDispatch[SETITEMS] = loadSetItems + upDispatch[ADDITEMS] = loadAddItems + upDispatch[BUILD] = loadBuild + upDispatch[MARK] = loadMark + upDispatch[STOP] = loadStop +} + +func decodeInt32(buf []byte) int { + uval := binary.LittleEndian.Uint32(buf) + val := int(uval) + if buf[3]&0x80 != 0 { + val = -(int(^uval) + 1) + } + + return val +} + +func decodeLong(data []byte) interface{} { + if len(data) == 0 { + return nil + } + + // determine whether most-significant bit (MSB) is set + isMsbSet := data[len(data)-1]&0x80 != 0 + + if len(data) > 8 { + bInt := new(big.Int) + for i := len(data) - 1; i >= 0; i-- { + bInt = bInt.Lsh(bInt, 8) // left shift 8 bits + if isMsbSet { + bInt = bInt.Or(bInt, big.NewInt(int64(^data[i]))) + } else { + bInt = bInt.Or(bInt, big.NewInt(int64(data[i]))) + } + } // for + + if isMsbSet { + bInt = bInt.Add(bInt, big.NewInt(1)) + bInt = bInt.Neg(bInt) + } + + return bInt + + } // if + + var val, bitMask uint64 + for i := len(data) - 1; i >= 0; i-- { + val = (val << 8) | uint64(data[i]) + bitMask = (bitMask << 8) | 0xFF + } + + if isMsbSet { + return -(int(^val & bitMask)) + } + + return int(val) +} + +func init() { + initUnpicklerDispatch() +} + +// Load decodes objects by loading through unpickling machinery. +func (up *Unpickler) Load() (interface{}, error) { + up.metaStack = make([][]interface{}, 0) + up.stack = make([]interface{}, 0) + up.proto = 0 + + for { + opcode, err := up.readOne() + if err != nil { + return nil, err + } + + opFunc := upDispatch[opcode] + if opFunc == nil { + err := fmt.Errorf("Unpickler.Load() failed:unknown opcode: 0x%x '%c'", opcode, opcode) + return nil, err + } + + err = opFunc(up) + if err != nil { + if p, ok := err.(Stop); ok { + return p.value, nil + } + + err := fmt.Errorf("Unpickler.Load() failed: %w", err) + return nil, err + } + } +} + +// Load unpickles a pickled file. +func Load(filename string) (interface{}, error) { + f, err := os.Open(filename) + if err != nil { + err := fmt.Errorf("Load() failed: %w", err) + return nil, err + } + defer f.Close() + + up := NewUnpickler(f) + + return up.Load() +} + +// Loads unpicles a string. +func Loads(s string) (interface{}, error) { + sr := strings.NewReader(s) + up := NewUnpickler(sr) + + return up.Load() +} diff --git a/pickle/pickle_example_test.go b/pickle/pickle_example_test.go new file mode 100644 index 0000000..421951b --- /dev/null +++ b/pickle/pickle_example_test.go @@ -0,0 +1,60 @@ +package pickle_test + +import ( + "log" + + "github.com/sugarme/gotch" + "github.com/sugarme/gotch/pickle" +) + +func ExampleLoadInfo() { + modelName := "vgg16" + url, ok := gotch.ModelUrls[modelName] + if !ok { + log.Fatalf("Unsupported model name %q\n", modelName) + } + modelFile, err := gotch.CachedPath(url) + if err != nil { + panic(err) + } + + err = pickle.LoadInfo(modelFile) + if err != nil { + log.Fatal(err) + } + + // Output: + // classifier.0.bias - [4096] + // classifier.0.weight - [4096 25088] + // classifier.3.bias - [4096] + // classifier.3.weight - [4096 4096] + // classifier.6.bias - [1000] + // classifier.6.weight - [1000 4096] + // features.0.bias - [64] + // features.0.weight - [64 3 3 3] + // features.10.bias - [256] + // features.10.weight - [256 128 3 3] + // features.12.bias - [256] + // features.12.weight - [256 256 3 3] + // features.14.bias - [256] + // features.14.weight - [256 256 3 3] + // features.17.bias - [512] + // features.17.weight - [512 256 3 3] + // features.19.bias - [512] + // features.19.weight - [512 512 3 3] + // features.2.bias - [64] + // features.2.weight - [64 64 3 3] + // features.21.bias - [512] + // features.21.weight - [512 512 3 3] + // features.24.bias - [512] + // features.24.weight - [512 512 3 3] + // features.26.bias - [512] + // features.26.weight - [512 512 3 3] + // features.28.bias - [512] + // features.28.weight - [512 512 3 3] + // features.5.bias - [128] + // features.5.weight - [128 64 3 3] + // features.7.bias - [128] + // features.7.weight - [128 128 3 3] + // Num of variables: 32 +} diff --git a/pickle/serialization.go b/pickle/serialization.go new file mode 100644 index 0000000..566d4fb --- /dev/null +++ b/pickle/serialization.go @@ -0,0 +1,519 @@ +package pickle + +// Ref. +// https://docs.python.org/3/library/pickle.html +// https://docs.python.org/3/library/pickletools.html +// https://github.com/python/cpython/blob/main/Lib/pickle.py (****real code here****) +// https://pytorch.org/tutorials/beginner/saving_loading_models.html +// https://github.com/pytorch/pytorch/blob/master/torch/serialization.py + +import ( + "archive/tar" + "archive/zip" + "errors" + "fmt" + "io" + "math/big" + "os" + "path" + "reflect" + "sort" + + "github.com/sugarme/gotch/nn" + ts "github.com/sugarme/gotch/tensor" +) + +const hexMagicNumber = "1950a86a20f9469cfc6c" +const protocolVersion = 1001 + +var ErrInvalidMagicNumber = errors.New("invalid pytorch magic number") +var ErrInvalidProtocolVersion = errors.New("invalid pytorch protocol version") + +// Encode encodes model using pickling machinery. +// Output pickled model can be loads with Python Pytorch as `torch.load("pytorch_model.bin")` +// +// TODO. implement pickling part so that model can be exported and load with Python Pytorch. +// See https://github.com/python/cpython/blob/b0de6299a840a397d4fe3e6c98159d9f258d3295/Lib/pickle.py#L407 +func Encode(model ts.Module, outputFile string) error { + panic("NotImplementedError") +} + +// Decode decodes pickled data created by 'torch.save()' with Python Pytorch +// and rebuilds named tensor weights. +func Decode(filename string) (map[string]*ts.Tensor, error) { + newUnpickler := func(r io.Reader) Unpickler { + return NewUnpickler(r) + } + result, err := LoadWithUnpickler(filename, newUnpickler) + if err != nil { + err := fmt.Errorf("Decode() failed: %w", err) + return nil, err + } + + // Rebuild tensors from Storage tensors + namedTensors := make(map[string]*ts.Tensor) + dictResult, isOrderedDict := result.(*OrderedDict) + if !isOrderedDict { + err := fmt.Errorf("Decode() failed: expected 'OrderedDict' type, got %v\n", reflect.TypeOf(result).String()) + return nil, err + } + for name, item := range dictResult.Map { + sx, isStorageTensor := item.Value.(*StorageTensor) + if !isStorageTensor { + err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String()) + return nil, err + } + + data := sx.Source.GetData() + size := sx.Size + dtype := sx.Source.DType() + device := sx.Source.Device() + stride := sx.Stride + storageOffset := sx.StorageOffset + + // fmt.Printf("%q - shape: %v - stride: %v - storageOffset: %v\n", sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset) + + x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true) + if sx.RequiresGrad { + x.MustRequiresGrad_(sx.RequiresGrad) + } + + namedTensors[name.(string)] = x + } + + return namedTensors, nil +} + +// LoadWithUnpickler is like Load, but it accepts a newUnpickler function which +// is used to create new customized pickle.Unpickler instances. +func LoadWithUnpickler(filename string, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) { + if !isZipFile(filename) { + return loadLegacyFile(filename, newUnpickler) + } + return loadZipFile(filename, newUnpickler) +} + +func loadZipFile(filename string, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) { + // Open a zip archive for reading. + r, err := zip.OpenReader(filename) + if err != nil { + return nil, err + } + defer r.Close() + + fileRecords := make(map[string]*zip.File, len(r.File)) + for _, f := range r.File { + _, recordName := path.Split(f.Name) + fileRecords[recordName] = f + } + + if _, isTorchScript := fileRecords["constants.pkl"]; isTorchScript { + return nil, fmt.Errorf("TorchScript is not supported") + } + + dataFile, hasDataFile := fileRecords["data.pkl"] + if !hasDataFile { + return nil, fmt.Errorf("data.pkl not found in zip file") + } + df, err := dataFile.Open() + if err != nil { + return nil, err + } + defer df.Close() + + loadedStorages := make(map[string]Storage) + + u := newUnpickler(df) + u.FindClass = makePickleFindClass(u.FindClass) + u.PersistentLoad = func(savedId interface{}) (interface{}, error) { + tuple, tupleOk := savedId.(*Tuple) + if !tupleOk || tuple.Len() == 0 { + return nil, fmt.Errorf("PersistentLoad: non-empty tuple expected, got %#v", savedId) + } + typename, typenameOk := tuple.Get(0).(string) + if !typenameOk { + return nil, fmt.Errorf("PersistentLoad: cannot get typename") + } + if typename != "storage" { + return nil, fmt.Errorf("unknown typename for PersistentLoad, expected 'storage' but got '%s'", typename) + } + if tuple.Len() < 5 { + return nil, fmt.Errorf("PersistentLoad: unexpected storage data length") + } + dataType, dataTypeOk := tuple.Get(1).(StorageClass) + key, keyOk := tuple.Get(2).(string) + location, locationOk := tuple.Get(3).(string) + size, sizeOk := tuple.Get(4).(int) + if !dataTypeOk || !keyOk || !locationOk || !sizeOk { + return nil, fmt.Errorf("PersistentLoad: unexpected data types") + } + storage, storageExists := loadedStorages[key] + if !storageExists { + storage, err = loadTensor(dataType, size, location, key, fileRecords) + if err != nil { + return nil, err + } + loadedStorages[key] = storage + } + return storage, nil + } + return u.Load() +} + +func loadTensor( + dataType StorageClass, + size int, + location, key string, + zipFileRecords map[string]*zip.File, +) (Storage, error) { + file, fileOk := zipFileRecords[key] + if !fileOk { + return nil, fmt.Errorf("cannot find zip record '%s'", key) + } + f, err := file.Open() + if err != nil { + return nil, err + } + defer f.Close() + + storage := dataType.New(size, location) + err = storage.SetFromFileWithSize(f, size) + return storage, err +} + +func loadLegacyFile(filename string, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + + tr := tar.NewReader(f) + for { + _, err := tr.Next() + switch err { + case nil: + // TODO: ... + panic("legacy load from tar not implemented") + case io.EOF: + break // End of archive + case tar.ErrHeader, io.ErrUnexpectedEOF: + _, err = f.Seek(0, io.SeekStart) + if err != nil { + return nil, err + } + return loadLegacyNoTar(f, newUnpickler) + default: + return nil, err + } + } +} + +func loadLegacyNoTar(f *os.File, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) { + if err := readAndCheckMagicNumber(f); err != nil { + return nil, err + } + if err := readAndChecProtocolVersion(f); err != nil { + return nil, err + } + if _, err := unpickle(f); err != nil { // sys info + return nil, err + } + + deserializedObjects := make(map[string]Storage) + + u := newUnpickler(f) + u.FindClass = makePickleFindClass(u.FindClass) + u.PersistentLoad = func(savedId interface{}) (interface{}, error) { + tuple, tupleOk := savedId.(*Tuple) + if !tupleOk || tuple.Len() == 0 { + return nil, fmt.Errorf("PersistentLoad: non-empty tuple expected, got %#v", savedId) + } + typename, typenameOk := tuple.Get(0).(string) + if !typenameOk { + return nil, fmt.Errorf("PersistentLoad: cannot get typename") + } + + // fmt.Printf("typename: %s\n", typename) + + switch typename { + case "storage": + if tuple.Len() < 6 { + return nil, fmt.Errorf( + "PersistentLoad: unexpected storage data length") + } + dataType, dataTypeOk := tuple.Get(1).(StorageClass) + rootKey, rootKeyOk := tuple.Get(2).(string) + location, locationOk := tuple.Get(3).(string) + size, sizeOk := tuple.Get(4).(int) + viewMetadata := tuple.Get(5) + if !dataTypeOk || !rootKeyOk || !locationOk || !sizeOk { + return nil, fmt.Errorf("PersistentLoad: unexpected data types") + } + + // fmt.Printf("dtype: %v - rootKey: %v - device: %v - size %v - viewMetaData: %v\n", reflect.TypeOf(dataType), rootKey, location, size, viewMetadata) + + storage, storageExists := deserializedObjects[rootKey] + if !storageExists { + storage = dataType.New(size, location) + deserializedObjects[rootKey] = storage + } + switch vm := viewMetadata.(type) { + case nil: + return storage, nil + case []interface{}: + if len(vm) != 3 { + return nil, fmt.Errorf( + "PersistentLoad: unexpected view metadata length") + } + panic("viewMetadata not implemented") + // TODO: ... + // view_key, offset, view_size = view_metadata + // if view_key not in deserialized_objects: + // deserialized_objects[view_key] = storage[offset:offset + view_size] + // return deserialized_objects[view_key] + default: + return nil, fmt.Errorf("PersistentLoad: unexpected view metadata type") + } + case "module": + if tuple.Len() < 2 { + return nil, fmt.Errorf("PersistentLoad: unexpected module data length") + } + return tuple.Get(1), nil + default: + return nil, fmt.Errorf("Unexpected saved ID type: %s", typename) + } + } + + result, err := u.Load() + if err != nil { + return nil, err + } + + rawStorageKeys, err := unpickle(f) + if err != nil { + return nil, err + } + storageKeys, err := makeStorageKeys(rawStorageKeys) + if err != nil { + return nil, err + } + + for _, key := range storageKeys { + storageObj, ok := deserializedObjects[key] + if !ok { + return nil, fmt.Errorf("storage object not found for key '%s'", key) + } + + err = storageObj.SetFromFile(f) + if err != nil { + return nil, err + } + } + + return result, nil +} + +func makeStorageKeys(obj interface{}) ([]string, error) { + list, ok := obj.(*List) + if !ok { + return nil, fmt.Errorf("invalid storage keys data") + } + keys := make([]string, len(*list)) + for i, rawKey := range *list { + key, keyOk := rawKey.(string) + if !keyOk { + return nil, fmt.Errorf("invalid storage key") + } + keys[i] = key + } + return keys, nil +} + +func readAndCheckMagicNumber(r io.Reader) error { + obj, err := unpickle(r) + if err != nil { + return err + } + if n, ok := obj.(*big.Int); !ok || n.Text(16) != hexMagicNumber { + return ErrInvalidMagicNumber + } + return nil +} + +func readAndChecProtocolVersion(r io.Reader) error { + obj, err := unpickle(r) + if err != nil { + return err + } + if n, ok := obj.(int); !ok || n != protocolVersion { + return ErrInvalidProtocolVersion + } + return nil +} + +func unpickle(r io.Reader) (interface{}, error) { + u := NewUnpickler(r) + return u.Load() +} + +func isZipFile(filename string) bool { + r, err := zip.OpenReader(filename) + if err != nil { + return false + } + r.Close() + return true +} + +func makePickleFindClass(fallback func(module, name string) (interface{}, error)) func(module, name string) (interface{}, error) { + return func(module, name string) (interface{}, error) { + switch module + "." + name { + case "torch._utils._rebuild_tensor": + return &RebuildTensor{}, nil + case "torch._utils._rebuild_tensor_v2": + return &RebuildTensorV2{}, nil + case "torch.FloatStorage": + return &FloatStorageClass{}, nil + case "torch.HalfStorage": + return &HalfStorageClass{}, nil + case "torch.DoubleStorage": + return &DoubleStorageClass{}, nil + case "torch.CharStorage": + return &CharStorageClass{}, nil + case "torch.ShortStorage": + return &ShortStorageClass{}, nil + case "torch.IntStorage": + return &IntStorageClass{}, nil + case "torch.LongStorage": + return &LongStorageClass{}, nil + case "torch.ByteStorage": + return &ByteStorageClass{}, nil + case "torch.BoolStorage": + return &BoolStorageClass{}, nil + case "torch.nn.backends.thnn._get_thnn_function_backend": + // this is for historical pickle deserilaization, it is not used otherwise + return getThnnFunctionBackend{}, nil + default: + if fallback == nil { + return nil, fmt.Errorf("class not found: %s %s", module, name) + } + return fallback(module, name) + } + } +} + +// LoadAll finds and loads all weights from varstore. +// It will throw err if one of weights from varstore cannot find from loaded pretrained model. +func LoadAll(vs *nn.VarStore, modelFile string) error { + weights, err := Decode(modelFile) + if err != nil { + err = fmt.Errorf("LoadAll() failed: %w", err) + return err + } + + // for tsName, _ := range vs.Vars.NamedVariables { + for tsName := range vs.Vars.NamedVariables { + // missing variable + currTs, ok := weights[tsName] + if !ok { + err = fmt.Errorf("LoadAll() failed: Cannot find tensor with name: %v in variable store. \n", tsName) + return err + } + + // mismatched shape + sourceShape := currTs.MustSize() + destShape := vs.Vars.NamedVariables[tsName].MustSize() + if !reflect.DeepEqual(destShape, sourceShape) { + err = fmt.Errorf("LoadAll() failed: Mismatched shape error for variable name: %v - At store: %v - At source %v\n", tsName, destShape, sourceShape) + return err + } + + ts.NoGrad(func() { + vs.Vars.NamedVariables[tsName].Copy_(currTs) + }) + } + + for _, x := range weights { + x.MustDrop() + } + + return nil +} + +// LoadPartial finds and loads weights for varstore. +// It returns list of unfound weight names. +func LoadPartial(vs *nn.VarStore, modelFile string) ([]string, error) { + weights, err := Decode(modelFile) + if err != nil { + err = fmt.Errorf("LoadPartial() failed: %w", err) + return nil, err + } + + var missingVariables []string + + // Match and in-place copy value (update) from newly loaded tensors + // to existing named tensors if name is matched. Throw error otherwise. + for tsName := range vs.Vars.NamedVariables { + var currTs *ts.Tensor + var ok bool + + // missing variable + if currTs, ok = weights[tsName]; !ok { + missingVariables = append(missingVariables, tsName) + continue + } + + // mismatched shape + destShape := currTs.MustSize() + sourceShape := vs.Vars.NamedVariables[tsName].MustSize() + if !reflect.DeepEqual(destShape, sourceShape) { + fmt.Printf("WARNING: Mismatched shape error for variable name: %v - At store: %v - At source %v. Skip loading this weight...\n", tsName, destShape, sourceShape) + missingVariables = append(missingVariables, tsName) + continue + } + + ts.NoGrad(func() { + vs.Vars.NamedVariables[tsName].Copy_(currTs) + }) + } + + for _, x := range weights { + x.MustDrop() + } + + return missingVariables, nil +} + +// LoadInfo loads pretrained weights and prints out name and shape of weights. +func LoadInfo(modelFile string) error { + weights, err := Decode(modelFile) + if err != nil { + err = fmt.Errorf("LoadInfo() failed: %w", err) + return err + } + + layers := make([]string, 0, len(weights)) + for tsName := range weights { + layers = append(layers, tsName) + } + sort.Strings(layers) + for _, l := range layers { + var x *ts.Tensor + for tsName, tsVal := range weights { + if tsName == l { + x = tsVal + break + } + } + fmt.Printf("%s - %+v\n", l, x.MustSize()) + } + + fmt.Printf("Num of variables: %v\n", len(weights)) + + for _, x := range weights { + x.MustDrop() + } + + return nil +} diff --git a/pickle/storage.go b/pickle/storage.go new file mode 100644 index 0000000..5a58bf5 --- /dev/null +++ b/pickle/storage.go @@ -0,0 +1,667 @@ +package pickle + +import ( + "encoding/binary" + "fmt" + "io" + "math" + + "github.com/sugarme/gotch" +) + +// This file implements Pytorch storage data types. +// ref: https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/storage.py +/* +torch.double: 'DoubleStorage', +torch.float: 'FloatStorage', +torch.half: 'HalfStorage', +torch.long: 'LongStorage', +torch.int: 'IntStorage', +torch.int16: 'ShortStorage', +torch.int8: 'CharStorage', +torch.uint8: 'ByteStorage', +torch.bool: 'BoolStorage', +torch.bfloat16: 'BFloat16Storage', +torch.cdouble: 'ComplexDoubleStorage', +torch.cfloat: 'ComplexFloatStorage', +torch.qint8: 'QInt8Storage', +torch.qint32: 'QInt32Storage', +torch.quint8: 'QUInt8Storage', +torch.quint4x2: 'QUInt4x2Storage', +torch.quint2x4: 'QUInt2x4Storage', +*/ + +// StorageClass defines interface for types to be used in Storage. +type StorageClass interface { + New(size int, location string) Storage +} + +// Storage define Storage interface. +type Storage interface { + SetFromFile(r io.Reader) error + SetFromFileWithSize(r io.Reader, size int) error + DType() gotch.DType + GetData() interface{} + Device() gotch.Device +} + +// BaseStorage represents a base storage. +type BaseStorage struct { + Size int + Location string +} + +// HalfStorage: +// ============ + +type HalfStorageClass struct{} + +var _ StorageClass = &HalfStorageClass{} + +func (s *HalfStorageClass) New(size int, location string) Storage { + return &HalfStorage{ + BaseStorage: BaseStorage{Size: size, Location: location}, + Data: nil, + } +} + +type HalfStorage struct { + BaseStorage + Data []float32 +} + +var _ Storage = &HalfStorage{} + +func (s *HalfStorage) SetFromFile(r io.Reader) error { + return setFromFile(s, r) +} + +func (s *HalfStorage) SetFromFileWithSize(r io.Reader, size int) error { + data := make([]float32, size) + br := NewLimitedBufferReader(r, size, 2, 512) + for i := 0; i < size; i++ { + bytes, err := br.ReadNext() + if err != nil { + return err + } + u16 := binary.LittleEndian.Uint16(bytes) + data[i] = math.Float32frombits(FloatBits16to32(u16)) + } + s.Data = data + return nil +} + +func (s *HalfStorage) GetData() interface{} { + return s.Data +} + +func (s *HalfStorage) DType() gotch.DType { + return gotch.Float +} + +func (s *HalfStorage) Device() gotch.Device { + switch s.Location { + case "cuda": + return gotch.CudaIfAvailable() + default: + return gotch.CPU + } +} + +// FloatStorage: +// ============= + +type FloatStorageClass struct{} + +var _ StorageClass = &FloatStorageClass{} + +func (s *FloatStorageClass) New(size int, location string) Storage { + return &FloatStorage{ + BaseStorage: BaseStorage{Size: size, Location: location}, + Data: nil, + } +} + +type FloatStorage struct { + BaseStorage + Data []float32 +} + +var _ Storage = &FloatStorage{} + +func (s *FloatStorage) SetFromFile(r io.Reader) error { + return setFromFile(s, r) +} + +func (s *FloatStorage) SetFromFileWithSize(r io.Reader, size int) error { + data := make([]float32, size) + br := NewLimitedBufferReader(r, size, 4, 512) + for i := 0; i < size; i++ { + bytes, err := br.ReadNext() + if err != nil { + return err + } + data[i] = math.Float32frombits(binary.LittleEndian.Uint32(bytes)) + } + s.Data = data + return nil +} + +func (s *FloatStorage) GetData() interface{} { + return s.Data +} + +func (s *FloatStorage) DType() gotch.DType { + return gotch.Float +} + +func (s *FloatStorage) Device() gotch.Device { + switch s.Location { + case "cuda": + return gotch.CudaIfAvailable() + default: + return gotch.CPU + } +} + +// DoubleStorage: +// ============== + +type DoubleStorageClass struct{} + +var _ StorageClass = &DoubleStorageClass{} + +func (s *DoubleStorageClass) New(size int, location string) Storage { + return &DoubleStorage{ + BaseStorage: BaseStorage{Size: size, Location: location}, + Data: nil, + } +} + +type DoubleStorage struct { + BaseStorage + Data []float64 +} + +var _ Storage = &DoubleStorage{} + +func (s *DoubleStorage) SetFromFile(r io.Reader) error { + return setFromFile(s, r) +} + +func (s *DoubleStorage) SetFromFileWithSize(r io.Reader, size int) error { + data := make([]float64, size) + br := NewLimitedBufferReader(r, size, 8, 512) + for i := 0; i < size; i++ { + bytes, err := br.ReadNext() + if err != nil { + return err + } + data[i] = math.Float64frombits(binary.LittleEndian.Uint64(bytes)) + } + s.Data = data + return nil +} + +func (s *DoubleStorage) GetData() interface{} { + return s.Data +} + +func (s *DoubleStorage) DType() gotch.DType { + return gotch.Double +} + +func (s *DoubleStorage) Device() gotch.Device { + switch s.Location { + case "cuda": + return gotch.CudaIfAvailable() + default: + return gotch.CPU + } +} + +// CharStorage: +// ============ + +type CharStorageClass struct{} + +var _ StorageClass = &CharStorageClass{} + +func (s *CharStorageClass) New(size int, location string) Storage { + return &CharStorage{ + BaseStorage: BaseStorage{Size: size, Location: location}, + Data: nil, + } +} + +type CharStorage struct { + BaseStorage + Data []int8 +} + +var _ Storage = &CharStorage{} + +func (s *CharStorage) SetFromFile(r io.Reader) error { + return setFromFile(s, r) +} + +func (s *CharStorage) SetFromFileWithSize(r io.Reader, size int) error { + data := make([]int8, size) + br := NewLimitedBufferReader(r, size, 1, 512) + for i := 0; i < size; i++ { + bytes, err := br.ReadNext() + if err != nil { + return err + } + data[i] = int8(bytes[0]) + } + s.Data = data + return nil +} + +func (s *CharStorage) GetData() interface{} { + return s.Data +} + +func (s *CharStorage) DType() gotch.DType { + return gotch.Int8 +} + +func (s *CharStorage) Device() gotch.Device { + switch s.Location { + case "cuda": + return gotch.CudaIfAvailable() + default: + return gotch.CPU + } +} + +// ShortStorage: +// ============= + +type ShortStorageClass struct{} + +var _ StorageClass = &ShortStorageClass{} + +func (s *ShortStorageClass) New(size int, location string) Storage { + return &ShortStorage{ + BaseStorage: BaseStorage{Size: size, Location: location}, + Data: nil, + } +} + +type ShortStorage struct { + BaseStorage + Data []int16 +} + +var _ Storage = &ShortStorage{} + +func (s *ShortStorage) SetFromFile(r io.Reader) error { + return setFromFile(s, r) +} + +func (s *ShortStorage) SetFromFileWithSize(r io.Reader, size int) error { + data := make([]int16, size) + br := NewLimitedBufferReader(r, size, 2, 512) + for i := 0; i < size; i++ { + bytes, err := br.ReadNext() + if err != nil { + return err + } + data[i] = int16(binary.LittleEndian.Uint16(bytes)) + } + s.Data = data + return nil +} + +func (s *ShortStorage) GetData() interface{} { + return s.Data +} + +func (s *ShortStorage) DType() gotch.DType { + return gotch.Int16 +} + +func (s *ShortStorage) Device() gotch.Device { + switch s.Location { + case "cuda": + return gotch.CudaIfAvailable() + default: + return gotch.CPU + } +} + +// IntStorage: +// =========== + +type IntStorageClass struct{} + +var _ StorageClass = &IntStorageClass{} + +func (s *IntStorageClass) New(size int, location string) Storage { + return &IntStorage{ + BaseStorage: BaseStorage{Size: size, Location: location}, + Data: nil, + } +} + +type IntStorage struct { + BaseStorage + Data []int32 +} + +var _ Storage = &IntStorage{} + +func (s *IntStorage) SetFromFile(r io.Reader) error { + return setFromFile(s, r) +} + +func (s *IntStorage) SetFromFileWithSize(r io.Reader, size int) error { + data := make([]int32, size) + br := NewLimitedBufferReader(r, size, 4, 512) + for i := 0; i < size; i++ { + bytes, err := br.ReadNext() + if err != nil { + return err + } + data[i] = int32(binary.LittleEndian.Uint32(bytes)) + } + s.Data = data + return nil +} + +func (s *IntStorage) GetData() interface{} { + return s.Data +} + +func (s *IntStorage) DType() gotch.DType { + return gotch.Int +} + +func (s *IntStorage) Device() gotch.Device { + switch s.Location { + case "cuda": + return gotch.CudaIfAvailable() + default: + return gotch.CPU + } +} + +// LongStorage: +// ============ + +type LongStorageClass struct{} + +var _ StorageClass = &LongStorageClass{} + +func (s *LongStorageClass) New(size int, location string) Storage { + return &LongStorage{ + BaseStorage: BaseStorage{Size: size, Location: location}, + Data: nil, + } +} + +type LongStorage struct { + BaseStorage + Data []int64 +} + +var _ Storage = &LongStorage{} + +func (s *LongStorage) SetFromFile(r io.Reader) error { + return setFromFile(s, r) +} + +func (s *LongStorage) SetFromFileWithSize(r io.Reader, size int) error { + data := make([]int64, size) + br := NewLimitedBufferReader(r, size, 8, 512) + for i := 0; i < size; i++ { + bytes, err := br.ReadNext() + if err != nil { + return err + } + data[i] = int64(binary.LittleEndian.Uint64(bytes)) + } + s.Data = data + return nil +} + +func (s *LongStorage) GetData() interface{} { + return s.Data +} + +func (s *LongStorage) DType() gotch.DType { + return gotch.Int64 +} + +func (s *LongStorage) Device() gotch.Device { + switch s.Location { + case "cuda": + return gotch.CudaIfAvailable() + default: + return gotch.CPU + } +} + +// ByteStorage: +// ============ + +type ByteStorageClass struct{} + +var _ StorageClass = &ByteStorageClass{} + +func (s *ByteStorageClass) New(size int, location string) Storage { + return &ByteStorage{ + BaseStorage: BaseStorage{Size: size, Location: location}, + Data: nil, + } +} + +type ByteStorage struct { + BaseStorage + Data []uint8 +} + +var _ Storage = &ByteStorage{} + +func (s *ByteStorage) SetFromFile(r io.Reader) error { + return setFromFile(s, r) +} + +func (s *ByteStorage) SetFromFileWithSize(r io.Reader, size int) error { + data := make([]uint8, size) + br := NewLimitedBufferReader(r, size, 1, 512) + for i := 0; i < size; i++ { + bytes, err := br.ReadNext() + if err != nil { + return err + } + data[i] = bytes[0] + } + s.Data = data + return nil +} + +func (s *ByteStorage) GetData() interface{} { + return s.Data +} + +func (s *ByteStorage) DType() gotch.DType { + return gotch.Uint8 +} + +func (s *ByteStorage) Device() gotch.Device { + switch s.Location { + case "cuda": + return gotch.CudaIfAvailable() + default: + return gotch.CPU + } +} + +// BoolStorage: +// ============ + +type BoolStorageClass struct{} + +var _ StorageClass = &BoolStorageClass{} + +func (s *BoolStorageClass) New(size int, location string) Storage { + return &BoolStorage{ + BaseStorage: BaseStorage{Size: size, Location: location}, + Data: nil, + } +} + +type BoolStorage struct { + BaseStorage + Data []bool +} + +var _ Storage = &BoolStorage{} + +func (s *BoolStorage) SetFromFile(r io.Reader) error { + return setFromFile(s, r) +} + +func (s *BoolStorage) SetFromFileWithSize(r io.Reader, size int) error { + data := make([]bool, size) + br := NewLimitedBufferReader(r, size, 1, 512) + for i := 0; i < size; i++ { + bytes, err := br.ReadNext() + if err != nil { + return err + } + data[i] = bytes[0] == 1 + } + s.Data = data + return nil +} + +func (s *BoolStorage) GetData() interface{} { + return s.Data +} + +func (s *BoolStorage) DType() gotch.DType { + return gotch.Float +} + +func (s *BoolStorage) Device() gotch.Device { + switch s.Location { + case "cuda": + return gotch.CudaIfAvailable() + default: + return gotch.CPU + } +} + +func setFromFile(s Storage, r io.Reader) error { + sizeBuf := make([]byte, 8) + _, err := r.Read(sizeBuf) + if err != nil { + return err + } + size := int(binary.LittleEndian.Uint64(sizeBuf)) + return s.SetFromFileWithSize(r, size) +} + +// StorageTensor: +//=============== +type StorageTensor struct { + Source Storage + StorageOffset int64 + Size []int64 + Stride []int64 + RequiresGrad bool +} + +// Rebuild Tensor: +// =============== +// ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L132 +// ref. def _rebuild_tensor(storage, storage_offset, size, stride): + +type RebuildTensor struct{} + +var _ Callable = &RebuildTensor{} + +func (r *RebuildTensor) Call(args ...interface{}) (interface{}, error) { + if len(args) != 4 { + return nil, fmt.Errorf("RebuildTensor.Call() failed. Expected 4 args, got %d: %#v", len(args), args) + } + + storage, storageOk := args[0].(Storage) + storageOffset, storageOffsetOk := args[1].(int) + size, sizeOk := args[2].(*Tuple) + stride, strideOk := args[3].(*Tuple) + if !storageOk || !storageOffsetOk || !sizeOk || !strideOk { + return nil, fmt.Errorf("RebuildTensor.Call() unexpected args: %#v", args) + } + + tensor := &StorageTensor{ + Source: storage, + StorageOffset: int64(storageOffset), + RequiresGrad: false, + } + var err error + tensor.Size, err = tupleToInt64Slice(size) + if err != nil { + return nil, err + } + tensor.Stride, err = tupleToInt64Slice(stride) + if err != nil { + return nil, err + } + return tensor, nil +} + +// RebuildTensorV2 represents a struct to rebuild tensor back from pickle object. +type RebuildTensorV2 struct{} + +var _ Callable = &RebuildTensorV2{} + +func (r *RebuildTensorV2) Call(args ...interface{}) (interface{}, error) { + if len(args) != 6 { + return nil, fmt.Errorf("RebuildTensorV2 unexpected args: %#v", args) + } + + storage, storageOk := args[0].(Storage) + storageOffset, storageOffsetOk := args[1].(int) + size, sizeOk := args[2].(*Tuple) + stride, strideOk := args[3].(*Tuple) + requiresGrad, requiresGradOk := args[4].(bool) + // arg[5] "backward hooks" is unused + if !storageOk || !storageOffsetOk || !sizeOk || !strideOk || + !requiresGradOk { + return nil, fmt.Errorf("RebuildTensorV2 unexpected args: %#v", args) + } + + tensor := &StorageTensor{ + Source: storage, + StorageOffset: int64(storageOffset), + RequiresGrad: requiresGrad, + } + var err error + tensor.Size, err = tupleToInt64Slice(size) + if err != nil { + return nil, err + } + tensor.Stride, err = tupleToInt64Slice(stride) + if err != nil { + return nil, err + } + return tensor, nil +} + +func tupleToInt64Slice(tuple *Tuple) ([]int64, error) { + length := tuple.Len() + slice := make([]int64, length) + for i := 0; i < length; i++ { + value, ok := tuple.Get(i).(int) + if !ok { + return nil, fmt.Errorf("tuple of ints expected: %#v", tuple) + } + slice[i] = int64(value) + } + return slice, nil +} diff --git a/pickle/type.go b/pickle/type.go new file mode 100644 index 0000000..14f7f4e --- /dev/null +++ b/pickle/type.go @@ -0,0 +1,519 @@ +package pickle + +import ( + "container/list" + "fmt" + "reflect" +) + +// This file contains custom types and interfaces for casting Python data types from Pickle. +// +// Custom Types: +// ============= +// 1. ByteArray +// 2. Dict +// 3. Tuple +// 4. OrderedDict +// 5. List +// 6. Set +// 7. FrozenSet +// 8. Object +// 9. Reconstructor +// 10. GenericClass + +// Interfaces: +// =========== + +// Callable is implemented by any value that can be directly called to get a +// new value. +// +// It is usually implemented by Python-like functions (returning a value +// given some arguments), or classes (typically returning an instance given +// some constructor arguments). +type Callable interface { + // Call mimics a direct invocation on a Python value, such as a function + // or class (constructor). + Call(args ...interface{}) (interface{}, error) +} + +// PyNewable is implemented by any value that has a Python-like +// "__new__" method. +// +// It is usually implemented by values representing Python classes. +type PyNewable interface { + // PyNew mimics Python invocation of the "__new__" method, usually + // provided by classes. + // + // See: https://docs.python.org/3/reference/datamodel.html#object.__new__ + PyNew(args ...interface{}) (interface{}, error) +} + +// PyStateSettable is implemented by any value that has a Python-like +// "__setstate__" method. +type PyStateSettable interface { + // PySetState mimics Python invocation of the "__setstate__" method. + // + // See: https://docs.python.org/3/library/pickle.html#object.__setstate__ + PySetState(state interface{}) error +} + +// PyDictSettable is implemented by any value that can store dictionary-like +// key/value pairs. It reflects Python behavior of setting a key/value pair on +// an object's "__dict__" attribute. +type PyDictSettable interface { + // PyDictSet mimics the setting of a key/value pair on an object's + //"__dict__" attribute. + // + // See: https://docs.python.org/3/library/stdtypes.html#object.__dict__ + PyDictSet(key, value interface{}) error +} + +// PyAttrSettable is implemented by any value on which an existing or new +// Python-like attribute can be set. In Python this is done with "setattr" +// builtin function. +type PyAttrSettable interface { + // PySetAttr mimics the setting of an arbitrary value to an object's + // attribute. + // + // In Python this is done with "setattr" function, to which object, + // attribute name, and value are passed. For an easy and clear + // implementation, here instead we require this method to be implemented + // on the "object" itself. + // + // See: https://docs.python.org/3/library/functions.html#setattr + PySetAttr(key string, value interface{}) error +} + +// ByteArray: +//=========== + +// ByteArray simulates Python bytearray. +type ByteArray []byte + +func NewByteArray() *ByteArray { + arr := make(ByteArray, 0) + return &arr +} + +func NewByteArrayFromSlice(slice []byte) *ByteArray { + arr := ByteArray(slice) + return &arr +} + +func (a *ByteArray) Get(i int) byte { + return (*a)[i] +} + +func (a *ByteArray) Len() int { + return len(*a) +} + +// Dict: +//====== + +// DictSetter is implemented by any value that exhibits a dict-like behaviour, +// allowing arbitrary key/value pairs to be set. +type DictSetter interface { + Set(key, value interface{}) +} + +// Dict represents a Python "dict" (builtin type). +// +// It is implemented as a slice, instead of a map, because in Go not +// all types can be map's keys (e.g. slices). +type Dict []*DictEntry + +type DictEntry struct { + Key interface{} + Value interface{} +} + +// NewDict makes and returns a new empty Dict. +func NewDict() *Dict { + d := make(Dict, 0) + return &d +} + +// Set sets into the Dict the given key/value pair. +func (d *Dict) Set(key, value interface{}) { + *d = append(*d, &DictEntry{ + Key: key, + Value: value, + }) +} + +// Get returns the value associated with the given key (if any), and whether +// the key is present or not. +func (d *Dict) Get(key interface{}) (interface{}, bool) { + for _, entry := range *d { + if reflect.DeepEqual(entry.Key, key) { + return entry.Value, true + } + } + return nil, false +} + +// MustGet returns the value associated with the given key, if if it exists, +// otherwise it panics. +func (d *Dict) MustGet(key interface{}) interface{} { + value, ok := d.Get(key) + if !ok { + panic(fmt.Errorf("key not found in Dict: %#v", key)) + } + return value +} + +// Len returns the length of the Dict, that is, the amount of key/value pairs +// contained by the Dict. +func (d *Dict) Len() int { + return len(*d) +} + +var _ DictSetter = &Dict{} + +// Tuple: +// ====== + +type Tuple []interface{} + +func NewTupleFromSlice(slice []interface{}) *Tuple { + t := Tuple(slice) + return &t +} + +func (t *Tuple) Get(i int) interface{} { + return (*t)[i] +} + +func (t *Tuple) Len() int { + return len(*t) +} + +// OrderedDict: +// ============ + +// OrderedDictClass represent Python "collections.OrderedDict" class. +// +// This class allows the indirect creation of OrderedDict objects. +type OrderedDictClass struct{} + +var _ Callable = &OrderedDictClass{} + +// Call returns a new empty OrderedDict. It is equivalent to Python +// constructor "collections.OrderedDict()". +// +// No arguments are supported. +func (*OrderedDictClass) Call(args ...interface{}) (interface{}, error) { + if len(args) != 0 { + return nil, fmt.Errorf( + "OrderedDictClass.Call args not supported: %#v", args) + } + return NewOrderedDict(), nil +} + +// OrderedDict is a minimal and trivial implementation of an ordered map, +// which represent a Python "collections.OrderedDict" object. +// +// It is composed by a simple unordered Map, and a List to keep the order of +// the entries. The former is useful for direct key lookups, the latter for +// iteration. +type OrderedDict struct { + // Map associates a key of any type (interface{}) to OrderedDictEntry + // pointer values. These values are shared with List. + Map map[interface{}]*OrderedDictEntry + // List is an ordered list of OrderedDictEntry pointers, which are + // also shared with Map. + List *list.List + // PyDict represents Python "object.__dict__" dictionary of attributes. + PyDict map[string]interface{} +} + +var _ DictSetter = &OrderedDict{} +var _ PyDictSettable = &OrderedDict{} + +// OrderedDictEntry is a single key/value pair stored in an OrderedDict. +// +// A pointer to an OrderedDictEntry is always shared between OrderedDict's Map +// and List. +type OrderedDictEntry struct { + // Key of a single OrderedDict's entry. + Key interface{} + // Value of a single OrderedDict's entry. + Value interface{} + // ListElement is a pointer to the OrderedDict's List Element which + // contains this very OrderedDictEntry. + ListElement *list.Element +} + +// NewOrderedDict makes and returns a new empty OrderedDict. +func NewOrderedDict() *OrderedDict { + return &OrderedDict{ + Map: make(map[interface{}]*OrderedDictEntry), + List: list.New(), + PyDict: make(map[string]interface{}), + } +} + +// Set sets into the OrderedDict the given key/value pair. If the key does not +// exist yet, the new pair is positioned at the end (back) of the OrderedDict. +// If the key already exists, the existing associated value is replaced with the +// new one, and the original position is maintained. +func (o *OrderedDict) Set(k, v interface{}) { + if entry, ok := o.Map[k]; ok { + entry.Value = v + return + } + + entry := &OrderedDictEntry{ + Key: k, + Value: v, + } + entry.ListElement = o.List.PushBack(entry) + o.Map[k] = entry +} + +// Get returns the value associated with the given key (if any), and whether +// the key is present or not. +func (o *OrderedDict) Get(k interface{}) (interface{}, bool) { + entry, ok := o.Map[k] + if !ok { + return nil, false + } + return entry.Value, true +} + +// MustGet returns the value associated with the given key, if if it exists, +// otherwise it panics. +func (o *OrderedDict) MustGet(key interface{}) interface{} { + value, ok := o.Get(key) + if !ok { + panic(fmt.Errorf("key not found in OrderedDict: %#v", key)) + } + return value +} + +// Len returns the length of the OrderedDict, that is, the amount of key/value +// pairs contained by the OrderedDict. +func (o *OrderedDict) Len() int { + return len(o.Map) +} + +// PyDictSet mimics the setting of a key/value pair on Python "__dict__" +// attribute of the OrderedDict. +func (o *OrderedDict) PyDictSet(key, value interface{}) error { + sKey, keyOk := key.(string) + if !keyOk { + return fmt.Errorf( + "OrderedDict.PyDictSet() requires string key: %#v", key) + } + o.PyDict[sKey] = value + return nil +} + +// List: +// ===== + +// ListAppender is implemented by any value that exhibits a list-like +// behaviour, allowing arbitrary values to be appended. +type ListAppender interface { + Append(v interface{}) +} + +// List represents a Python "list" (builtin type). +type List []interface{} + +var _ ListAppender = &List{} + +// NewList makes and returns a new empty List. +func NewList() *List { + l := make(List, 0) + return &l +} + +// NewListFromSlice makes and returns a new List initialized with the elements +// of the given slice. +// +// The new List is a simple type cast of the input slice; the slice is _not_ +// copied. +func NewListFromSlice(slice []interface{}) *List { + l := List(slice) + return &l +} + +// Append appends one element to the end of the List. +func (l *List) Append(v interface{}) { + *l = append(*l, v) +} + +// Get returns the element of the List at the given index. +// +// It panics if the index is out of range. +func (l *List) Get(i int) interface{} { + return (*l)[i] +} + +// Len returns the length of the List. +func (l *List) Len() int { + return len(*l) +} + +// Set: +// ==== + +// SetAdder is implemented by any value that exhibits a set-like behaviour, +// allowing arbitrary values to be added. +type SetAdder interface { + Add(v interface{}) +} + +// Set represents a Python "set" (builtin type). +// +// It is implemented in Go as a map with empty struct values; the actual set +// of generic "interface{}" items is thus represented by all the keys. +type Set map[interface{}]setEmptyStruct + +var _ SetAdder = &Set{} + +type setEmptyStruct struct{} + +// NewSet makes and returns a new empty Set. +func NewSet() *Set { + s := make(Set) + return &s +} + +// NewSetFromSlice makes and returns a new Set initialized with the elements +// of the given slice. +func NewSetFromSlice(slice []interface{}) *Set { + s := make(Set, len(slice)) + for _, item := range slice { + s[item] = setEmptyStruct{} + } + return &s +} + +// Len returns the length of the Set. +func (s *Set) Len() int { + return len(*s) +} + +// Add adds one element to the Set. +func (s *Set) Add(v interface{}) { + (*s)[v] = setEmptyStruct{} +} + +// Has returns whether the given value is present in the Set (true) +// or not (false). +func (s *Set) Has(v interface{}) bool { + _, ok := (*s)[v] + return ok +} + +// FrozenSet: +//=========== + +// FrozenSet represents a Python "frozenset" (builtin type). +// +// It is implemented in Go as a map with empty struct values; the actual set +// of generic "interface{}" items is thus represented by all the keys. +type FrozenSet map[interface{}]frozenSetEmptyStruct + +type frozenSetEmptyStruct struct{} + +// NewFrozenSetFromSlice makes and returns a new FrozenSet initialized +// with the elements of the given slice. +func NewFrozenSetFromSlice(slice []interface{}) *FrozenSet { + f := make(FrozenSet, len(slice)) + for _, item := range slice { + f[item] = frozenSetEmptyStruct{} + } + return &f +} + +// Len returns the length of the FrozenSet. +func (f *FrozenSet) Len() int { + return len(*f) +} + +// Has returns whether the given value is present in the FrozenSet (true) +// or not (false). +func (f *FrozenSet) Has(v interface{}) bool { + _, ok := (*f)[v] + return ok +} + +// Object: +//======== + +type ObjectClass struct{} + +var _ PyNewable = &ObjectClass{} + +func (o *ObjectClass) PyNew(args ...interface{}) (interface{}, error) { + if len(args) == 0 { + return nil, fmt.Errorf("ObjectClass.PyNew called with no arguments") + } + switch class := args[0].(type) { + case PyNewable: + return class.PyNew() + default: + return nil, fmt.Errorf( + "ObjectClass.PyNew unprocessable args: %#v", args) + } +} + +// Reconstructor: +//=============== + +type Reconstructor struct{} + +var _ Callable = &Reconstructor{} + +func (r *Reconstructor) Call(args ...interface{}) (interface{}, error) { + if len(args) < 2 { + return nil, fmt.Errorf("Reconstructor: invalid arguments: %#v", args) + } + class := args[0] + switch base := args[1].(type) { + case PyNewable: + return base.PyNew(class) + default: + return nil, fmt.Errorf( + "Reconstructor: unprocessable arguments: %#v", args) + } +} + +// GenericClass: +//============== + +type GenericClass struct { + Module string + Name string +} + +var _ PyNewable = &GenericClass{} + +type GenericObject struct { + Class *GenericClass + ConstructorArgs []interface{} +} + +func NewGenericClass(module, name string) *GenericClass { + return &GenericClass{Module: module, Name: name} +} + +func (g *GenericClass) PyNew(args ...interface{}) (interface{}, error) { + return &GenericObject{ + Class: g, + ConstructorArgs: args, + }, nil +} + +// getThnnFunctionBackend is for historical pickle deserilaization, it is not used otherwise +type getThnnFunctionBackend struct{} + +var _ Callable = &getThnnFunctionBackend{} + +func (getThnnFunctionBackend) Call(_ ...interface{}) (interface{}, error) { + return nil, nil +} diff --git a/pickle/util.go b/pickle/util.go new file mode 100644 index 0000000..514d882 --- /dev/null +++ b/pickle/util.go @@ -0,0 +1,112 @@ +package pickle + +import "io" + +// Converts the bits representation of a Half Float (16 bits) number to +// an IEEE 754 float representation (32 bits) +// From http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf +func FloatBits16to32(u16 uint16) uint32 { + return mantissaTable[offsetTable[u16>>10]+(uint32(u16)&0x3ff)] + exponentTable[u16>>10] +} + +var mantissaTable [2048]uint32 +var exponentTable [64]uint32 +var offsetTable [64]uint32 + +func init() { + initMantissaTable() + initExponentTable() + initOffsetTable() +} + +func initMantissaTable() { + mantissaTable[0] = 0 + for i := uint32(1); i < 1024; i++ { + mantissaTable[i] = convertMantissa(i) + } + for i := uint32(1024); i < 2048; i++ { + mantissaTable[i] = 0x38000000 + ((i - 1024) << 13) + } +} + +func initExponentTable() { + exponentTable[0] = 0 + exponentTable[31] = 0x47800000 + exponentTable[32] = 0x80000000 + exponentTable[63] = 0xC7800000 + for i := uint32(1); i < 31; i++ { + exponentTable[i] = i << 23 + } + for i := uint32(33); i < 63; i++ { + exponentTable[i] = 0x80000000 + (i-32)<<23 + } +} + +func initOffsetTable() { + offsetTable[0] = 0 + offsetTable[32] = 0 + for i := uint32(1); i < 31; i++ { + offsetTable[i] = 1024 + } + for i := uint32(32); i < 64; i++ { + offsetTable[i] = 1024 + } +} + +func convertMantissa(i uint32) uint32 { + var m uint32 = i << 13 // zero pad mantissa bits + var e uint32 = 0 // zero exponent + for m&0x00800000 != 0 { // while not normalized + e -= 0x00800000 // decrement exponent (1 << 23) + m <<= 1 // shift mantissa + } + m &= ^uint32(0x00800000) // clear leading 1 bit + e += 0x38800000 // adjust bias ((127-14)<<23) + return m | e // return combined number +} + +type LimitedBufferReader struct { + r io.Reader + scalarSize int + remainingBytes int + buf []byte + bufIndex int +} + +func NewLimitedBufferReader( + r io.Reader, + dataSize, scalarSize, bufferSize int, +) *LimitedBufferReader { + size := bufferSize * scalarSize + return &LimitedBufferReader{ + r: r, + scalarSize: scalarSize, + remainingBytes: scalarSize * dataSize, + buf: make([]byte, size), + bufIndex: size, + } +} + +func (br *LimitedBufferReader) HasNext() bool { + return br.remainingBytes != 0 +} + +func (br *LimitedBufferReader) ReadNext() ([]byte, error) { + if br.remainingBytes == 0 { + return nil, io.EOF + } + if br.bufIndex == len(br.buf) { + br.bufIndex = 0 + if br.remainingBytes < len(br.buf) { + br.buf = br.buf[0:br.remainingBytes] + } + _, err := br.r.Read(br.buf) + if err != nil { + return nil, err + } + } + result := br.buf[br.bufIndex : br.bufIndex+br.scalarSize] + br.bufIndex += br.scalarSize + br.remainingBytes -= br.scalarSize + return result, nil +}