added subpackage 'pickle' and file-util
This commit is contained in:
parent
47c6c60561
commit
dc4ab3047c
|
@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
|||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
- Added subpackage `pickle`. Now we can load directly Python Pytorch pretrained model without any Python script conversion.
|
||||
- Added `gotch.CachePath()` and `gotch.ModelUrls`
|
||||
- Remove Travis CI for now.
|
||||
- fixed `tensor.OfSlice()` throw error due to "Unsupported Go type" (e.g. []float32)
|
||||
- added `nn.Path.Paths()` method
|
||||
|
|
36
example/pickle/main.go
Normal file
36
example/pickle/main.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
"github.com/sugarme/gotch/pickle"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
func main() {
|
||||
device := gotch.CPU
|
||||
vs := nn.NewVarStore(device)
|
||||
net := vision.VGG16(vs.Root(), 1000)
|
||||
|
||||
modelName := "vgg16"
|
||||
modelUrl, ok := gotch.ModelUrls[modelName]
|
||||
if !ok {
|
||||
log.Fatal("model name %q not found.", modelName)
|
||||
}
|
||||
|
||||
modelFile, err := gotch.CachedPath(modelUrl)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = pickle.LoadAll(vs, modelFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%v\n", net)
|
||||
vs.Summary()
|
||||
}
|
289
file-util.go
Normal file
289
file-util.go
Normal file
|
@ -0,0 +1,289 @@
|
|||
package gotch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// This file provides functions to work with local dataset cache, ...
|
||||
|
||||
// ModelUrls maps model name to its pretrained URL.
|
||||
//
|
||||
// This URLS taken from separate models in pytorch/vision repository
|
||||
// https://github.com/pytorch/vision/tree/main/torchvision/models
|
||||
var ModelUrls map[string]string = map[string]string{
|
||||
"alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
|
||||
|
||||
"convnext_tiny": "https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
|
||||
"convnext_small": "https://download.pytorch.org/models/convnext_small-0c510722.pth",
|
||||
"convnext_base": "https://download.pytorch.org/models/convnext_base-6075fbad.pth",
|
||||
"convnext_large": "https://download.pytorch.org/models/convnext_large-ea097f82.pth",
|
||||
|
||||
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
|
||||
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
|
||||
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
|
||||
"densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
|
||||
|
||||
//Weights ported from https://github.com/rwightman/pytorch-image-models/
|
||||
"efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
|
||||
"efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
|
||||
"efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
|
||||
"efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
|
||||
"efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
|
||||
//Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
|
||||
"efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
|
||||
"efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
|
||||
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
|
||||
|
||||
//GoogLeNet ported from TensorFlow
|
||||
"googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
|
||||
|
||||
//Inception v3 ported from TensorFlow
|
||||
"inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
|
||||
|
||||
"mnasnet0_5": "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
|
||||
"mnasnet0_75": "",
|
||||
"mnasnet1_0": "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
|
||||
"mnasnet1_3": "",
|
||||
|
||||
"mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
|
||||
"mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
|
||||
"mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
|
||||
|
||||
"regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
|
||||
"regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
|
||||
"regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
|
||||
"regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
|
||||
"regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
|
||||
"regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
|
||||
"regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
|
||||
"regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
|
||||
"regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
|
||||
"regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
|
||||
"regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
|
||||
"regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
|
||||
"regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
|
||||
"regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
|
||||
|
||||
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
|
||||
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
|
||||
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
|
||||
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
|
||||
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
|
||||
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
|
||||
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
|
||||
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
|
||||
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
|
||||
|
||||
"shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
|
||||
"shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
|
||||
"shufflenetv2_x1.5": "",
|
||||
"shufflenetv2_x2.0": "",
|
||||
|
||||
"squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
|
||||
"squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
|
||||
|
||||
"vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
|
||||
"vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
|
||||
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
|
||||
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
|
||||
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
|
||||
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
|
||||
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
|
||||
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
|
||||
|
||||
"vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth",
|
||||
"vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",
|
||||
"vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",
|
||||
"vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth",
|
||||
}
|
||||
|
||||
// CachedPath resolves and caches data based on input string, then returns fullpath to the cached data.
|
||||
//
|
||||
// Parameters:
|
||||
// - `filenameOrUrl`: full path to filename or url
|
||||
//
|
||||
// CachedPath does several things consequently:
|
||||
// 1. Resolves input string to a fullpath cached filename candidate.
|
||||
// 2. Check it at `CachePath`, if exists, then return the candidate. If not
|
||||
// 3. Retrieves and Caches data to `CachePath` and returns path to cached data
|
||||
func CachedPath(filenameOrUrl string) (resolvedPath string, err error) {
|
||||
filename := path.Base(filenameOrUrl)
|
||||
// Resolves to "candidate" filename at `CacheDir`
|
||||
cachedFileCandidate := fmt.Sprintf("%s/%s", CacheDir, filename)
|
||||
|
||||
// 1. Cached candidate file exists
|
||||
if _, err := os.Stat(cachedFileCandidate); err == nil {
|
||||
return cachedFileCandidate, nil
|
||||
}
|
||||
|
||||
// 2. If valid fullpath to local file, caches it and return cached filename
|
||||
if _, err := os.Stat(filenameOrUrl); err == nil {
|
||||
err := copyFile(filenameOrUrl, cachedFileCandidate)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return cachedFileCandidate, nil
|
||||
}
|
||||
|
||||
// 3. Cached candidate file NOT exist. Try to download it and save to `CacheDir`
|
||||
if isValidURL(filenameOrUrl) {
|
||||
if _, err := http.Get(filenameOrUrl); err == nil {
|
||||
err := downloadFile(filenameOrUrl, cachedFileCandidate)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return cachedFileCandidate, nil
|
||||
} else {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
err = fmt.Errorf("Unable to parse %q as a URL or as a local path.\n", filenameOrUrl)
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Not resolves
|
||||
err = fmt.Errorf("Unable to parse %q as a URL or as a local path.\n", filenameOrUrl)
|
||||
return "", err
|
||||
}
|
||||
|
||||
func isValidURL(url string) bool {
|
||||
|
||||
// TODO: implement
|
||||
return true
|
||||
}
|
||||
|
||||
// downloadFile downloads file from URL and stores it in local filepath.
|
||||
// It writes to the destination file as it downloads it, without loading
|
||||
// the entire file into memory. An `io.TeeReader` is passed into Copy()
|
||||
// to report progress on the download.
|
||||
func downloadFile(url string, filepath string) error {
|
||||
// Create path if not existing
|
||||
dir := path.Dir(filepath)
|
||||
filename := path.Base(filepath)
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create the file with .tmp extension, so that we won't overwrite a
|
||||
// file until it's downloaded fully
|
||||
out, err := os.Create(filepath + ".tmp")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// Get the data
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check server response
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
err := fmt.Errorf("bad status: %s(%v)", resp.Status, resp.StatusCode)
|
||||
if resp.StatusCode == 404 {
|
||||
err = fmt.Errorf("download file not found: %q for downloading", url)
|
||||
} else {
|
||||
err = fmt.Errorf("download file failed: %q", url)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// the total file size to download
|
||||
size, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
|
||||
downloadSize := uint64(size)
|
||||
|
||||
// Create our bytes counter and pass it to be used alongside our writer
|
||||
counter := &writeCounter{FileSize: downloadSize}
|
||||
_, err = io.Copy(out, io.TeeReader(resp.Body, counter))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("\r%s... %s/%s completed", filename, byteCountIEC(counter.Total), byteCountIEC(counter.FileSize))
|
||||
// The progress use the same line so print a new line once it's finished downloading
|
||||
fmt.Println()
|
||||
|
||||
// Rename the tmp file back to the original file
|
||||
err = os.Rename(filepath+".tmp", filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeCounter counts the number of bytes written to it. By implementing the Write method,
|
||||
// it is of the io.Writer interface and we can pass this into io.TeeReader()
|
||||
// Every write to this writer, will print the progress of the file write.
|
||||
type writeCounter struct {
|
||||
Total uint64
|
||||
FileSize uint64
|
||||
}
|
||||
|
||||
func (wc *writeCounter) Write(p []byte) (int, error) {
|
||||
n := len(p)
|
||||
wc.Total += uint64(n)
|
||||
wc.printProgress()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// PrintProgress prints the progress of a file write
|
||||
func (wc writeCounter) printProgress() {
|
||||
// Clear the line by using a character return to go back to the start and remove
|
||||
// the remaining characters by filling it with spaces
|
||||
fmt.Printf("\r%s", strings.Repeat(" ", 50))
|
||||
|
||||
// Return again and print current status of download
|
||||
fmt.Printf("\rDownloading... %s/%s", byteCountIEC(wc.Total), byteCountIEC(wc.FileSize))
|
||||
}
|
||||
|
||||
// byteCountIEC converts bytes to human-readable string in binary (IEC) format.
|
||||
func byteCountIEC(b uint64) string {
|
||||
const unit = 1024
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := uint64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB",
|
||||
float64(b)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
sourceFileStat, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !sourceFileStat.Mode().IsRegular() {
|
||||
return fmt.Errorf("%s is not a regular file", src)
|
||||
}
|
||||
|
||||
source, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer source.Close()
|
||||
|
||||
destination, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer destination.Close()
|
||||
_, err = io.Copy(destination, source)
|
||||
return err
|
||||
}
|
35
init.go
Normal file
35
init.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package gotch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
var (
|
||||
CacheDir string = "NOT_SETTING"
|
||||
gotchEnvKey string = "GOTCH_CACHE"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// default path: {$HOME}/.cache/gotch
|
||||
homeDir := os.Getenv("HOME")
|
||||
CacheDir = fmt.Sprintf("%s/.cache/transformer", homeDir)
|
||||
|
||||
initEnv()
|
||||
|
||||
log.Printf("INFO: CacheDir=%q\n", CacheDir)
|
||||
}
|
||||
|
||||
func initEnv() {
|
||||
val := os.Getenv(gotchEnvKey)
|
||||
if val != "" {
|
||||
CacheDir = val
|
||||
}
|
||||
|
||||
if _, err := os.Stat(CacheDir); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(CacheDir, 0755); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
1937
pickle/pickle.go
Normal file
1937
pickle/pickle.go
Normal file
File diff suppressed because it is too large
Load Diff
60
pickle/pickle_example_test.go
Normal file
60
pickle/pickle_example_test.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package pickle_test
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/pickle"
|
||||
)
|
||||
|
||||
func ExampleLoadInfo() {
|
||||
modelName := "vgg16"
|
||||
url, ok := gotch.ModelUrls[modelName]
|
||||
if !ok {
|
||||
log.Fatalf("Unsupported model name %q\n", modelName)
|
||||
}
|
||||
modelFile, err := gotch.CachedPath(url)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = pickle.LoadInfo(modelFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Output:
|
||||
// classifier.0.bias - [4096]
|
||||
// classifier.0.weight - [4096 25088]
|
||||
// classifier.3.bias - [4096]
|
||||
// classifier.3.weight - [4096 4096]
|
||||
// classifier.6.bias - [1000]
|
||||
// classifier.6.weight - [1000 4096]
|
||||
// features.0.bias - [64]
|
||||
// features.0.weight - [64 3 3 3]
|
||||
// features.10.bias - [256]
|
||||
// features.10.weight - [256 128 3 3]
|
||||
// features.12.bias - [256]
|
||||
// features.12.weight - [256 256 3 3]
|
||||
// features.14.bias - [256]
|
||||
// features.14.weight - [256 256 3 3]
|
||||
// features.17.bias - [512]
|
||||
// features.17.weight - [512 256 3 3]
|
||||
// features.19.bias - [512]
|
||||
// features.19.weight - [512 512 3 3]
|
||||
// features.2.bias - [64]
|
||||
// features.2.weight - [64 64 3 3]
|
||||
// features.21.bias - [512]
|
||||
// features.21.weight - [512 512 3 3]
|
||||
// features.24.bias - [512]
|
||||
// features.24.weight - [512 512 3 3]
|
||||
// features.26.bias - [512]
|
||||
// features.26.weight - [512 512 3 3]
|
||||
// features.28.bias - [512]
|
||||
// features.28.weight - [512 512 3 3]
|
||||
// features.5.bias - [128]
|
||||
// features.5.weight - [128 64 3 3]
|
||||
// features.7.bias - [128]
|
||||
// features.7.weight - [128 128 3 3]
|
||||
// Num of variables: 32
|
||||
}
|
519
pickle/serialization.go
Normal file
519
pickle/serialization.go
Normal file
|
@ -0,0 +1,519 @@
|
|||
package pickle
|
||||
|
||||
// Ref.
|
||||
// https://docs.python.org/3/library/pickle.html
|
||||
// https://docs.python.org/3/library/pickletools.html
|
||||
// https://github.com/python/cpython/blob/main/Lib/pickle.py (****real code here****)
|
||||
// https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
// https://github.com/pytorch/pytorch/blob/master/torch/serialization.py
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"archive/zip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"sort"
|
||||
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
const hexMagicNumber = "1950a86a20f9469cfc6c"
|
||||
const protocolVersion = 1001
|
||||
|
||||
var ErrInvalidMagicNumber = errors.New("invalid pytorch magic number")
|
||||
var ErrInvalidProtocolVersion = errors.New("invalid pytorch protocol version")
|
||||
|
||||
// Encode encodes model using pickling machinery.
|
||||
// Output pickled model can be loads with Python Pytorch as `torch.load("pytorch_model.bin")`
|
||||
//
|
||||
// TODO. implement pickling part so that model can be exported and load with Python Pytorch.
|
||||
// See https://github.com/python/cpython/blob/b0de6299a840a397d4fe3e6c98159d9f258d3295/Lib/pickle.py#L407
|
||||
func Encode(model ts.Module, outputFile string) error {
|
||||
panic("NotImplementedError")
|
||||
}
|
||||
|
||||
// Decode decodes pickled data created by 'torch.save()' with Python Pytorch
|
||||
// and rebuilds named tensor weights.
|
||||
func Decode(filename string) (map[string]*ts.Tensor, error) {
|
||||
newUnpickler := func(r io.Reader) Unpickler {
|
||||
return NewUnpickler(r)
|
||||
}
|
||||
result, err := LoadWithUnpickler(filename, newUnpickler)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("Decode() failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Rebuild tensors from Storage tensors
|
||||
namedTensors := make(map[string]*ts.Tensor)
|
||||
dictResult, isOrderedDict := result.(*OrderedDict)
|
||||
if !isOrderedDict {
|
||||
err := fmt.Errorf("Decode() failed: expected 'OrderedDict' type, got %v\n", reflect.TypeOf(result).String())
|
||||
return nil, err
|
||||
}
|
||||
for name, item := range dictResult.Map {
|
||||
sx, isStorageTensor := item.Value.(*StorageTensor)
|
||||
if !isStorageTensor {
|
||||
err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data := sx.Source.GetData()
|
||||
size := sx.Size
|
||||
dtype := sx.Source.DType()
|
||||
device := sx.Source.Device()
|
||||
stride := sx.Stride
|
||||
storageOffset := sx.StorageOffset
|
||||
|
||||
// fmt.Printf("%q - shape: %v - stride: %v - storageOffset: %v\n", sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
|
||||
|
||||
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
||||
if sx.RequiresGrad {
|
||||
x.MustRequiresGrad_(sx.RequiresGrad)
|
||||
}
|
||||
|
||||
namedTensors[name.(string)] = x
|
||||
}
|
||||
|
||||
return namedTensors, nil
|
||||
}
|
||||
|
||||
// LoadWithUnpickler is like Load, but it accepts a newUnpickler function which
|
||||
// is used to create new customized pickle.Unpickler instances.
|
||||
func LoadWithUnpickler(filename string, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) {
|
||||
if !isZipFile(filename) {
|
||||
return loadLegacyFile(filename, newUnpickler)
|
||||
}
|
||||
return loadZipFile(filename, newUnpickler)
|
||||
}
|
||||
|
||||
func loadZipFile(filename string, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) {
|
||||
// Open a zip archive for reading.
|
||||
r, err := zip.OpenReader(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
fileRecords := make(map[string]*zip.File, len(r.File))
|
||||
for _, f := range r.File {
|
||||
_, recordName := path.Split(f.Name)
|
||||
fileRecords[recordName] = f
|
||||
}
|
||||
|
||||
if _, isTorchScript := fileRecords["constants.pkl"]; isTorchScript {
|
||||
return nil, fmt.Errorf("TorchScript is not supported")
|
||||
}
|
||||
|
||||
dataFile, hasDataFile := fileRecords["data.pkl"]
|
||||
if !hasDataFile {
|
||||
return nil, fmt.Errorf("data.pkl not found in zip file")
|
||||
}
|
||||
df, err := dataFile.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer df.Close()
|
||||
|
||||
loadedStorages := make(map[string]Storage)
|
||||
|
||||
u := newUnpickler(df)
|
||||
u.FindClass = makePickleFindClass(u.FindClass)
|
||||
u.PersistentLoad = func(savedId interface{}) (interface{}, error) {
|
||||
tuple, tupleOk := savedId.(*Tuple)
|
||||
if !tupleOk || tuple.Len() == 0 {
|
||||
return nil, fmt.Errorf("PersistentLoad: non-empty tuple expected, got %#v", savedId)
|
||||
}
|
||||
typename, typenameOk := tuple.Get(0).(string)
|
||||
if !typenameOk {
|
||||
return nil, fmt.Errorf("PersistentLoad: cannot get typename")
|
||||
}
|
||||
if typename != "storage" {
|
||||
return nil, fmt.Errorf("unknown typename for PersistentLoad, expected 'storage' but got '%s'", typename)
|
||||
}
|
||||
if tuple.Len() < 5 {
|
||||
return nil, fmt.Errorf("PersistentLoad: unexpected storage data length")
|
||||
}
|
||||
dataType, dataTypeOk := tuple.Get(1).(StorageClass)
|
||||
key, keyOk := tuple.Get(2).(string)
|
||||
location, locationOk := tuple.Get(3).(string)
|
||||
size, sizeOk := tuple.Get(4).(int)
|
||||
if !dataTypeOk || !keyOk || !locationOk || !sizeOk {
|
||||
return nil, fmt.Errorf("PersistentLoad: unexpected data types")
|
||||
}
|
||||
storage, storageExists := loadedStorages[key]
|
||||
if !storageExists {
|
||||
storage, err = loadTensor(dataType, size, location, key, fileRecords)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
loadedStorages[key] = storage
|
||||
}
|
||||
return storage, nil
|
||||
}
|
||||
return u.Load()
|
||||
}
|
||||
|
||||
func loadTensor(
|
||||
dataType StorageClass,
|
||||
size int,
|
||||
location, key string,
|
||||
zipFileRecords map[string]*zip.File,
|
||||
) (Storage, error) {
|
||||
file, fileOk := zipFileRecords[key]
|
||||
if !fileOk {
|
||||
return nil, fmt.Errorf("cannot find zip record '%s'", key)
|
||||
}
|
||||
f, err := file.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
storage := dataType.New(size, location)
|
||||
err = storage.SetFromFileWithSize(f, size)
|
||||
return storage, err
|
||||
}
|
||||
|
||||
func loadLegacyFile(filename string, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
tr := tar.NewReader(f)
|
||||
for {
|
||||
_, err := tr.Next()
|
||||
switch err {
|
||||
case nil:
|
||||
// TODO: ...
|
||||
panic("legacy load from tar not implemented")
|
||||
case io.EOF:
|
||||
break // End of archive
|
||||
case tar.ErrHeader, io.ErrUnexpectedEOF:
|
||||
_, err = f.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return loadLegacyNoTar(f, newUnpickler)
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func loadLegacyNoTar(f *os.File, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) {
|
||||
if err := readAndCheckMagicNumber(f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := readAndChecProtocolVersion(f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := unpickle(f); err != nil { // sys info
|
||||
return nil, err
|
||||
}
|
||||
|
||||
deserializedObjects := make(map[string]Storage)
|
||||
|
||||
u := newUnpickler(f)
|
||||
u.FindClass = makePickleFindClass(u.FindClass)
|
||||
u.PersistentLoad = func(savedId interface{}) (interface{}, error) {
|
||||
tuple, tupleOk := savedId.(*Tuple)
|
||||
if !tupleOk || tuple.Len() == 0 {
|
||||
return nil, fmt.Errorf("PersistentLoad: non-empty tuple expected, got %#v", savedId)
|
||||
}
|
||||
typename, typenameOk := tuple.Get(0).(string)
|
||||
if !typenameOk {
|
||||
return nil, fmt.Errorf("PersistentLoad: cannot get typename")
|
||||
}
|
||||
|
||||
// fmt.Printf("typename: %s\n", typename)
|
||||
|
||||
switch typename {
|
||||
case "storage":
|
||||
if tuple.Len() < 6 {
|
||||
return nil, fmt.Errorf(
|
||||
"PersistentLoad: unexpected storage data length")
|
||||
}
|
||||
dataType, dataTypeOk := tuple.Get(1).(StorageClass)
|
||||
rootKey, rootKeyOk := tuple.Get(2).(string)
|
||||
location, locationOk := tuple.Get(3).(string)
|
||||
size, sizeOk := tuple.Get(4).(int)
|
||||
viewMetadata := tuple.Get(5)
|
||||
if !dataTypeOk || !rootKeyOk || !locationOk || !sizeOk {
|
||||
return nil, fmt.Errorf("PersistentLoad: unexpected data types")
|
||||
}
|
||||
|
||||
// fmt.Printf("dtype: %v - rootKey: %v - device: %v - size %v - viewMetaData: %v\n", reflect.TypeOf(dataType), rootKey, location, size, viewMetadata)
|
||||
|
||||
storage, storageExists := deserializedObjects[rootKey]
|
||||
if !storageExists {
|
||||
storage = dataType.New(size, location)
|
||||
deserializedObjects[rootKey] = storage
|
||||
}
|
||||
switch vm := viewMetadata.(type) {
|
||||
case nil:
|
||||
return storage, nil
|
||||
case []interface{}:
|
||||
if len(vm) != 3 {
|
||||
return nil, fmt.Errorf(
|
||||
"PersistentLoad: unexpected view metadata length")
|
||||
}
|
||||
panic("viewMetadata not implemented")
|
||||
// TODO: ...
|
||||
// view_key, offset, view_size = view_metadata
|
||||
// if view_key not in deserialized_objects:
|
||||
// deserialized_objects[view_key] = storage[offset:offset + view_size]
|
||||
// return deserialized_objects[view_key]
|
||||
default:
|
||||
return nil, fmt.Errorf("PersistentLoad: unexpected view metadata type")
|
||||
}
|
||||
case "module":
|
||||
if tuple.Len() < 2 {
|
||||
return nil, fmt.Errorf("PersistentLoad: unexpected module data length")
|
||||
}
|
||||
return tuple.Get(1), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Unexpected saved ID type: %s", typename)
|
||||
}
|
||||
}
|
||||
|
||||
result, err := u.Load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawStorageKeys, err := unpickle(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
storageKeys, err := makeStorageKeys(rawStorageKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, key := range storageKeys {
|
||||
storageObj, ok := deserializedObjects[key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("storage object not found for key '%s'", key)
|
||||
}
|
||||
|
||||
err = storageObj.SetFromFile(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func makeStorageKeys(obj interface{}) ([]string, error) {
|
||||
list, ok := obj.(*List)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid storage keys data")
|
||||
}
|
||||
keys := make([]string, len(*list))
|
||||
for i, rawKey := range *list {
|
||||
key, keyOk := rawKey.(string)
|
||||
if !keyOk {
|
||||
return nil, fmt.Errorf("invalid storage key")
|
||||
}
|
||||
keys[i] = key
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func readAndCheckMagicNumber(r io.Reader) error {
|
||||
obj, err := unpickle(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n, ok := obj.(*big.Int); !ok || n.Text(16) != hexMagicNumber {
|
||||
return ErrInvalidMagicNumber
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readAndChecProtocolVersion(r io.Reader) error {
|
||||
obj, err := unpickle(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n, ok := obj.(int); !ok || n != protocolVersion {
|
||||
return ErrInvalidProtocolVersion
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unpickle(r io.Reader) (interface{}, error) {
|
||||
u := NewUnpickler(r)
|
||||
return u.Load()
|
||||
}
|
||||
|
||||
func isZipFile(filename string) bool {
|
||||
r, err := zip.OpenReader(filename)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
r.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func makePickleFindClass(fallback func(module, name string) (interface{}, error)) func(module, name string) (interface{}, error) {
|
||||
return func(module, name string) (interface{}, error) {
|
||||
switch module + "." + name {
|
||||
case "torch._utils._rebuild_tensor":
|
||||
return &RebuildTensor{}, nil
|
||||
case "torch._utils._rebuild_tensor_v2":
|
||||
return &RebuildTensorV2{}, nil
|
||||
case "torch.FloatStorage":
|
||||
return &FloatStorageClass{}, nil
|
||||
case "torch.HalfStorage":
|
||||
return &HalfStorageClass{}, nil
|
||||
case "torch.DoubleStorage":
|
||||
return &DoubleStorageClass{}, nil
|
||||
case "torch.CharStorage":
|
||||
return &CharStorageClass{}, nil
|
||||
case "torch.ShortStorage":
|
||||
return &ShortStorageClass{}, nil
|
||||
case "torch.IntStorage":
|
||||
return &IntStorageClass{}, nil
|
||||
case "torch.LongStorage":
|
||||
return &LongStorageClass{}, nil
|
||||
case "torch.ByteStorage":
|
||||
return &ByteStorageClass{}, nil
|
||||
case "torch.BoolStorage":
|
||||
return &BoolStorageClass{}, nil
|
||||
case "torch.nn.backends.thnn._get_thnn_function_backend":
|
||||
// this is for historical pickle deserilaization, it is not used otherwise
|
||||
return getThnnFunctionBackend{}, nil
|
||||
default:
|
||||
if fallback == nil {
|
||||
return nil, fmt.Errorf("class not found: %s %s", module, name)
|
||||
}
|
||||
return fallback(module, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LoadAll finds and loads all weights from varstore.
|
||||
// It will throw err if one of weights from varstore cannot find from loaded pretrained model.
|
||||
func LoadAll(vs *nn.VarStore, modelFile string) error {
|
||||
weights, err := Decode(modelFile)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("LoadAll() failed: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// for tsName, _ := range vs.Vars.NamedVariables {
|
||||
for tsName := range vs.Vars.NamedVariables {
|
||||
// missing variable
|
||||
currTs, ok := weights[tsName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("LoadAll() failed: Cannot find tensor with name: %v in variable store. \n", tsName)
|
||||
return err
|
||||
}
|
||||
|
||||
// mismatched shape
|
||||
sourceShape := currTs.MustSize()
|
||||
destShape := vs.Vars.NamedVariables[tsName].MustSize()
|
||||
if !reflect.DeepEqual(destShape, sourceShape) {
|
||||
err = fmt.Errorf("LoadAll() failed: Mismatched shape error for variable name: %v - At store: %v - At source %v\n", tsName, destShape, sourceShape)
|
||||
return err
|
||||
}
|
||||
|
||||
ts.NoGrad(func() {
|
||||
vs.Vars.NamedVariables[tsName].Copy_(currTs)
|
||||
})
|
||||
}
|
||||
|
||||
for _, x := range weights {
|
||||
x.MustDrop()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadPartial finds and loads weights for varstore.
|
||||
// It returns list of unfound weight names.
|
||||
func LoadPartial(vs *nn.VarStore, modelFile string) ([]string, error) {
|
||||
weights, err := Decode(modelFile)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("LoadPartial() failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var missingVariables []string
|
||||
|
||||
// Match and in-place copy value (update) from newly loaded tensors
|
||||
// to existing named tensors if name is matched. Throw error otherwise.
|
||||
for tsName := range vs.Vars.NamedVariables {
|
||||
var currTs *ts.Tensor
|
||||
var ok bool
|
||||
|
||||
// missing variable
|
||||
if currTs, ok = weights[tsName]; !ok {
|
||||
missingVariables = append(missingVariables, tsName)
|
||||
continue
|
||||
}
|
||||
|
||||
// mismatched shape
|
||||
destShape := currTs.MustSize()
|
||||
sourceShape := vs.Vars.NamedVariables[tsName].MustSize()
|
||||
if !reflect.DeepEqual(destShape, sourceShape) {
|
||||
fmt.Printf("WARNING: Mismatched shape error for variable name: %v - At store: %v - At source %v. Skip loading this weight...\n", tsName, destShape, sourceShape)
|
||||
missingVariables = append(missingVariables, tsName)
|
||||
continue
|
||||
}
|
||||
|
||||
ts.NoGrad(func() {
|
||||
vs.Vars.NamedVariables[tsName].Copy_(currTs)
|
||||
})
|
||||
}
|
||||
|
||||
for _, x := range weights {
|
||||
x.MustDrop()
|
||||
}
|
||||
|
||||
return missingVariables, nil
|
||||
}
|
||||
|
||||
// LoadInfo loads pretrained weights and prints out name and shape of weights.
|
||||
func LoadInfo(modelFile string) error {
|
||||
weights, err := Decode(modelFile)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("LoadInfo() failed: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
layers := make([]string, 0, len(weights))
|
||||
for tsName := range weights {
|
||||
layers = append(layers, tsName)
|
||||
}
|
||||
sort.Strings(layers)
|
||||
for _, l := range layers {
|
||||
var x *ts.Tensor
|
||||
for tsName, tsVal := range weights {
|
||||
if tsName == l {
|
||||
x = tsVal
|
||||
break
|
||||
}
|
||||
}
|
||||
fmt.Printf("%s - %+v\n", l, x.MustSize())
|
||||
}
|
||||
|
||||
fmt.Printf("Num of variables: %v\n", len(weights))
|
||||
|
||||
for _, x := range weights {
|
||||
x.MustDrop()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
667
pickle/storage.go
Normal file
667
pickle/storage.go
Normal file
|
@ -0,0 +1,667 @@
|
|||
package pickle
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
)
|
||||
|
||||
// This file implements Pytorch storage data types.
|
||||
// ref: https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/storage.py
|
||||
/*
|
||||
torch.double: 'DoubleStorage',
|
||||
torch.float: 'FloatStorage',
|
||||
torch.half: 'HalfStorage',
|
||||
torch.long: 'LongStorage',
|
||||
torch.int: 'IntStorage',
|
||||
torch.int16: 'ShortStorage',
|
||||
torch.int8: 'CharStorage',
|
||||
torch.uint8: 'ByteStorage',
|
||||
torch.bool: 'BoolStorage',
|
||||
torch.bfloat16: 'BFloat16Storage',
|
||||
torch.cdouble: 'ComplexDoubleStorage',
|
||||
torch.cfloat: 'ComplexFloatStorage',
|
||||
torch.qint8: 'QInt8Storage',
|
||||
torch.qint32: 'QInt32Storage',
|
||||
torch.quint8: 'QUInt8Storage',
|
||||
torch.quint4x2: 'QUInt4x2Storage',
|
||||
torch.quint2x4: 'QUInt2x4Storage',
|
||||
*/
|
||||
|
||||
// StorageClass defines interface for types to be used in Storage.
|
||||
type StorageClass interface {
|
||||
New(size int, location string) Storage
|
||||
}
|
||||
|
||||
// Storage define Storage interface.
|
||||
type Storage interface {
|
||||
SetFromFile(r io.Reader) error
|
||||
SetFromFileWithSize(r io.Reader, size int) error
|
||||
DType() gotch.DType
|
||||
GetData() interface{}
|
||||
Device() gotch.Device
|
||||
}
|
||||
|
||||
// BaseStorage represents a base storage.
|
||||
type BaseStorage struct {
|
||||
Size int
|
||||
Location string
|
||||
}
|
||||
|
||||
// HalfStorage:
|
||||
// ============
|
||||
|
||||
type HalfStorageClass struct{}
|
||||
|
||||
var _ StorageClass = &HalfStorageClass{}
|
||||
|
||||
func (s *HalfStorageClass) New(size int, location string) Storage {
|
||||
return &HalfStorage{
|
||||
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||
Data: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type HalfStorage struct {
|
||||
BaseStorage
|
||||
Data []float32
|
||||
}
|
||||
|
||||
var _ Storage = &HalfStorage{}
|
||||
|
||||
func (s *HalfStorage) SetFromFile(r io.Reader) error {
|
||||
return setFromFile(s, r)
|
||||
}
|
||||
|
||||
func (s *HalfStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||
data := make([]float32, size)
|
||||
br := NewLimitedBufferReader(r, size, 2, 512)
|
||||
for i := 0; i < size; i++ {
|
||||
bytes, err := br.ReadNext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u16 := binary.LittleEndian.Uint16(bytes)
|
||||
data[i] = math.Float32frombits(FloatBits16to32(u16))
|
||||
}
|
||||
s.Data = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *HalfStorage) GetData() interface{} {
|
||||
return s.Data
|
||||
}
|
||||
|
||||
func (s *HalfStorage) DType() gotch.DType {
|
||||
return gotch.Float
|
||||
}
|
||||
|
||||
func (s *HalfStorage) Device() gotch.Device {
|
||||
switch s.Location {
|
||||
case "cuda":
|
||||
return gotch.CudaIfAvailable()
|
||||
default:
|
||||
return gotch.CPU
|
||||
}
|
||||
}
|
||||
|
||||
// FloatStorage:
|
||||
// =============
|
||||
|
||||
type FloatStorageClass struct{}
|
||||
|
||||
var _ StorageClass = &FloatStorageClass{}
|
||||
|
||||
func (s *FloatStorageClass) New(size int, location string) Storage {
|
||||
return &FloatStorage{
|
||||
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||
Data: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type FloatStorage struct {
|
||||
BaseStorage
|
||||
Data []float32
|
||||
}
|
||||
|
||||
var _ Storage = &FloatStorage{}
|
||||
|
||||
func (s *FloatStorage) SetFromFile(r io.Reader) error {
|
||||
return setFromFile(s, r)
|
||||
}
|
||||
|
||||
func (s *FloatStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||
data := make([]float32, size)
|
||||
br := NewLimitedBufferReader(r, size, 4, 512)
|
||||
for i := 0; i < size; i++ {
|
||||
bytes, err := br.ReadNext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data[i] = math.Float32frombits(binary.LittleEndian.Uint32(bytes))
|
||||
}
|
||||
s.Data = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FloatStorage) GetData() interface{} {
|
||||
return s.Data
|
||||
}
|
||||
|
||||
func (s *FloatStorage) DType() gotch.DType {
|
||||
return gotch.Float
|
||||
}
|
||||
|
||||
func (s *FloatStorage) Device() gotch.Device {
|
||||
switch s.Location {
|
||||
case "cuda":
|
||||
return gotch.CudaIfAvailable()
|
||||
default:
|
||||
return gotch.CPU
|
||||
}
|
||||
}
|
||||
|
||||
// DoubleStorage:
|
||||
// ==============
|
||||
|
||||
type DoubleStorageClass struct{}
|
||||
|
||||
var _ StorageClass = &DoubleStorageClass{}
|
||||
|
||||
func (s *DoubleStorageClass) New(size int, location string) Storage {
|
||||
return &DoubleStorage{
|
||||
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||
Data: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type DoubleStorage struct {
|
||||
BaseStorage
|
||||
Data []float64
|
||||
}
|
||||
|
||||
var _ Storage = &DoubleStorage{}
|
||||
|
||||
func (s *DoubleStorage) SetFromFile(r io.Reader) error {
|
||||
return setFromFile(s, r)
|
||||
}
|
||||
|
||||
func (s *DoubleStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||
data := make([]float64, size)
|
||||
br := NewLimitedBufferReader(r, size, 8, 512)
|
||||
for i := 0; i < size; i++ {
|
||||
bytes, err := br.ReadNext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data[i] = math.Float64frombits(binary.LittleEndian.Uint64(bytes))
|
||||
}
|
||||
s.Data = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DoubleStorage) GetData() interface{} {
|
||||
return s.Data
|
||||
}
|
||||
|
||||
func (s *DoubleStorage) DType() gotch.DType {
|
||||
return gotch.Double
|
||||
}
|
||||
|
||||
func (s *DoubleStorage) Device() gotch.Device {
|
||||
switch s.Location {
|
||||
case "cuda":
|
||||
return gotch.CudaIfAvailable()
|
||||
default:
|
||||
return gotch.CPU
|
||||
}
|
||||
}
|
||||
|
||||
// CharStorage:
|
||||
// ============
|
||||
|
||||
type CharStorageClass struct{}
|
||||
|
||||
var _ StorageClass = &CharStorageClass{}
|
||||
|
||||
func (s *CharStorageClass) New(size int, location string) Storage {
|
||||
return &CharStorage{
|
||||
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||
Data: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type CharStorage struct {
|
||||
BaseStorage
|
||||
Data []int8
|
||||
}
|
||||
|
||||
var _ Storage = &CharStorage{}
|
||||
|
||||
func (s *CharStorage) SetFromFile(r io.Reader) error {
|
||||
return setFromFile(s, r)
|
||||
}
|
||||
|
||||
func (s *CharStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||
data := make([]int8, size)
|
||||
br := NewLimitedBufferReader(r, size, 1, 512)
|
||||
for i := 0; i < size; i++ {
|
||||
bytes, err := br.ReadNext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data[i] = int8(bytes[0])
|
||||
}
|
||||
s.Data = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *CharStorage) GetData() interface{} {
|
||||
return s.Data
|
||||
}
|
||||
|
||||
func (s *CharStorage) DType() gotch.DType {
|
||||
return gotch.Int8
|
||||
}
|
||||
|
||||
func (s *CharStorage) Device() gotch.Device {
|
||||
switch s.Location {
|
||||
case "cuda":
|
||||
return gotch.CudaIfAvailable()
|
||||
default:
|
||||
return gotch.CPU
|
||||
}
|
||||
}
|
||||
|
||||
// ShortStorage:
|
||||
// =============
|
||||
|
||||
type ShortStorageClass struct{}
|
||||
|
||||
var _ StorageClass = &ShortStorageClass{}
|
||||
|
||||
func (s *ShortStorageClass) New(size int, location string) Storage {
|
||||
return &ShortStorage{
|
||||
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||
Data: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type ShortStorage struct {
|
||||
BaseStorage
|
||||
Data []int16
|
||||
}
|
||||
|
||||
var _ Storage = &ShortStorage{}
|
||||
|
||||
func (s *ShortStorage) SetFromFile(r io.Reader) error {
|
||||
return setFromFile(s, r)
|
||||
}
|
||||
|
||||
func (s *ShortStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||
data := make([]int16, size)
|
||||
br := NewLimitedBufferReader(r, size, 2, 512)
|
||||
for i := 0; i < size; i++ {
|
||||
bytes, err := br.ReadNext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data[i] = int16(binary.LittleEndian.Uint16(bytes))
|
||||
}
|
||||
s.Data = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ShortStorage) GetData() interface{} {
|
||||
return s.Data
|
||||
}
|
||||
|
||||
func (s *ShortStorage) DType() gotch.DType {
|
||||
return gotch.Int16
|
||||
}
|
||||
|
||||
func (s *ShortStorage) Device() gotch.Device {
|
||||
switch s.Location {
|
||||
case "cuda":
|
||||
return gotch.CudaIfAvailable()
|
||||
default:
|
||||
return gotch.CPU
|
||||
}
|
||||
}
|
||||
|
||||
// IntStorage:
|
||||
// ===========
|
||||
|
||||
type IntStorageClass struct{}
|
||||
|
||||
var _ StorageClass = &IntStorageClass{}
|
||||
|
||||
func (s *IntStorageClass) New(size int, location string) Storage {
|
||||
return &IntStorage{
|
||||
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||
Data: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type IntStorage struct {
|
||||
BaseStorage
|
||||
Data []int32
|
||||
}
|
||||
|
||||
var _ Storage = &IntStorage{}
|
||||
|
||||
func (s *IntStorage) SetFromFile(r io.Reader) error {
|
||||
return setFromFile(s, r)
|
||||
}
|
||||
|
||||
func (s *IntStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||
data := make([]int32, size)
|
||||
br := NewLimitedBufferReader(r, size, 4, 512)
|
||||
for i := 0; i < size; i++ {
|
||||
bytes, err := br.ReadNext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data[i] = int32(binary.LittleEndian.Uint32(bytes))
|
||||
}
|
||||
s.Data = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *IntStorage) GetData() interface{} {
|
||||
return s.Data
|
||||
}
|
||||
|
||||
func (s *IntStorage) DType() gotch.DType {
|
||||
return gotch.Int
|
||||
}
|
||||
|
||||
func (s *IntStorage) Device() gotch.Device {
|
||||
switch s.Location {
|
||||
case "cuda":
|
||||
return gotch.CudaIfAvailable()
|
||||
default:
|
||||
return gotch.CPU
|
||||
}
|
||||
}
|
||||
|
||||
// LongStorage:
|
||||
// ============
|
||||
|
||||
type LongStorageClass struct{}
|
||||
|
||||
var _ StorageClass = &LongStorageClass{}
|
||||
|
||||
func (s *LongStorageClass) New(size int, location string) Storage {
|
||||
return &LongStorage{
|
||||
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||
Data: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type LongStorage struct {
|
||||
BaseStorage
|
||||
Data []int64
|
||||
}
|
||||
|
||||
var _ Storage = &LongStorage{}
|
||||
|
||||
func (s *LongStorage) SetFromFile(r io.Reader) error {
|
||||
return setFromFile(s, r)
|
||||
}
|
||||
|
||||
func (s *LongStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||
data := make([]int64, size)
|
||||
br := NewLimitedBufferReader(r, size, 8, 512)
|
||||
for i := 0; i < size; i++ {
|
||||
bytes, err := br.ReadNext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data[i] = int64(binary.LittleEndian.Uint64(bytes))
|
||||
}
|
||||
s.Data = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *LongStorage) GetData() interface{} {
|
||||
return s.Data
|
||||
}
|
||||
|
||||
func (s *LongStorage) DType() gotch.DType {
|
||||
return gotch.Int64
|
||||
}
|
||||
|
||||
func (s *LongStorage) Device() gotch.Device {
|
||||
switch s.Location {
|
||||
case "cuda":
|
||||
return gotch.CudaIfAvailable()
|
||||
default:
|
||||
return gotch.CPU
|
||||
}
|
||||
}
|
||||
|
||||
// ByteStorage:
|
||||
// ============
|
||||
|
||||
type ByteStorageClass struct{}
|
||||
|
||||
var _ StorageClass = &ByteStorageClass{}
|
||||
|
||||
func (s *ByteStorageClass) New(size int, location string) Storage {
|
||||
return &ByteStorage{
|
||||
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||
Data: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type ByteStorage struct {
|
||||
BaseStorage
|
||||
Data []uint8
|
||||
}
|
||||
|
||||
var _ Storage = &ByteStorage{}
|
||||
|
||||
func (s *ByteStorage) SetFromFile(r io.Reader) error {
|
||||
return setFromFile(s, r)
|
||||
}
|
||||
|
||||
func (s *ByteStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||
data := make([]uint8, size)
|
||||
br := NewLimitedBufferReader(r, size, 1, 512)
|
||||
for i := 0; i < size; i++ {
|
||||
bytes, err := br.ReadNext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data[i] = bytes[0]
|
||||
}
|
||||
s.Data = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ByteStorage) GetData() interface{} {
|
||||
return s.Data
|
||||
}
|
||||
|
||||
func (s *ByteStorage) DType() gotch.DType {
|
||||
return gotch.Uint8
|
||||
}
|
||||
|
||||
func (s *ByteStorage) Device() gotch.Device {
|
||||
switch s.Location {
|
||||
case "cuda":
|
||||
return gotch.CudaIfAvailable()
|
||||
default:
|
||||
return gotch.CPU
|
||||
}
|
||||
}
|
||||
|
||||
// BoolStorage:
|
||||
// ============
|
||||
|
||||
type BoolStorageClass struct{}
|
||||
|
||||
var _ StorageClass = &BoolStorageClass{}
|
||||
|
||||
func (s *BoolStorageClass) New(size int, location string) Storage {
|
||||
return &BoolStorage{
|
||||
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||
Data: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type BoolStorage struct {
|
||||
BaseStorage
|
||||
Data []bool
|
||||
}
|
||||
|
||||
var _ Storage = &BoolStorage{}
|
||||
|
||||
func (s *BoolStorage) SetFromFile(r io.Reader) error {
|
||||
return setFromFile(s, r)
|
||||
}
|
||||
|
||||
func (s *BoolStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||
data := make([]bool, size)
|
||||
br := NewLimitedBufferReader(r, size, 1, 512)
|
||||
for i := 0; i < size; i++ {
|
||||
bytes, err := br.ReadNext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data[i] = bytes[0] == 1
|
||||
}
|
||||
s.Data = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BoolStorage) GetData() interface{} {
|
||||
return s.Data
|
||||
}
|
||||
|
||||
func (s *BoolStorage) DType() gotch.DType {
|
||||
return gotch.Float
|
||||
}
|
||||
|
||||
func (s *BoolStorage) Device() gotch.Device {
|
||||
switch s.Location {
|
||||
case "cuda":
|
||||
return gotch.CudaIfAvailable()
|
||||
default:
|
||||
return gotch.CPU
|
||||
}
|
||||
}
|
||||
|
||||
func setFromFile(s Storage, r io.Reader) error {
|
||||
sizeBuf := make([]byte, 8)
|
||||
_, err := r.Read(sizeBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
size := int(binary.LittleEndian.Uint64(sizeBuf))
|
||||
return s.SetFromFileWithSize(r, size)
|
||||
}
|
||||
|
||||
// StorageTensor:
|
||||
//===============
|
||||
type StorageTensor struct {
|
||||
Source Storage
|
||||
StorageOffset int64
|
||||
Size []int64
|
||||
Stride []int64
|
||||
RequiresGrad bool
|
||||
}
|
||||
|
||||
// Rebuild Tensor:
|
||||
// ===============
|
||||
// ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L132
|
||||
// ref. def _rebuild_tensor(storage, storage_offset, size, stride):
|
||||
|
||||
type RebuildTensor struct{}
|
||||
|
||||
var _ Callable = &RebuildTensor{}
|
||||
|
||||
func (r *RebuildTensor) Call(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 4 {
|
||||
return nil, fmt.Errorf("RebuildTensor.Call() failed. Expected 4 args, got %d: %#v", len(args), args)
|
||||
}
|
||||
|
||||
storage, storageOk := args[0].(Storage)
|
||||
storageOffset, storageOffsetOk := args[1].(int)
|
||||
size, sizeOk := args[2].(*Tuple)
|
||||
stride, strideOk := args[3].(*Tuple)
|
||||
if !storageOk || !storageOffsetOk || !sizeOk || !strideOk {
|
||||
return nil, fmt.Errorf("RebuildTensor.Call() unexpected args: %#v", args)
|
||||
}
|
||||
|
||||
tensor := &StorageTensor{
|
||||
Source: storage,
|
||||
StorageOffset: int64(storageOffset),
|
||||
RequiresGrad: false,
|
||||
}
|
||||
var err error
|
||||
tensor.Size, err = tupleToInt64Slice(size)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tensor.Stride, err = tupleToInt64Slice(stride)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tensor, nil
|
||||
}
|
||||
|
||||
// RebuildTensorV2 represents a struct to rebuild tensor back from pickle object.
|
||||
type RebuildTensorV2 struct{}
|
||||
|
||||
var _ Callable = &RebuildTensorV2{}
|
||||
|
||||
func (r *RebuildTensorV2) Call(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 6 {
|
||||
return nil, fmt.Errorf("RebuildTensorV2 unexpected args: %#v", args)
|
||||
}
|
||||
|
||||
storage, storageOk := args[0].(Storage)
|
||||
storageOffset, storageOffsetOk := args[1].(int)
|
||||
size, sizeOk := args[2].(*Tuple)
|
||||
stride, strideOk := args[3].(*Tuple)
|
||||
requiresGrad, requiresGradOk := args[4].(bool)
|
||||
// arg[5] "backward hooks" is unused
|
||||
if !storageOk || !storageOffsetOk || !sizeOk || !strideOk ||
|
||||
!requiresGradOk {
|
||||
return nil, fmt.Errorf("RebuildTensorV2 unexpected args: %#v", args)
|
||||
}
|
||||
|
||||
tensor := &StorageTensor{
|
||||
Source: storage,
|
||||
StorageOffset: int64(storageOffset),
|
||||
RequiresGrad: requiresGrad,
|
||||
}
|
||||
var err error
|
||||
tensor.Size, err = tupleToInt64Slice(size)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tensor.Stride, err = tupleToInt64Slice(stride)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tensor, nil
|
||||
}
|
||||
|
||||
func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
|
||||
length := tuple.Len()
|
||||
slice := make([]int64, length)
|
||||
for i := 0; i < length; i++ {
|
||||
value, ok := tuple.Get(i).(int)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tuple of ints expected: %#v", tuple)
|
||||
}
|
||||
slice[i] = int64(value)
|
||||
}
|
||||
return slice, nil
|
||||
}
|
519
pickle/type.go
Normal file
519
pickle/type.go
Normal file
|
@ -0,0 +1,519 @@
|
|||
package pickle
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// This file contains custom types and interfaces for casting Python data types from Pickle.
|
||||
//
|
||||
// Custom Types:
|
||||
// =============
|
||||
// 1. ByteArray
|
||||
// 2. Dict
|
||||
// 3. Tuple
|
||||
// 4. OrderedDict
|
||||
// 5. List
|
||||
// 6. Set
|
||||
// 7. FrozenSet
|
||||
// 8. Object
|
||||
// 9. Reconstructor
|
||||
// 10. GenericClass
|
||||
|
||||
// Interfaces:
|
||||
// ===========
|
||||
|
||||
// Callable is implemented by any value that can be directly called to get a
|
||||
// new value.
|
||||
//
|
||||
// It is usually implemented by Python-like functions (returning a value
|
||||
// given some arguments), or classes (typically returning an instance given
|
||||
// some constructor arguments).
|
||||
type Callable interface {
|
||||
// Call mimics a direct invocation on a Python value, such as a function
|
||||
// or class (constructor).
|
||||
Call(args ...interface{}) (interface{}, error)
|
||||
}
|
||||
|
||||
// PyNewable is implemented by any value that has a Python-like
|
||||
// "__new__" method.
|
||||
//
|
||||
// It is usually implemented by values representing Python classes.
|
||||
type PyNewable interface {
|
||||
// PyNew mimics Python invocation of the "__new__" method, usually
|
||||
// provided by classes.
|
||||
//
|
||||
// See: https://docs.python.org/3/reference/datamodel.html#object.__new__
|
||||
PyNew(args ...interface{}) (interface{}, error)
|
||||
}
|
||||
|
||||
// PyStateSettable is implemented by any value that has a Python-like
|
||||
// "__setstate__" method.
|
||||
type PyStateSettable interface {
|
||||
// PySetState mimics Python invocation of the "__setstate__" method.
|
||||
//
|
||||
// See: https://docs.python.org/3/library/pickle.html#object.__setstate__
|
||||
PySetState(state interface{}) error
|
||||
}
|
||||
|
||||
// PyDictSettable is implemented by any value that can store dictionary-like
|
||||
// key/value pairs. It reflects Python behavior of setting a key/value pair on
|
||||
// an object's "__dict__" attribute.
|
||||
type PyDictSettable interface {
|
||||
// PyDictSet mimics the setting of a key/value pair on an object's
|
||||
//"__dict__" attribute.
|
||||
//
|
||||
// See: https://docs.python.org/3/library/stdtypes.html#object.__dict__
|
||||
PyDictSet(key, value interface{}) error
|
||||
}
|
||||
|
||||
// PyAttrSettable is implemented by any value on which an existing or new
|
||||
// Python-like attribute can be set. In Python this is done with "setattr"
|
||||
// builtin function.
|
||||
type PyAttrSettable interface {
|
||||
// PySetAttr mimics the setting of an arbitrary value to an object's
|
||||
// attribute.
|
||||
//
|
||||
// In Python this is done with "setattr" function, to which object,
|
||||
// attribute name, and value are passed. For an easy and clear
|
||||
// implementation, here instead we require this method to be implemented
|
||||
// on the "object" itself.
|
||||
//
|
||||
// See: https://docs.python.org/3/library/functions.html#setattr
|
||||
PySetAttr(key string, value interface{}) error
|
||||
}
|
||||
|
||||
// ByteArray:
|
||||
//===========
|
||||
|
||||
// ByteArray simulates Python bytearray.
|
||||
type ByteArray []byte
|
||||
|
||||
func NewByteArray() *ByteArray {
|
||||
arr := make(ByteArray, 0)
|
||||
return &arr
|
||||
}
|
||||
|
||||
func NewByteArrayFromSlice(slice []byte) *ByteArray {
|
||||
arr := ByteArray(slice)
|
||||
return &arr
|
||||
}
|
||||
|
||||
func (a *ByteArray) Get(i int) byte {
|
||||
return (*a)[i]
|
||||
}
|
||||
|
||||
func (a *ByteArray) Len() int {
|
||||
return len(*a)
|
||||
}
|
||||
|
||||
// Dict:
|
||||
//======
|
||||
|
||||
// DictSetter is implemented by any value that exhibits a dict-like behaviour,
|
||||
// allowing arbitrary key/value pairs to be set.
|
||||
type DictSetter interface {
|
||||
Set(key, value interface{})
|
||||
}
|
||||
|
||||
// Dict represents a Python "dict" (builtin type).
|
||||
//
|
||||
// It is implemented as a slice, instead of a map, because in Go not
|
||||
// all types can be map's keys (e.g. slices).
|
||||
type Dict []*DictEntry
|
||||
|
||||
type DictEntry struct {
|
||||
Key interface{}
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// NewDict makes and returns a new empty Dict.
|
||||
func NewDict() *Dict {
|
||||
d := make(Dict, 0)
|
||||
return &d
|
||||
}
|
||||
|
||||
// Set sets into the Dict the given key/value pair.
|
||||
func (d *Dict) Set(key, value interface{}) {
|
||||
*d = append(*d, &DictEntry{
|
||||
Key: key,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
// Get returns the value associated with the given key (if any), and whether
|
||||
// the key is present or not.
|
||||
func (d *Dict) Get(key interface{}) (interface{}, bool) {
|
||||
for _, entry := range *d {
|
||||
if reflect.DeepEqual(entry.Key, key) {
|
||||
return entry.Value, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// MustGet returns the value associated with the given key, if if it exists,
|
||||
// otherwise it panics.
|
||||
func (d *Dict) MustGet(key interface{}) interface{} {
|
||||
value, ok := d.Get(key)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("key not found in Dict: %#v", key))
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// Len returns the length of the Dict, that is, the amount of key/value pairs
|
||||
// contained by the Dict.
|
||||
func (d *Dict) Len() int {
|
||||
return len(*d)
|
||||
}
|
||||
|
||||
var _ DictSetter = &Dict{}
|
||||
|
||||
// Tuple:
|
||||
// ======
|
||||
|
||||
type Tuple []interface{}
|
||||
|
||||
func NewTupleFromSlice(slice []interface{}) *Tuple {
|
||||
t := Tuple(slice)
|
||||
return &t
|
||||
}
|
||||
|
||||
func (t *Tuple) Get(i int) interface{} {
|
||||
return (*t)[i]
|
||||
}
|
||||
|
||||
func (t *Tuple) Len() int {
|
||||
return len(*t)
|
||||
}
|
||||
|
||||
// OrderedDict:
|
||||
// ============
|
||||
|
||||
// OrderedDictClass represent Python "collections.OrderedDict" class.
|
||||
//
|
||||
// This class allows the indirect creation of OrderedDict objects.
|
||||
type OrderedDictClass struct{}
|
||||
|
||||
var _ Callable = &OrderedDictClass{}
|
||||
|
||||
// Call returns a new empty OrderedDict. It is equivalent to Python
|
||||
// constructor "collections.OrderedDict()".
|
||||
//
|
||||
// No arguments are supported.
|
||||
func (*OrderedDictClass) Call(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 0 {
|
||||
return nil, fmt.Errorf(
|
||||
"OrderedDictClass.Call args not supported: %#v", args)
|
||||
}
|
||||
return NewOrderedDict(), nil
|
||||
}
|
||||
|
||||
// OrderedDict is a minimal and trivial implementation of an ordered map,
|
||||
// which represent a Python "collections.OrderedDict" object.
|
||||
//
|
||||
// It is composed by a simple unordered Map, and a List to keep the order of
|
||||
// the entries. The former is useful for direct key lookups, the latter for
|
||||
// iteration.
|
||||
type OrderedDict struct {
|
||||
// Map associates a key of any type (interface{}) to OrderedDictEntry
|
||||
// pointer values. These values are shared with List.
|
||||
Map map[interface{}]*OrderedDictEntry
|
||||
// List is an ordered list of OrderedDictEntry pointers, which are
|
||||
// also shared with Map.
|
||||
List *list.List
|
||||
// PyDict represents Python "object.__dict__" dictionary of attributes.
|
||||
PyDict map[string]interface{}
|
||||
}
|
||||
|
||||
var _ DictSetter = &OrderedDict{}
|
||||
var _ PyDictSettable = &OrderedDict{}
|
||||
|
||||
// OrderedDictEntry is a single key/value pair stored in an OrderedDict.
|
||||
//
|
||||
// A pointer to an OrderedDictEntry is always shared between OrderedDict's Map
|
||||
// and List.
|
||||
type OrderedDictEntry struct {
|
||||
// Key of a single OrderedDict's entry.
|
||||
Key interface{}
|
||||
// Value of a single OrderedDict's entry.
|
||||
Value interface{}
|
||||
// ListElement is a pointer to the OrderedDict's List Element which
|
||||
// contains this very OrderedDictEntry.
|
||||
ListElement *list.Element
|
||||
}
|
||||
|
||||
// NewOrderedDict makes and returns a new empty OrderedDict.
|
||||
func NewOrderedDict() *OrderedDict {
|
||||
return &OrderedDict{
|
||||
Map: make(map[interface{}]*OrderedDictEntry),
|
||||
List: list.New(),
|
||||
PyDict: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets into the OrderedDict the given key/value pair. If the key does not
|
||||
// exist yet, the new pair is positioned at the end (back) of the OrderedDict.
|
||||
// If the key already exists, the existing associated value is replaced with the
|
||||
// new one, and the original position is maintained.
|
||||
func (o *OrderedDict) Set(k, v interface{}) {
|
||||
if entry, ok := o.Map[k]; ok {
|
||||
entry.Value = v
|
||||
return
|
||||
}
|
||||
|
||||
entry := &OrderedDictEntry{
|
||||
Key: k,
|
||||
Value: v,
|
||||
}
|
||||
entry.ListElement = o.List.PushBack(entry)
|
||||
o.Map[k] = entry
|
||||
}
|
||||
|
||||
// Get returns the value associated with the given key (if any), and whether
|
||||
// the key is present or not.
|
||||
func (o *OrderedDict) Get(k interface{}) (interface{}, bool) {
|
||||
entry, ok := o.Map[k]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return entry.Value, true
|
||||
}
|
||||
|
||||
// MustGet returns the value associated with the given key, if if it exists,
|
||||
// otherwise it panics.
|
||||
func (o *OrderedDict) MustGet(key interface{}) interface{} {
|
||||
value, ok := o.Get(key)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("key not found in OrderedDict: %#v", key))
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// Len returns the length of the OrderedDict, that is, the amount of key/value
|
||||
// pairs contained by the OrderedDict.
|
||||
func (o *OrderedDict) Len() int {
|
||||
return len(o.Map)
|
||||
}
|
||||
|
||||
// PyDictSet mimics the setting of a key/value pair on Python "__dict__"
|
||||
// attribute of the OrderedDict.
|
||||
func (o *OrderedDict) PyDictSet(key, value interface{}) error {
|
||||
sKey, keyOk := key.(string)
|
||||
if !keyOk {
|
||||
return fmt.Errorf(
|
||||
"OrderedDict.PyDictSet() requires string key: %#v", key)
|
||||
}
|
||||
o.PyDict[sKey] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
// List:
|
||||
// =====
|
||||
|
||||
// ListAppender is implemented by any value that exhibits a list-like
|
||||
// behaviour, allowing arbitrary values to be appended.
|
||||
type ListAppender interface {
|
||||
Append(v interface{})
|
||||
}
|
||||
|
||||
// List represents a Python "list" (builtin type).
|
||||
type List []interface{}
|
||||
|
||||
var _ ListAppender = &List{}
|
||||
|
||||
// NewList makes and returns a new empty List.
|
||||
func NewList() *List {
|
||||
l := make(List, 0)
|
||||
return &l
|
||||
}
|
||||
|
||||
// NewListFromSlice makes and returns a new List initialized with the elements
|
||||
// of the given slice.
|
||||
//
|
||||
// The new List is a simple type cast of the input slice; the slice is _not_
|
||||
// copied.
|
||||
func NewListFromSlice(slice []interface{}) *List {
|
||||
l := List(slice)
|
||||
return &l
|
||||
}
|
||||
|
||||
// Append appends one element to the end of the List.
|
||||
func (l *List) Append(v interface{}) {
|
||||
*l = append(*l, v)
|
||||
}
|
||||
|
||||
// Get returns the element of the List at the given index.
|
||||
//
|
||||
// It panics if the index is out of range.
|
||||
func (l *List) Get(i int) interface{} {
|
||||
return (*l)[i]
|
||||
}
|
||||
|
||||
// Len returns the length of the List.
|
||||
func (l *List) Len() int {
|
||||
return len(*l)
|
||||
}
|
||||
|
||||
// Set:
|
||||
// ====
|
||||
|
||||
// SetAdder is implemented by any value that exhibits a set-like behaviour,
|
||||
// allowing arbitrary values to be added.
|
||||
type SetAdder interface {
|
||||
Add(v interface{})
|
||||
}
|
||||
|
||||
// Set represents a Python "set" (builtin type).
|
||||
//
|
||||
// It is implemented in Go as a map with empty struct values; the actual set
|
||||
// of generic "interface{}" items is thus represented by all the keys.
|
||||
type Set map[interface{}]setEmptyStruct
|
||||
|
||||
var _ SetAdder = &Set{}
|
||||
|
||||
type setEmptyStruct struct{}
|
||||
|
||||
// NewSet makes and returns a new empty Set.
|
||||
func NewSet() *Set {
|
||||
s := make(Set)
|
||||
return &s
|
||||
}
|
||||
|
||||
// NewSetFromSlice makes and returns a new Set initialized with the elements
|
||||
// of the given slice.
|
||||
func NewSetFromSlice(slice []interface{}) *Set {
|
||||
s := make(Set, len(slice))
|
||||
for _, item := range slice {
|
||||
s[item] = setEmptyStruct{}
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
// Len returns the length of the Set.
|
||||
func (s *Set) Len() int {
|
||||
return len(*s)
|
||||
}
|
||||
|
||||
// Add adds one element to the Set.
|
||||
func (s *Set) Add(v interface{}) {
|
||||
(*s)[v] = setEmptyStruct{}
|
||||
}
|
||||
|
||||
// Has returns whether the given value is present in the Set (true)
|
||||
// or not (false).
|
||||
func (s *Set) Has(v interface{}) bool {
|
||||
_, ok := (*s)[v]
|
||||
return ok
|
||||
}
|
||||
|
||||
// FrozenSet:
|
||||
//===========
|
||||
|
||||
// FrozenSet represents a Python "frozenset" (builtin type).
|
||||
//
|
||||
// It is implemented in Go as a map with empty struct values; the actual set
|
||||
// of generic "interface{}" items is thus represented by all the keys.
|
||||
type FrozenSet map[interface{}]frozenSetEmptyStruct
|
||||
|
||||
type frozenSetEmptyStruct struct{}
|
||||
|
||||
// NewFrozenSetFromSlice makes and returns a new FrozenSet initialized
|
||||
// with the elements of the given slice.
|
||||
func NewFrozenSetFromSlice(slice []interface{}) *FrozenSet {
|
||||
f := make(FrozenSet, len(slice))
|
||||
for _, item := range slice {
|
||||
f[item] = frozenSetEmptyStruct{}
|
||||
}
|
||||
return &f
|
||||
}
|
||||
|
||||
// Len returns the length of the FrozenSet.
|
||||
func (f *FrozenSet) Len() int {
|
||||
return len(*f)
|
||||
}
|
||||
|
||||
// Has returns whether the given value is present in the FrozenSet (true)
|
||||
// or not (false).
|
||||
func (f *FrozenSet) Has(v interface{}) bool {
|
||||
_, ok := (*f)[v]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Object:
|
||||
//========
|
||||
|
||||
type ObjectClass struct{}
|
||||
|
||||
var _ PyNewable = &ObjectClass{}
|
||||
|
||||
func (o *ObjectClass) PyNew(args ...interface{}) (interface{}, error) {
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("ObjectClass.PyNew called with no arguments")
|
||||
}
|
||||
switch class := args[0].(type) {
|
||||
case PyNewable:
|
||||
return class.PyNew()
|
||||
default:
|
||||
return nil, fmt.Errorf(
|
||||
"ObjectClass.PyNew unprocessable args: %#v", args)
|
||||
}
|
||||
}
|
||||
|
||||
// Reconstructor:
|
||||
//===============
|
||||
|
||||
type Reconstructor struct{}
|
||||
|
||||
var _ Callable = &Reconstructor{}
|
||||
|
||||
func (r *Reconstructor) Call(args ...interface{}) (interface{}, error) {
|
||||
if len(args) < 2 {
|
||||
return nil, fmt.Errorf("Reconstructor: invalid arguments: %#v", args)
|
||||
}
|
||||
class := args[0]
|
||||
switch base := args[1].(type) {
|
||||
case PyNewable:
|
||||
return base.PyNew(class)
|
||||
default:
|
||||
return nil, fmt.Errorf(
|
||||
"Reconstructor: unprocessable arguments: %#v", args)
|
||||
}
|
||||
}
|
||||
|
||||
// GenericClass:
|
||||
//==============
|
||||
|
||||
type GenericClass struct {
|
||||
Module string
|
||||
Name string
|
||||
}
|
||||
|
||||
var _ PyNewable = &GenericClass{}
|
||||
|
||||
type GenericObject struct {
|
||||
Class *GenericClass
|
||||
ConstructorArgs []interface{}
|
||||
}
|
||||
|
||||
func NewGenericClass(module, name string) *GenericClass {
|
||||
return &GenericClass{Module: module, Name: name}
|
||||
}
|
||||
|
||||
func (g *GenericClass) PyNew(args ...interface{}) (interface{}, error) {
|
||||
return &GenericObject{
|
||||
Class: g,
|
||||
ConstructorArgs: args,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getThnnFunctionBackend is for historical pickle deserilaization, it is not used otherwise
|
||||
type getThnnFunctionBackend struct{}
|
||||
|
||||
var _ Callable = &getThnnFunctionBackend{}
|
||||
|
||||
func (getThnnFunctionBackend) Call(_ ...interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
112
pickle/util.go
Normal file
112
pickle/util.go
Normal file
|
@ -0,0 +1,112 @@
|
|||
package pickle
|
||||
|
||||
import "io"
|
||||
|
||||
// Converts the bits representation of a Half Float (16 bits) number to
|
||||
// an IEEE 754 float representation (32 bits)
|
||||
// From http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
|
||||
func FloatBits16to32(u16 uint16) uint32 {
|
||||
return mantissaTable[offsetTable[u16>>10]+(uint32(u16)&0x3ff)] + exponentTable[u16>>10]
|
||||
}
|
||||
|
||||
var mantissaTable [2048]uint32
|
||||
var exponentTable [64]uint32
|
||||
var offsetTable [64]uint32
|
||||
|
||||
func init() {
|
||||
initMantissaTable()
|
||||
initExponentTable()
|
||||
initOffsetTable()
|
||||
}
|
||||
|
||||
func initMantissaTable() {
|
||||
mantissaTable[0] = 0
|
||||
for i := uint32(1); i < 1024; i++ {
|
||||
mantissaTable[i] = convertMantissa(i)
|
||||
}
|
||||
for i := uint32(1024); i < 2048; i++ {
|
||||
mantissaTable[i] = 0x38000000 + ((i - 1024) << 13)
|
||||
}
|
||||
}
|
||||
|
||||
func initExponentTable() {
|
||||
exponentTable[0] = 0
|
||||
exponentTable[31] = 0x47800000
|
||||
exponentTable[32] = 0x80000000
|
||||
exponentTable[63] = 0xC7800000
|
||||
for i := uint32(1); i < 31; i++ {
|
||||
exponentTable[i] = i << 23
|
||||
}
|
||||
for i := uint32(33); i < 63; i++ {
|
||||
exponentTable[i] = 0x80000000 + (i-32)<<23
|
||||
}
|
||||
}
|
||||
|
||||
func initOffsetTable() {
|
||||
offsetTable[0] = 0
|
||||
offsetTable[32] = 0
|
||||
for i := uint32(1); i < 31; i++ {
|
||||
offsetTable[i] = 1024
|
||||
}
|
||||
for i := uint32(32); i < 64; i++ {
|
||||
offsetTable[i] = 1024
|
||||
}
|
||||
}
|
||||
|
||||
func convertMantissa(i uint32) uint32 {
|
||||
var m uint32 = i << 13 // zero pad mantissa bits
|
||||
var e uint32 = 0 // zero exponent
|
||||
for m&0x00800000 != 0 { // while not normalized
|
||||
e -= 0x00800000 // decrement exponent (1 << 23)
|
||||
m <<= 1 // shift mantissa
|
||||
}
|
||||
m &= ^uint32(0x00800000) // clear leading 1 bit
|
||||
e += 0x38800000 // adjust bias ((127-14)<<23)
|
||||
return m | e // return combined number
|
||||
}
|
||||
|
||||
type LimitedBufferReader struct {
|
||||
r io.Reader
|
||||
scalarSize int
|
||||
remainingBytes int
|
||||
buf []byte
|
||||
bufIndex int
|
||||
}
|
||||
|
||||
func NewLimitedBufferReader(
|
||||
r io.Reader,
|
||||
dataSize, scalarSize, bufferSize int,
|
||||
) *LimitedBufferReader {
|
||||
size := bufferSize * scalarSize
|
||||
return &LimitedBufferReader{
|
||||
r: r,
|
||||
scalarSize: scalarSize,
|
||||
remainingBytes: scalarSize * dataSize,
|
||||
buf: make([]byte, size),
|
||||
bufIndex: size,
|
||||
}
|
||||
}
|
||||
|
||||
func (br *LimitedBufferReader) HasNext() bool {
|
||||
return br.remainingBytes != 0
|
||||
}
|
||||
|
||||
func (br *LimitedBufferReader) ReadNext() ([]byte, error) {
|
||||
if br.remainingBytes == 0 {
|
||||
return nil, io.EOF
|
||||
}
|
||||
if br.bufIndex == len(br.buf) {
|
||||
br.bufIndex = 0
|
||||
if br.remainingBytes < len(br.buf) {
|
||||
br.buf = br.buf[0:br.remainingBytes]
|
||||
}
|
||||
_, err := br.r.Read(br.buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
result := br.buf[br.bufIndex : br.bufIndex+br.scalarSize]
|
||||
br.bufIndex += br.scalarSize
|
||||
br.remainingBytes -= br.scalarSize
|
||||
return result, nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user