added subpackage 'pickle' and file-util

This commit is contained in:
sugarme 2022-02-24 12:43:39 +11:00
parent 47c6c60561
commit dc4ab3047c
10 changed files with 4176 additions and 0 deletions

View File

@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
- Added subpackage `pickle`. Now we can load directly Python Pytorch pretrained model without any Python script conversion.
- Added `gotch.CachePath()` and `gotch.ModelUrls`
- Remove Travis CI for now.
- fixed `tensor.OfSlice()` throw error due to "Unsupported Go type" (e.g. []float32)
- added `nn.Path.Paths()` method

36
example/pickle/main.go Normal file
View File

@ -0,0 +1,36 @@
package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
"github.com/sugarme/gotch/pickle"
"github.com/sugarme/gotch/vision"
)
func main() {
device := gotch.CPU
vs := nn.NewVarStore(device)
net := vision.VGG16(vs.Root(), 1000)
modelName := "vgg16"
modelUrl, ok := gotch.ModelUrls[modelName]
if !ok {
log.Fatal("model name %q not found.", modelName)
}
modelFile, err := gotch.CachedPath(modelUrl)
if err != nil {
log.Fatal(err)
}
err = pickle.LoadAll(vs, modelFile)
if err != nil {
log.Fatal(err)
}
fmt.Printf("%v\n", net)
vs.Summary()
}

289
file-util.go Normal file
View File

@ -0,0 +1,289 @@
package gotch
import (
"fmt"
"io"
"log"
"net/http"
"os"
"path"
"strconv"
"strings"
)
// This file provides functions to work with local dataset cache, ...
// ModelUrls maps model name to its pretrained URL.
//
// This URLS taken from separate models in pytorch/vision repository
// https://github.com/pytorch/vision/tree/main/torchvision/models
var ModelUrls map[string]string = map[string]string{
"alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
"convnext_tiny": "https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
"convnext_small": "https://download.pytorch.org/models/convnext_small-0c510722.pth",
"convnext_base": "https://download.pytorch.org/models/convnext_base-6075fbad.pth",
"convnext_large": "https://download.pytorch.org/models/convnext_large-ea097f82.pth",
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
"densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
//Weights ported from https://github.com/rwightman/pytorch-image-models/
"efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
"efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
"efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
"efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
"efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
//Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
"efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
"efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
//GoogLeNet ported from TensorFlow
"googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
//Inception v3 ported from TensorFlow
"inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
"mnasnet0_5": "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
"mnasnet0_75": "",
"mnasnet1_0": "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
"mnasnet1_3": "",
"mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
"mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
"mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
"regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
"regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
"regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
"regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
"regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
"regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
"regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
"regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
"regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
"regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
"regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
"regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
"regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
"regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
"shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
"shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
"shufflenetv2_x1.5": "",
"shufflenetv2_x2.0": "",
"squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
"squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
"vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
"vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
"vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth",
"vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",
"vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",
"vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth",
}
// CachedPath resolves and caches data based on input string, then returns fullpath to the cached data.
//
// Parameters:
// - `filenameOrUrl`: full path to filename or url
//
// CachedPath does several things consequently:
// 1. Resolves input string to a fullpath cached filename candidate.
// 2. Check it at `CachePath`, if exists, then return the candidate. If not
// 3. Retrieves and Caches data to `CachePath` and returns path to cached data
func CachedPath(filenameOrUrl string) (resolvedPath string, err error) {
filename := path.Base(filenameOrUrl)
// Resolves to "candidate" filename at `CacheDir`
cachedFileCandidate := fmt.Sprintf("%s/%s", CacheDir, filename)
// 1. Cached candidate file exists
if _, err := os.Stat(cachedFileCandidate); err == nil {
return cachedFileCandidate, nil
}
// 2. If valid fullpath to local file, caches it and return cached filename
if _, err := os.Stat(filenameOrUrl); err == nil {
err := copyFile(filenameOrUrl, cachedFileCandidate)
if err != nil {
return "", err
}
return cachedFileCandidate, nil
}
// 3. Cached candidate file NOT exist. Try to download it and save to `CacheDir`
if isValidURL(filenameOrUrl) {
if _, err := http.Get(filenameOrUrl); err == nil {
err := downloadFile(filenameOrUrl, cachedFileCandidate)
if err != nil {
return "", err
}
return cachedFileCandidate, nil
} else {
fmt.Printf("Error: %v\n", err)
err = fmt.Errorf("Unable to parse %q as a URL or as a local path.\n", filenameOrUrl)
return "", err
}
}
// Not resolves
err = fmt.Errorf("Unable to parse %q as a URL or as a local path.\n", filenameOrUrl)
return "", err
}
func isValidURL(url string) bool {
// TODO: implement
return true
}
// downloadFile downloads file from URL and stores it in local filepath.
// It writes to the destination file as it downloads it, without loading
// the entire file into memory. An `io.TeeReader` is passed into Copy()
// to report progress on the download.
func downloadFile(url string, filepath string) error {
// Create path if not existing
dir := path.Dir(filepath)
filename := path.Base(filepath)
if _, err := os.Stat(dir); os.IsNotExist(err) {
if err := os.MkdirAll(dir, 0755); err != nil {
log.Fatal(err)
}
}
// Create the file with .tmp extension, so that we won't overwrite a
// file until it's downloaded fully
out, err := os.Create(filepath + ".tmp")
if err != nil {
return err
}
defer out.Close()
// Get the data
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
// Check server response
if resp.StatusCode != http.StatusOK {
err := fmt.Errorf("bad status: %s(%v)", resp.Status, resp.StatusCode)
if resp.StatusCode == 404 {
err = fmt.Errorf("download file not found: %q for downloading", url)
} else {
err = fmt.Errorf("download file failed: %q", url)
}
return err
}
// the total file size to download
size, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
downloadSize := uint64(size)
// Create our bytes counter and pass it to be used alongside our writer
counter := &writeCounter{FileSize: downloadSize}
_, err = io.Copy(out, io.TeeReader(resp.Body, counter))
if err != nil {
return err
}
fmt.Printf("\r%s... %s/%s completed", filename, byteCountIEC(counter.Total), byteCountIEC(counter.FileSize))
// The progress use the same line so print a new line once it's finished downloading
fmt.Println()
// Rename the tmp file back to the original file
err = os.Rename(filepath+".tmp", filepath)
if err != nil {
return err
}
return nil
}
// writeCounter counts the number of bytes written to it. By implementing the Write method,
// it is of the io.Writer interface and we can pass this into io.TeeReader()
// Every write to this writer, will print the progress of the file write.
type writeCounter struct {
Total uint64
FileSize uint64
}
func (wc *writeCounter) Write(p []byte) (int, error) {
n := len(p)
wc.Total += uint64(n)
wc.printProgress()
return n, nil
}
// PrintProgress prints the progress of a file write
func (wc writeCounter) printProgress() {
// Clear the line by using a character return to go back to the start and remove
// the remaining characters by filling it with spaces
fmt.Printf("\r%s", strings.Repeat(" ", 50))
// Return again and print current status of download
fmt.Printf("\rDownloading... %s/%s", byteCountIEC(wc.Total), byteCountIEC(wc.FileSize))
}
// byteCountIEC converts bytes to human-readable string in binary (IEC) format.
func byteCountIEC(b uint64) string {
const unit = 1024
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := uint64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB",
float64(b)/float64(div), "KMGTPE"[exp])
}
func copyFile(src, dst string) error {
sourceFileStat, err := os.Stat(src)
if err != nil {
return err
}
if !sourceFileStat.Mode().IsRegular() {
return fmt.Errorf("%s is not a regular file", src)
}
source, err := os.Open(src)
if err != nil {
return err
}
defer source.Close()
destination, err := os.Create(dst)
if err != nil {
return err
}
defer destination.Close()
_, err = io.Copy(destination, source)
return err
}

35
init.go Normal file
View File

@ -0,0 +1,35 @@
package gotch
import (
"fmt"
"log"
"os"
)
var (
CacheDir string = "NOT_SETTING"
gotchEnvKey string = "GOTCH_CACHE"
)
func init() {
// default path: {$HOME}/.cache/gotch
homeDir := os.Getenv("HOME")
CacheDir = fmt.Sprintf("%s/.cache/transformer", homeDir)
initEnv()
log.Printf("INFO: CacheDir=%q\n", CacheDir)
}
func initEnv() {
val := os.Getenv(gotchEnvKey)
if val != "" {
CacheDir = val
}
if _, err := os.Stat(CacheDir); os.IsNotExist(err) {
if err := os.MkdirAll(CacheDir, 0755); err != nil {
log.Fatal(err)
}
}
}

1937
pickle/pickle.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,60 @@
package pickle_test
import (
"log"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/pickle"
)
func ExampleLoadInfo() {
modelName := "vgg16"
url, ok := gotch.ModelUrls[modelName]
if !ok {
log.Fatalf("Unsupported model name %q\n", modelName)
}
modelFile, err := gotch.CachedPath(url)
if err != nil {
panic(err)
}
err = pickle.LoadInfo(modelFile)
if err != nil {
log.Fatal(err)
}
// Output:
// classifier.0.bias - [4096]
// classifier.0.weight - [4096 25088]
// classifier.3.bias - [4096]
// classifier.3.weight - [4096 4096]
// classifier.6.bias - [1000]
// classifier.6.weight - [1000 4096]
// features.0.bias - [64]
// features.0.weight - [64 3 3 3]
// features.10.bias - [256]
// features.10.weight - [256 128 3 3]
// features.12.bias - [256]
// features.12.weight - [256 256 3 3]
// features.14.bias - [256]
// features.14.weight - [256 256 3 3]
// features.17.bias - [512]
// features.17.weight - [512 256 3 3]
// features.19.bias - [512]
// features.19.weight - [512 512 3 3]
// features.2.bias - [64]
// features.2.weight - [64 64 3 3]
// features.21.bias - [512]
// features.21.weight - [512 512 3 3]
// features.24.bias - [512]
// features.24.weight - [512 512 3 3]
// features.26.bias - [512]
// features.26.weight - [512 512 3 3]
// features.28.bias - [512]
// features.28.weight - [512 512 3 3]
// features.5.bias - [128]
// features.5.weight - [128 64 3 3]
// features.7.bias - [128]
// features.7.weight - [128 128 3 3]
// Num of variables: 32
}

519
pickle/serialization.go Normal file
View File

@ -0,0 +1,519 @@
package pickle
// Ref.
// https://docs.python.org/3/library/pickle.html
// https://docs.python.org/3/library/pickletools.html
// https://github.com/python/cpython/blob/main/Lib/pickle.py (****real code here****)
// https://pytorch.org/tutorials/beginner/saving_loading_models.html
// https://github.com/pytorch/pytorch/blob/master/torch/serialization.py
import (
"archive/tar"
"archive/zip"
"errors"
"fmt"
"io"
"math/big"
"os"
"path"
"reflect"
"sort"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
)
const hexMagicNumber = "1950a86a20f9469cfc6c"
const protocolVersion = 1001
var ErrInvalidMagicNumber = errors.New("invalid pytorch magic number")
var ErrInvalidProtocolVersion = errors.New("invalid pytorch protocol version")
// Encode encodes model using pickling machinery.
// Output pickled model can be loads with Python Pytorch as `torch.load("pytorch_model.bin")`
//
// TODO. implement pickling part so that model can be exported and load with Python Pytorch.
// See https://github.com/python/cpython/blob/b0de6299a840a397d4fe3e6c98159d9f258d3295/Lib/pickle.py#L407
func Encode(model ts.Module, outputFile string) error {
panic("NotImplementedError")
}
// Decode decodes pickled data created by 'torch.save()' with Python Pytorch
// and rebuilds named tensor weights.
func Decode(filename string) (map[string]*ts.Tensor, error) {
newUnpickler := func(r io.Reader) Unpickler {
return NewUnpickler(r)
}
result, err := LoadWithUnpickler(filename, newUnpickler)
if err != nil {
err := fmt.Errorf("Decode() failed: %w", err)
return nil, err
}
// Rebuild tensors from Storage tensors
namedTensors := make(map[string]*ts.Tensor)
dictResult, isOrderedDict := result.(*OrderedDict)
if !isOrderedDict {
err := fmt.Errorf("Decode() failed: expected 'OrderedDict' type, got %v\n", reflect.TypeOf(result).String())
return nil, err
}
for name, item := range dictResult.Map {
sx, isStorageTensor := item.Value.(*StorageTensor)
if !isStorageTensor {
err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
return nil, err
}
data := sx.Source.GetData()
size := sx.Size
dtype := sx.Source.DType()
device := sx.Source.Device()
stride := sx.Stride
storageOffset := sx.StorageOffset
// fmt.Printf("%q - shape: %v - stride: %v - storageOffset: %v\n", sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
if sx.RequiresGrad {
x.MustRequiresGrad_(sx.RequiresGrad)
}
namedTensors[name.(string)] = x
}
return namedTensors, nil
}
// LoadWithUnpickler is like Load, but it accepts a newUnpickler function which
// is used to create new customized pickle.Unpickler instances.
func LoadWithUnpickler(filename string, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) {
if !isZipFile(filename) {
return loadLegacyFile(filename, newUnpickler)
}
return loadZipFile(filename, newUnpickler)
}
func loadZipFile(filename string, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) {
// Open a zip archive for reading.
r, err := zip.OpenReader(filename)
if err != nil {
return nil, err
}
defer r.Close()
fileRecords := make(map[string]*zip.File, len(r.File))
for _, f := range r.File {
_, recordName := path.Split(f.Name)
fileRecords[recordName] = f
}
if _, isTorchScript := fileRecords["constants.pkl"]; isTorchScript {
return nil, fmt.Errorf("TorchScript is not supported")
}
dataFile, hasDataFile := fileRecords["data.pkl"]
if !hasDataFile {
return nil, fmt.Errorf("data.pkl not found in zip file")
}
df, err := dataFile.Open()
if err != nil {
return nil, err
}
defer df.Close()
loadedStorages := make(map[string]Storage)
u := newUnpickler(df)
u.FindClass = makePickleFindClass(u.FindClass)
u.PersistentLoad = func(savedId interface{}) (interface{}, error) {
tuple, tupleOk := savedId.(*Tuple)
if !tupleOk || tuple.Len() == 0 {
return nil, fmt.Errorf("PersistentLoad: non-empty tuple expected, got %#v", savedId)
}
typename, typenameOk := tuple.Get(0).(string)
if !typenameOk {
return nil, fmt.Errorf("PersistentLoad: cannot get typename")
}
if typename != "storage" {
return nil, fmt.Errorf("unknown typename for PersistentLoad, expected 'storage' but got '%s'", typename)
}
if tuple.Len() < 5 {
return nil, fmt.Errorf("PersistentLoad: unexpected storage data length")
}
dataType, dataTypeOk := tuple.Get(1).(StorageClass)
key, keyOk := tuple.Get(2).(string)
location, locationOk := tuple.Get(3).(string)
size, sizeOk := tuple.Get(4).(int)
if !dataTypeOk || !keyOk || !locationOk || !sizeOk {
return nil, fmt.Errorf("PersistentLoad: unexpected data types")
}
storage, storageExists := loadedStorages[key]
if !storageExists {
storage, err = loadTensor(dataType, size, location, key, fileRecords)
if err != nil {
return nil, err
}
loadedStorages[key] = storage
}
return storage, nil
}
return u.Load()
}
func loadTensor(
dataType StorageClass,
size int,
location, key string,
zipFileRecords map[string]*zip.File,
) (Storage, error) {
file, fileOk := zipFileRecords[key]
if !fileOk {
return nil, fmt.Errorf("cannot find zip record '%s'", key)
}
f, err := file.Open()
if err != nil {
return nil, err
}
defer f.Close()
storage := dataType.New(size, location)
err = storage.SetFromFileWithSize(f, size)
return storage, err
}
func loadLegacyFile(filename string, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer f.Close()
tr := tar.NewReader(f)
for {
_, err := tr.Next()
switch err {
case nil:
// TODO: ...
panic("legacy load from tar not implemented")
case io.EOF:
break // End of archive
case tar.ErrHeader, io.ErrUnexpectedEOF:
_, err = f.Seek(0, io.SeekStart)
if err != nil {
return nil, err
}
return loadLegacyNoTar(f, newUnpickler)
default:
return nil, err
}
}
}
func loadLegacyNoTar(f *os.File, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) {
if err := readAndCheckMagicNumber(f); err != nil {
return nil, err
}
if err := readAndChecProtocolVersion(f); err != nil {
return nil, err
}
if _, err := unpickle(f); err != nil { // sys info
return nil, err
}
deserializedObjects := make(map[string]Storage)
u := newUnpickler(f)
u.FindClass = makePickleFindClass(u.FindClass)
u.PersistentLoad = func(savedId interface{}) (interface{}, error) {
tuple, tupleOk := savedId.(*Tuple)
if !tupleOk || tuple.Len() == 0 {
return nil, fmt.Errorf("PersistentLoad: non-empty tuple expected, got %#v", savedId)
}
typename, typenameOk := tuple.Get(0).(string)
if !typenameOk {
return nil, fmt.Errorf("PersistentLoad: cannot get typename")
}
// fmt.Printf("typename: %s\n", typename)
switch typename {
case "storage":
if tuple.Len() < 6 {
return nil, fmt.Errorf(
"PersistentLoad: unexpected storage data length")
}
dataType, dataTypeOk := tuple.Get(1).(StorageClass)
rootKey, rootKeyOk := tuple.Get(2).(string)
location, locationOk := tuple.Get(3).(string)
size, sizeOk := tuple.Get(4).(int)
viewMetadata := tuple.Get(5)
if !dataTypeOk || !rootKeyOk || !locationOk || !sizeOk {
return nil, fmt.Errorf("PersistentLoad: unexpected data types")
}
// fmt.Printf("dtype: %v - rootKey: %v - device: %v - size %v - viewMetaData: %v\n", reflect.TypeOf(dataType), rootKey, location, size, viewMetadata)
storage, storageExists := deserializedObjects[rootKey]
if !storageExists {
storage = dataType.New(size, location)
deserializedObjects[rootKey] = storage
}
switch vm := viewMetadata.(type) {
case nil:
return storage, nil
case []interface{}:
if len(vm) != 3 {
return nil, fmt.Errorf(
"PersistentLoad: unexpected view metadata length")
}
panic("viewMetadata not implemented")
// TODO: ...
// view_key, offset, view_size = view_metadata
// if view_key not in deserialized_objects:
// deserialized_objects[view_key] = storage[offset:offset + view_size]
// return deserialized_objects[view_key]
default:
return nil, fmt.Errorf("PersistentLoad: unexpected view metadata type")
}
case "module":
if tuple.Len() < 2 {
return nil, fmt.Errorf("PersistentLoad: unexpected module data length")
}
return tuple.Get(1), nil
default:
return nil, fmt.Errorf("Unexpected saved ID type: %s", typename)
}
}
result, err := u.Load()
if err != nil {
return nil, err
}
rawStorageKeys, err := unpickle(f)
if err != nil {
return nil, err
}
storageKeys, err := makeStorageKeys(rawStorageKeys)
if err != nil {
return nil, err
}
for _, key := range storageKeys {
storageObj, ok := deserializedObjects[key]
if !ok {
return nil, fmt.Errorf("storage object not found for key '%s'", key)
}
err = storageObj.SetFromFile(f)
if err != nil {
return nil, err
}
}
return result, nil
}
func makeStorageKeys(obj interface{}) ([]string, error) {
list, ok := obj.(*List)
if !ok {
return nil, fmt.Errorf("invalid storage keys data")
}
keys := make([]string, len(*list))
for i, rawKey := range *list {
key, keyOk := rawKey.(string)
if !keyOk {
return nil, fmt.Errorf("invalid storage key")
}
keys[i] = key
}
return keys, nil
}
func readAndCheckMagicNumber(r io.Reader) error {
obj, err := unpickle(r)
if err != nil {
return err
}
if n, ok := obj.(*big.Int); !ok || n.Text(16) != hexMagicNumber {
return ErrInvalidMagicNumber
}
return nil
}
func readAndChecProtocolVersion(r io.Reader) error {
obj, err := unpickle(r)
if err != nil {
return err
}
if n, ok := obj.(int); !ok || n != protocolVersion {
return ErrInvalidProtocolVersion
}
return nil
}
func unpickle(r io.Reader) (interface{}, error) {
u := NewUnpickler(r)
return u.Load()
}
func isZipFile(filename string) bool {
r, err := zip.OpenReader(filename)
if err != nil {
return false
}
r.Close()
return true
}
func makePickleFindClass(fallback func(module, name string) (interface{}, error)) func(module, name string) (interface{}, error) {
return func(module, name string) (interface{}, error) {
switch module + "." + name {
case "torch._utils._rebuild_tensor":
return &RebuildTensor{}, nil
case "torch._utils._rebuild_tensor_v2":
return &RebuildTensorV2{}, nil
case "torch.FloatStorage":
return &FloatStorageClass{}, nil
case "torch.HalfStorage":
return &HalfStorageClass{}, nil
case "torch.DoubleStorage":
return &DoubleStorageClass{}, nil
case "torch.CharStorage":
return &CharStorageClass{}, nil
case "torch.ShortStorage":
return &ShortStorageClass{}, nil
case "torch.IntStorage":
return &IntStorageClass{}, nil
case "torch.LongStorage":
return &LongStorageClass{}, nil
case "torch.ByteStorage":
return &ByteStorageClass{}, nil
case "torch.BoolStorage":
return &BoolStorageClass{}, nil
case "torch.nn.backends.thnn._get_thnn_function_backend":
// this is for historical pickle deserilaization, it is not used otherwise
return getThnnFunctionBackend{}, nil
default:
if fallback == nil {
return nil, fmt.Errorf("class not found: %s %s", module, name)
}
return fallback(module, name)
}
}
}
// LoadAll finds and loads all weights from varstore.
// It will throw err if one of weights from varstore cannot find from loaded pretrained model.
func LoadAll(vs *nn.VarStore, modelFile string) error {
weights, err := Decode(modelFile)
if err != nil {
err = fmt.Errorf("LoadAll() failed: %w", err)
return err
}
// for tsName, _ := range vs.Vars.NamedVariables {
for tsName := range vs.Vars.NamedVariables {
// missing variable
currTs, ok := weights[tsName]
if !ok {
err = fmt.Errorf("LoadAll() failed: Cannot find tensor with name: %v in variable store. \n", tsName)
return err
}
// mismatched shape
sourceShape := currTs.MustSize()
destShape := vs.Vars.NamedVariables[tsName].MustSize()
if !reflect.DeepEqual(destShape, sourceShape) {
err = fmt.Errorf("LoadAll() failed: Mismatched shape error for variable name: %v - At store: %v - At source %v\n", tsName, destShape, sourceShape)
return err
}
ts.NoGrad(func() {
vs.Vars.NamedVariables[tsName].Copy_(currTs)
})
}
for _, x := range weights {
x.MustDrop()
}
return nil
}
// LoadPartial finds and loads weights for varstore.
// It returns list of unfound weight names.
func LoadPartial(vs *nn.VarStore, modelFile string) ([]string, error) {
weights, err := Decode(modelFile)
if err != nil {
err = fmt.Errorf("LoadPartial() failed: %w", err)
return nil, err
}
var missingVariables []string
// Match and in-place copy value (update) from newly loaded tensors
// to existing named tensors if name is matched. Throw error otherwise.
for tsName := range vs.Vars.NamedVariables {
var currTs *ts.Tensor
var ok bool
// missing variable
if currTs, ok = weights[tsName]; !ok {
missingVariables = append(missingVariables, tsName)
continue
}
// mismatched shape
destShape := currTs.MustSize()
sourceShape := vs.Vars.NamedVariables[tsName].MustSize()
if !reflect.DeepEqual(destShape, sourceShape) {
fmt.Printf("WARNING: Mismatched shape error for variable name: %v - At store: %v - At source %v. Skip loading this weight...\n", tsName, destShape, sourceShape)
missingVariables = append(missingVariables, tsName)
continue
}
ts.NoGrad(func() {
vs.Vars.NamedVariables[tsName].Copy_(currTs)
})
}
for _, x := range weights {
x.MustDrop()
}
return missingVariables, nil
}
// LoadInfo loads pretrained weights and prints out name and shape of weights.
func LoadInfo(modelFile string) error {
weights, err := Decode(modelFile)
if err != nil {
err = fmt.Errorf("LoadInfo() failed: %w", err)
return err
}
layers := make([]string, 0, len(weights))
for tsName := range weights {
layers = append(layers, tsName)
}
sort.Strings(layers)
for _, l := range layers {
var x *ts.Tensor
for tsName, tsVal := range weights {
if tsName == l {
x = tsVal
break
}
}
fmt.Printf("%s - %+v\n", l, x.MustSize())
}
fmt.Printf("Num of variables: %v\n", len(weights))
for _, x := range weights {
x.MustDrop()
}
return nil
}

667
pickle/storage.go Normal file
View File

@ -0,0 +1,667 @@
package pickle
import (
"encoding/binary"
"fmt"
"io"
"math"
"github.com/sugarme/gotch"
)
// This file implements Pytorch storage data types.
// ref: https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/storage.py
/*
torch.double: 'DoubleStorage',
torch.float: 'FloatStorage',
torch.half: 'HalfStorage',
torch.long: 'LongStorage',
torch.int: 'IntStorage',
torch.int16: 'ShortStorage',
torch.int8: 'CharStorage',
torch.uint8: 'ByteStorage',
torch.bool: 'BoolStorage',
torch.bfloat16: 'BFloat16Storage',
torch.cdouble: 'ComplexDoubleStorage',
torch.cfloat: 'ComplexFloatStorage',
torch.qint8: 'QInt8Storage',
torch.qint32: 'QInt32Storage',
torch.quint8: 'QUInt8Storage',
torch.quint4x2: 'QUInt4x2Storage',
torch.quint2x4: 'QUInt2x4Storage',
*/
// StorageClass defines interface for types to be used in Storage.
type StorageClass interface {
New(size int, location string) Storage
}
// Storage define Storage interface.
type Storage interface {
SetFromFile(r io.Reader) error
SetFromFileWithSize(r io.Reader, size int) error
DType() gotch.DType
GetData() interface{}
Device() gotch.Device
}
// BaseStorage represents a base storage.
type BaseStorage struct {
Size int
Location string
}
// HalfStorage:
// ============
type HalfStorageClass struct{}
var _ StorageClass = &HalfStorageClass{}
func (s *HalfStorageClass) New(size int, location string) Storage {
return &HalfStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type HalfStorage struct {
BaseStorage
Data []float32
}
var _ Storage = &HalfStorage{}
func (s *HalfStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *HalfStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]float32, size)
br := NewLimitedBufferReader(r, size, 2, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
u16 := binary.LittleEndian.Uint16(bytes)
data[i] = math.Float32frombits(FloatBits16to32(u16))
}
s.Data = data
return nil
}
func (s *HalfStorage) GetData() interface{} {
return s.Data
}
func (s *HalfStorage) DType() gotch.DType {
return gotch.Float
}
func (s *HalfStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// FloatStorage:
// =============
type FloatStorageClass struct{}
var _ StorageClass = &FloatStorageClass{}
func (s *FloatStorageClass) New(size int, location string) Storage {
return &FloatStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type FloatStorage struct {
BaseStorage
Data []float32
}
var _ Storage = &FloatStorage{}
func (s *FloatStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *FloatStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]float32, size)
br := NewLimitedBufferReader(r, size, 4, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
data[i] = math.Float32frombits(binary.LittleEndian.Uint32(bytes))
}
s.Data = data
return nil
}
func (s *FloatStorage) GetData() interface{} {
return s.Data
}
func (s *FloatStorage) DType() gotch.DType {
return gotch.Float
}
func (s *FloatStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// DoubleStorage:
// ==============
type DoubleStorageClass struct{}
var _ StorageClass = &DoubleStorageClass{}
func (s *DoubleStorageClass) New(size int, location string) Storage {
return &DoubleStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type DoubleStorage struct {
BaseStorage
Data []float64
}
var _ Storage = &DoubleStorage{}
func (s *DoubleStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *DoubleStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]float64, size)
br := NewLimitedBufferReader(r, size, 8, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
data[i] = math.Float64frombits(binary.LittleEndian.Uint64(bytes))
}
s.Data = data
return nil
}
func (s *DoubleStorage) GetData() interface{} {
return s.Data
}
func (s *DoubleStorage) DType() gotch.DType {
return gotch.Double
}
func (s *DoubleStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// CharStorage:
// ============
type CharStorageClass struct{}
var _ StorageClass = &CharStorageClass{}
func (s *CharStorageClass) New(size int, location string) Storage {
return &CharStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type CharStorage struct {
BaseStorage
Data []int8
}
var _ Storage = &CharStorage{}
func (s *CharStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *CharStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]int8, size)
br := NewLimitedBufferReader(r, size, 1, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
data[i] = int8(bytes[0])
}
s.Data = data
return nil
}
func (s *CharStorage) GetData() interface{} {
return s.Data
}
func (s *CharStorage) DType() gotch.DType {
return gotch.Int8
}
func (s *CharStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// ShortStorage:
// =============
type ShortStorageClass struct{}
var _ StorageClass = &ShortStorageClass{}
func (s *ShortStorageClass) New(size int, location string) Storage {
return &ShortStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type ShortStorage struct {
BaseStorage
Data []int16
}
var _ Storage = &ShortStorage{}
func (s *ShortStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *ShortStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]int16, size)
br := NewLimitedBufferReader(r, size, 2, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
data[i] = int16(binary.LittleEndian.Uint16(bytes))
}
s.Data = data
return nil
}
func (s *ShortStorage) GetData() interface{} {
return s.Data
}
func (s *ShortStorage) DType() gotch.DType {
return gotch.Int16
}
func (s *ShortStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// IntStorage:
// ===========
type IntStorageClass struct{}
var _ StorageClass = &IntStorageClass{}
func (s *IntStorageClass) New(size int, location string) Storage {
return &IntStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type IntStorage struct {
BaseStorage
Data []int32
}
var _ Storage = &IntStorage{}
func (s *IntStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *IntStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]int32, size)
br := NewLimitedBufferReader(r, size, 4, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
data[i] = int32(binary.LittleEndian.Uint32(bytes))
}
s.Data = data
return nil
}
func (s *IntStorage) GetData() interface{} {
return s.Data
}
func (s *IntStorage) DType() gotch.DType {
return gotch.Int
}
func (s *IntStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// LongStorage:
// ============
type LongStorageClass struct{}
var _ StorageClass = &LongStorageClass{}
func (s *LongStorageClass) New(size int, location string) Storage {
return &LongStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type LongStorage struct {
BaseStorage
Data []int64
}
var _ Storage = &LongStorage{}
func (s *LongStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *LongStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]int64, size)
br := NewLimitedBufferReader(r, size, 8, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
data[i] = int64(binary.LittleEndian.Uint64(bytes))
}
s.Data = data
return nil
}
func (s *LongStorage) GetData() interface{} {
return s.Data
}
func (s *LongStorage) DType() gotch.DType {
return gotch.Int64
}
func (s *LongStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// ByteStorage:
// ============
type ByteStorageClass struct{}
var _ StorageClass = &ByteStorageClass{}
func (s *ByteStorageClass) New(size int, location string) Storage {
return &ByteStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type ByteStorage struct {
BaseStorage
Data []uint8
}
var _ Storage = &ByteStorage{}
func (s *ByteStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *ByteStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]uint8, size)
br := NewLimitedBufferReader(r, size, 1, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
data[i] = bytes[0]
}
s.Data = data
return nil
}
func (s *ByteStorage) GetData() interface{} {
return s.Data
}
func (s *ByteStorage) DType() gotch.DType {
return gotch.Uint8
}
func (s *ByteStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// BoolStorage:
// ============
type BoolStorageClass struct{}
var _ StorageClass = &BoolStorageClass{}
func (s *BoolStorageClass) New(size int, location string) Storage {
return &BoolStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type BoolStorage struct {
BaseStorage
Data []bool
}
var _ Storage = &BoolStorage{}
func (s *BoolStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *BoolStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]bool, size)
br := NewLimitedBufferReader(r, size, 1, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
data[i] = bytes[0] == 1
}
s.Data = data
return nil
}
func (s *BoolStorage) GetData() interface{} {
return s.Data
}
func (s *BoolStorage) DType() gotch.DType {
return gotch.Float
}
func (s *BoolStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
func setFromFile(s Storage, r io.Reader) error {
sizeBuf := make([]byte, 8)
_, err := r.Read(sizeBuf)
if err != nil {
return err
}
size := int(binary.LittleEndian.Uint64(sizeBuf))
return s.SetFromFileWithSize(r, size)
}
// StorageTensor:
//===============
type StorageTensor struct {
Source Storage
StorageOffset int64
Size []int64
Stride []int64
RequiresGrad bool
}
// Rebuild Tensor:
// ===============
// ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L132
// ref. def _rebuild_tensor(storage, storage_offset, size, stride):
type RebuildTensor struct{}
var _ Callable = &RebuildTensor{}
func (r *RebuildTensor) Call(args ...interface{}) (interface{}, error) {
if len(args) != 4 {
return nil, fmt.Errorf("RebuildTensor.Call() failed. Expected 4 args, got %d: %#v", len(args), args)
}
storage, storageOk := args[0].(Storage)
storageOffset, storageOffsetOk := args[1].(int)
size, sizeOk := args[2].(*Tuple)
stride, strideOk := args[3].(*Tuple)
if !storageOk || !storageOffsetOk || !sizeOk || !strideOk {
return nil, fmt.Errorf("RebuildTensor.Call() unexpected args: %#v", args)
}
tensor := &StorageTensor{
Source: storage,
StorageOffset: int64(storageOffset),
RequiresGrad: false,
}
var err error
tensor.Size, err = tupleToInt64Slice(size)
if err != nil {
return nil, err
}
tensor.Stride, err = tupleToInt64Slice(stride)
if err != nil {
return nil, err
}
return tensor, nil
}
// RebuildTensorV2 represents a struct to rebuild tensor back from pickle object.
type RebuildTensorV2 struct{}
var _ Callable = &RebuildTensorV2{}
func (r *RebuildTensorV2) Call(args ...interface{}) (interface{}, error) {
if len(args) != 6 {
return nil, fmt.Errorf("RebuildTensorV2 unexpected args: %#v", args)
}
storage, storageOk := args[0].(Storage)
storageOffset, storageOffsetOk := args[1].(int)
size, sizeOk := args[2].(*Tuple)
stride, strideOk := args[3].(*Tuple)
requiresGrad, requiresGradOk := args[4].(bool)
// arg[5] "backward hooks" is unused
if !storageOk || !storageOffsetOk || !sizeOk || !strideOk ||
!requiresGradOk {
return nil, fmt.Errorf("RebuildTensorV2 unexpected args: %#v", args)
}
tensor := &StorageTensor{
Source: storage,
StorageOffset: int64(storageOffset),
RequiresGrad: requiresGrad,
}
var err error
tensor.Size, err = tupleToInt64Slice(size)
if err != nil {
return nil, err
}
tensor.Stride, err = tupleToInt64Slice(stride)
if err != nil {
return nil, err
}
return tensor, nil
}
func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
length := tuple.Len()
slice := make([]int64, length)
for i := 0; i < length; i++ {
value, ok := tuple.Get(i).(int)
if !ok {
return nil, fmt.Errorf("tuple of ints expected: %#v", tuple)
}
slice[i] = int64(value)
}
return slice, nil
}

519
pickle/type.go Normal file
View File

@ -0,0 +1,519 @@
package pickle
import (
"container/list"
"fmt"
"reflect"
)
// This file contains custom types and interfaces for casting Python data types from Pickle.
//
// Custom Types:
// =============
// 1. ByteArray
// 2. Dict
// 3. Tuple
// 4. OrderedDict
// 5. List
// 6. Set
// 7. FrozenSet
// 8. Object
// 9. Reconstructor
// 10. GenericClass
// Interfaces:
// ===========
// Callable is implemented by any value that can be directly called to get a
// new value.
//
// It is usually implemented by Python-like functions (returning a value
// given some arguments), or classes (typically returning an instance given
// some constructor arguments).
type Callable interface {
// Call mimics a direct invocation on a Python value, such as a function
// or class (constructor).
Call(args ...interface{}) (interface{}, error)
}
// PyNewable is implemented by any value that has a Python-like
// "__new__" method.
//
// It is usually implemented by values representing Python classes.
type PyNewable interface {
// PyNew mimics Python invocation of the "__new__" method, usually
// provided by classes.
//
// See: https://docs.python.org/3/reference/datamodel.html#object.__new__
PyNew(args ...interface{}) (interface{}, error)
}
// PyStateSettable is implemented by any value that has a Python-like
// "__setstate__" method.
type PyStateSettable interface {
// PySetState mimics Python invocation of the "__setstate__" method.
//
// See: https://docs.python.org/3/library/pickle.html#object.__setstate__
PySetState(state interface{}) error
}
// PyDictSettable is implemented by any value that can store dictionary-like
// key/value pairs. It reflects Python behavior of setting a key/value pair on
// an object's "__dict__" attribute.
type PyDictSettable interface {
// PyDictSet mimics the setting of a key/value pair on an object's
//"__dict__" attribute.
//
// See: https://docs.python.org/3/library/stdtypes.html#object.__dict__
PyDictSet(key, value interface{}) error
}
// PyAttrSettable is implemented by any value on which an existing or new
// Python-like attribute can be set. In Python this is done with "setattr"
// builtin function.
type PyAttrSettable interface {
// PySetAttr mimics the setting of an arbitrary value to an object's
// attribute.
//
// In Python this is done with "setattr" function, to which object,
// attribute name, and value are passed. For an easy and clear
// implementation, here instead we require this method to be implemented
// on the "object" itself.
//
// See: https://docs.python.org/3/library/functions.html#setattr
PySetAttr(key string, value interface{}) error
}
// ByteArray:
//===========
// ByteArray simulates Python bytearray.
type ByteArray []byte
func NewByteArray() *ByteArray {
arr := make(ByteArray, 0)
return &arr
}
func NewByteArrayFromSlice(slice []byte) *ByteArray {
arr := ByteArray(slice)
return &arr
}
func (a *ByteArray) Get(i int) byte {
return (*a)[i]
}
func (a *ByteArray) Len() int {
return len(*a)
}
// Dict:
//======
// DictSetter is implemented by any value that exhibits a dict-like behaviour,
// allowing arbitrary key/value pairs to be set.
type DictSetter interface {
Set(key, value interface{})
}
// Dict represents a Python "dict" (builtin type).
//
// It is implemented as a slice, instead of a map, because in Go not
// all types can be map's keys (e.g. slices).
type Dict []*DictEntry
type DictEntry struct {
Key interface{}
Value interface{}
}
// NewDict makes and returns a new empty Dict.
func NewDict() *Dict {
d := make(Dict, 0)
return &d
}
// Set sets into the Dict the given key/value pair.
func (d *Dict) Set(key, value interface{}) {
*d = append(*d, &DictEntry{
Key: key,
Value: value,
})
}
// Get returns the value associated with the given key (if any), and whether
// the key is present or not.
func (d *Dict) Get(key interface{}) (interface{}, bool) {
for _, entry := range *d {
if reflect.DeepEqual(entry.Key, key) {
return entry.Value, true
}
}
return nil, false
}
// MustGet returns the value associated with the given key, if if it exists,
// otherwise it panics.
func (d *Dict) MustGet(key interface{}) interface{} {
value, ok := d.Get(key)
if !ok {
panic(fmt.Errorf("key not found in Dict: %#v", key))
}
return value
}
// Len returns the length of the Dict, that is, the amount of key/value pairs
// contained by the Dict.
func (d *Dict) Len() int {
return len(*d)
}
var _ DictSetter = &Dict{}
// Tuple:
// ======
type Tuple []interface{}
func NewTupleFromSlice(slice []interface{}) *Tuple {
t := Tuple(slice)
return &t
}
func (t *Tuple) Get(i int) interface{} {
return (*t)[i]
}
func (t *Tuple) Len() int {
return len(*t)
}
// OrderedDict:
// ============
// OrderedDictClass represent Python "collections.OrderedDict" class.
//
// This class allows the indirect creation of OrderedDict objects.
type OrderedDictClass struct{}
var _ Callable = &OrderedDictClass{}
// Call returns a new empty OrderedDict. It is equivalent to Python
// constructor "collections.OrderedDict()".
//
// No arguments are supported.
func (*OrderedDictClass) Call(args ...interface{}) (interface{}, error) {
if len(args) != 0 {
return nil, fmt.Errorf(
"OrderedDictClass.Call args not supported: %#v", args)
}
return NewOrderedDict(), nil
}
// OrderedDict is a minimal and trivial implementation of an ordered map,
// which represent a Python "collections.OrderedDict" object.
//
// It is composed by a simple unordered Map, and a List to keep the order of
// the entries. The former is useful for direct key lookups, the latter for
// iteration.
type OrderedDict struct {
// Map associates a key of any type (interface{}) to OrderedDictEntry
// pointer values. These values are shared with List.
Map map[interface{}]*OrderedDictEntry
// List is an ordered list of OrderedDictEntry pointers, which are
// also shared with Map.
List *list.List
// PyDict represents Python "object.__dict__" dictionary of attributes.
PyDict map[string]interface{}
}
var _ DictSetter = &OrderedDict{}
var _ PyDictSettable = &OrderedDict{}
// OrderedDictEntry is a single key/value pair stored in an OrderedDict.
//
// A pointer to an OrderedDictEntry is always shared between OrderedDict's Map
// and List.
type OrderedDictEntry struct {
// Key of a single OrderedDict's entry.
Key interface{}
// Value of a single OrderedDict's entry.
Value interface{}
// ListElement is a pointer to the OrderedDict's List Element which
// contains this very OrderedDictEntry.
ListElement *list.Element
}
// NewOrderedDict makes and returns a new empty OrderedDict.
func NewOrderedDict() *OrderedDict {
return &OrderedDict{
Map: make(map[interface{}]*OrderedDictEntry),
List: list.New(),
PyDict: make(map[string]interface{}),
}
}
// Set sets into the OrderedDict the given key/value pair. If the key does not
// exist yet, the new pair is positioned at the end (back) of the OrderedDict.
// If the key already exists, the existing associated value is replaced with the
// new one, and the original position is maintained.
func (o *OrderedDict) Set(k, v interface{}) {
if entry, ok := o.Map[k]; ok {
entry.Value = v
return
}
entry := &OrderedDictEntry{
Key: k,
Value: v,
}
entry.ListElement = o.List.PushBack(entry)
o.Map[k] = entry
}
// Get returns the value associated with the given key (if any), and whether
// the key is present or not.
func (o *OrderedDict) Get(k interface{}) (interface{}, bool) {
entry, ok := o.Map[k]
if !ok {
return nil, false
}
return entry.Value, true
}
// MustGet returns the value associated with the given key, if if it exists,
// otherwise it panics.
func (o *OrderedDict) MustGet(key interface{}) interface{} {
value, ok := o.Get(key)
if !ok {
panic(fmt.Errorf("key not found in OrderedDict: %#v", key))
}
return value
}
// Len returns the length of the OrderedDict, that is, the amount of key/value
// pairs contained by the OrderedDict.
func (o *OrderedDict) Len() int {
return len(o.Map)
}
// PyDictSet mimics the setting of a key/value pair on Python "__dict__"
// attribute of the OrderedDict.
func (o *OrderedDict) PyDictSet(key, value interface{}) error {
sKey, keyOk := key.(string)
if !keyOk {
return fmt.Errorf(
"OrderedDict.PyDictSet() requires string key: %#v", key)
}
o.PyDict[sKey] = value
return nil
}
// List:
// =====
// ListAppender is implemented by any value that exhibits a list-like
// behaviour, allowing arbitrary values to be appended.
type ListAppender interface {
Append(v interface{})
}
// List represents a Python "list" (builtin type).
type List []interface{}
var _ ListAppender = &List{}
// NewList makes and returns a new empty List.
func NewList() *List {
l := make(List, 0)
return &l
}
// NewListFromSlice makes and returns a new List initialized with the elements
// of the given slice.
//
// The new List is a simple type cast of the input slice; the slice is _not_
// copied.
func NewListFromSlice(slice []interface{}) *List {
l := List(slice)
return &l
}
// Append appends one element to the end of the List.
func (l *List) Append(v interface{}) {
*l = append(*l, v)
}
// Get returns the element of the List at the given index.
//
// It panics if the index is out of range.
func (l *List) Get(i int) interface{} {
return (*l)[i]
}
// Len returns the length of the List.
func (l *List) Len() int {
return len(*l)
}
// Set:
// ====
// SetAdder is implemented by any value that exhibits a set-like behaviour,
// allowing arbitrary values to be added.
type SetAdder interface {
Add(v interface{})
}
// Set represents a Python "set" (builtin type).
//
// It is implemented in Go as a map with empty struct values; the actual set
// of generic "interface{}" items is thus represented by all the keys.
type Set map[interface{}]setEmptyStruct
var _ SetAdder = &Set{}
type setEmptyStruct struct{}
// NewSet makes and returns a new empty Set.
func NewSet() *Set {
s := make(Set)
return &s
}
// NewSetFromSlice makes and returns a new Set initialized with the elements
// of the given slice.
func NewSetFromSlice(slice []interface{}) *Set {
s := make(Set, len(slice))
for _, item := range slice {
s[item] = setEmptyStruct{}
}
return &s
}
// Len returns the length of the Set.
func (s *Set) Len() int {
return len(*s)
}
// Add adds one element to the Set.
func (s *Set) Add(v interface{}) {
(*s)[v] = setEmptyStruct{}
}
// Has returns whether the given value is present in the Set (true)
// or not (false).
func (s *Set) Has(v interface{}) bool {
_, ok := (*s)[v]
return ok
}
// FrozenSet:
//===========
// FrozenSet represents a Python "frozenset" (builtin type).
//
// It is implemented in Go as a map with empty struct values; the actual set
// of generic "interface{}" items is thus represented by all the keys.
type FrozenSet map[interface{}]frozenSetEmptyStruct
type frozenSetEmptyStruct struct{}
// NewFrozenSetFromSlice makes and returns a new FrozenSet initialized
// with the elements of the given slice.
func NewFrozenSetFromSlice(slice []interface{}) *FrozenSet {
f := make(FrozenSet, len(slice))
for _, item := range slice {
f[item] = frozenSetEmptyStruct{}
}
return &f
}
// Len returns the length of the FrozenSet.
func (f *FrozenSet) Len() int {
return len(*f)
}
// Has returns whether the given value is present in the FrozenSet (true)
// or not (false).
func (f *FrozenSet) Has(v interface{}) bool {
_, ok := (*f)[v]
return ok
}
// Object:
//========
type ObjectClass struct{}
var _ PyNewable = &ObjectClass{}
func (o *ObjectClass) PyNew(args ...interface{}) (interface{}, error) {
if len(args) == 0 {
return nil, fmt.Errorf("ObjectClass.PyNew called with no arguments")
}
switch class := args[0].(type) {
case PyNewable:
return class.PyNew()
default:
return nil, fmt.Errorf(
"ObjectClass.PyNew unprocessable args: %#v", args)
}
}
// Reconstructor:
//===============
type Reconstructor struct{}
var _ Callable = &Reconstructor{}
func (r *Reconstructor) Call(args ...interface{}) (interface{}, error) {
if len(args) < 2 {
return nil, fmt.Errorf("Reconstructor: invalid arguments: %#v", args)
}
class := args[0]
switch base := args[1].(type) {
case PyNewable:
return base.PyNew(class)
default:
return nil, fmt.Errorf(
"Reconstructor: unprocessable arguments: %#v", args)
}
}
// GenericClass:
//==============
type GenericClass struct {
Module string
Name string
}
var _ PyNewable = &GenericClass{}
type GenericObject struct {
Class *GenericClass
ConstructorArgs []interface{}
}
func NewGenericClass(module, name string) *GenericClass {
return &GenericClass{Module: module, Name: name}
}
func (g *GenericClass) PyNew(args ...interface{}) (interface{}, error) {
return &GenericObject{
Class: g,
ConstructorArgs: args,
}, nil
}
// getThnnFunctionBackend is for historical pickle deserilaization, it is not used otherwise
type getThnnFunctionBackend struct{}
var _ Callable = &getThnnFunctionBackend{}
func (getThnnFunctionBackend) Call(_ ...interface{}) (interface{}, error) {
return nil, nil
}

112
pickle/util.go Normal file
View File

@ -0,0 +1,112 @@
package pickle
import "io"
// Converts the bits representation of a Half Float (16 bits) number to
// an IEEE 754 float representation (32 bits)
// From http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
func FloatBits16to32(u16 uint16) uint32 {
return mantissaTable[offsetTable[u16>>10]+(uint32(u16)&0x3ff)] + exponentTable[u16>>10]
}
var mantissaTable [2048]uint32
var exponentTable [64]uint32
var offsetTable [64]uint32
func init() {
initMantissaTable()
initExponentTable()
initOffsetTable()
}
func initMantissaTable() {
mantissaTable[0] = 0
for i := uint32(1); i < 1024; i++ {
mantissaTable[i] = convertMantissa(i)
}
for i := uint32(1024); i < 2048; i++ {
mantissaTable[i] = 0x38000000 + ((i - 1024) << 13)
}
}
func initExponentTable() {
exponentTable[0] = 0
exponentTable[31] = 0x47800000
exponentTable[32] = 0x80000000
exponentTable[63] = 0xC7800000
for i := uint32(1); i < 31; i++ {
exponentTable[i] = i << 23
}
for i := uint32(33); i < 63; i++ {
exponentTable[i] = 0x80000000 + (i-32)<<23
}
}
func initOffsetTable() {
offsetTable[0] = 0
offsetTable[32] = 0
for i := uint32(1); i < 31; i++ {
offsetTable[i] = 1024
}
for i := uint32(32); i < 64; i++ {
offsetTable[i] = 1024
}
}
func convertMantissa(i uint32) uint32 {
var m uint32 = i << 13 // zero pad mantissa bits
var e uint32 = 0 // zero exponent
for m&0x00800000 != 0 { // while not normalized
e -= 0x00800000 // decrement exponent (1 << 23)
m <<= 1 // shift mantissa
}
m &= ^uint32(0x00800000) // clear leading 1 bit
e += 0x38800000 // adjust bias ((127-14)<<23)
return m | e // return combined number
}
type LimitedBufferReader struct {
r io.Reader
scalarSize int
remainingBytes int
buf []byte
bufIndex int
}
func NewLimitedBufferReader(
r io.Reader,
dataSize, scalarSize, bufferSize int,
) *LimitedBufferReader {
size := bufferSize * scalarSize
return &LimitedBufferReader{
r: r,
scalarSize: scalarSize,
remainingBytes: scalarSize * dataSize,
buf: make([]byte, size),
bufIndex: size,
}
}
func (br *LimitedBufferReader) HasNext() bool {
return br.remainingBytes != 0
}
func (br *LimitedBufferReader) ReadNext() ([]byte, error) {
if br.remainingBytes == 0 {
return nil, io.EOF
}
if br.bufIndex == len(br.buf) {
br.bufIndex = 0
if br.remainingBytes < len(br.buf) {
br.buf = br.buf[0:br.remainingBytes]
}
_, err := br.r.Read(br.buf)
if err != nil {
return nil, err
}
}
result := br.buf[br.bufIndex : br.bufIndex+br.scalarSize]
br.bufIndex += br.scalarSize
br.remainingBytes -= br.scalarSize
return result, nil
}