added subpackage 'pickle' and file-util
This commit is contained in:
parent
47c6c60561
commit
dc4ab3047c
|
@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
- Added subpackage `pickle`. Now we can load directly Python Pytorch pretrained model without any Python script conversion.
|
||||||
|
- Added `gotch.CachePath()` and `gotch.ModelUrls`
|
||||||
- Remove Travis CI for now.
|
- Remove Travis CI for now.
|
||||||
- fixed `tensor.OfSlice()` throw error due to "Unsupported Go type" (e.g. []float32)
|
- fixed `tensor.OfSlice()` throw error due to "Unsupported Go type" (e.g. []float32)
|
||||||
- added `nn.Path.Paths()` method
|
- added `nn.Path.Paths()` method
|
||||||
|
|
36
example/pickle/main.go
Normal file
36
example/pickle/main.go
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
"github.com/sugarme/gotch/nn"
|
||||||
|
"github.com/sugarme/gotch/pickle"
|
||||||
|
"github.com/sugarme/gotch/vision"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
device := gotch.CPU
|
||||||
|
vs := nn.NewVarStore(device)
|
||||||
|
net := vision.VGG16(vs.Root(), 1000)
|
||||||
|
|
||||||
|
modelName := "vgg16"
|
||||||
|
modelUrl, ok := gotch.ModelUrls[modelName]
|
||||||
|
if !ok {
|
||||||
|
log.Fatal("model name %q not found.", modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelFile, err := gotch.CachedPath(modelUrl)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pickle.LoadAll(vs, modelFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("%v\n", net)
|
||||||
|
vs.Summary()
|
||||||
|
}
|
289
file-util.go
Normal file
289
file-util.go
Normal file
|
@ -0,0 +1,289 @@
|
||||||
|
package gotch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This file provides functions to work with local dataset cache, ...
|
||||||
|
|
||||||
|
// ModelUrls maps model name to its pretrained URL.
|
||||||
|
//
|
||||||
|
// This URLS taken from separate models in pytorch/vision repository
|
||||||
|
// https://github.com/pytorch/vision/tree/main/torchvision/models
|
||||||
|
var ModelUrls map[string]string = map[string]string{
|
||||||
|
"alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
|
||||||
|
|
||||||
|
"convnext_tiny": "https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
|
||||||
|
"convnext_small": "https://download.pytorch.org/models/convnext_small-0c510722.pth",
|
||||||
|
"convnext_base": "https://download.pytorch.org/models/convnext_base-6075fbad.pth",
|
||||||
|
"convnext_large": "https://download.pytorch.org/models/convnext_large-ea097f82.pth",
|
||||||
|
|
||||||
|
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
|
||||||
|
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
|
||||||
|
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
|
||||||
|
"densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
|
||||||
|
|
||||||
|
//Weights ported from https://github.com/rwightman/pytorch-image-models/
|
||||||
|
"efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
|
||||||
|
"efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
|
||||||
|
"efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
|
||||||
|
"efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
|
||||||
|
"efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
|
||||||
|
//Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
|
||||||
|
"efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
|
||||||
|
"efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
|
||||||
|
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
|
||||||
|
|
||||||
|
//GoogLeNet ported from TensorFlow
|
||||||
|
"googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
|
||||||
|
|
||||||
|
//Inception v3 ported from TensorFlow
|
||||||
|
"inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
|
||||||
|
|
||||||
|
"mnasnet0_5": "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
|
||||||
|
"mnasnet0_75": "",
|
||||||
|
"mnasnet1_0": "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
|
||||||
|
"mnasnet1_3": "",
|
||||||
|
|
||||||
|
"mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
|
||||||
|
"mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
|
||||||
|
"mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
|
||||||
|
|
||||||
|
"regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
|
||||||
|
"regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
|
||||||
|
"regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
|
||||||
|
"regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
|
||||||
|
"regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
|
||||||
|
"regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
|
||||||
|
"regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
|
||||||
|
"regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
|
||||||
|
"regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
|
||||||
|
"regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
|
||||||
|
"regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
|
||||||
|
"regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
|
||||||
|
"regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
|
||||||
|
"regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
|
||||||
|
|
||||||
|
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
|
||||||
|
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
|
||||||
|
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
|
||||||
|
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
|
||||||
|
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
|
||||||
|
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
|
||||||
|
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
|
||||||
|
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
|
||||||
|
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
|
||||||
|
|
||||||
|
"shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
|
||||||
|
"shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
|
||||||
|
"shufflenetv2_x1.5": "",
|
||||||
|
"shufflenetv2_x2.0": "",
|
||||||
|
|
||||||
|
"squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
|
||||||
|
"squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
|
||||||
|
|
||||||
|
"vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
|
||||||
|
"vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
|
||||||
|
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
|
||||||
|
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
|
||||||
|
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
|
||||||
|
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
|
||||||
|
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
|
||||||
|
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
|
||||||
|
|
||||||
|
"vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth",
|
||||||
|
"vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",
|
||||||
|
"vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",
|
||||||
|
"vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth",
|
||||||
|
}
|
||||||
|
|
||||||
|
// CachedPath resolves and caches data based on input string, then returns fullpath to the cached data.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - `filenameOrUrl`: full path to filename or url
|
||||||
|
//
|
||||||
|
// CachedPath does several things consequently:
|
||||||
|
// 1. Resolves input string to a fullpath cached filename candidate.
|
||||||
|
// 2. Check it at `CachePath`, if exists, then return the candidate. If not
|
||||||
|
// 3. Retrieves and Caches data to `CachePath` and returns path to cached data
|
||||||
|
func CachedPath(filenameOrUrl string) (resolvedPath string, err error) {
|
||||||
|
filename := path.Base(filenameOrUrl)
|
||||||
|
// Resolves to "candidate" filename at `CacheDir`
|
||||||
|
cachedFileCandidate := fmt.Sprintf("%s/%s", CacheDir, filename)
|
||||||
|
|
||||||
|
// 1. Cached candidate file exists
|
||||||
|
if _, err := os.Stat(cachedFileCandidate); err == nil {
|
||||||
|
return cachedFileCandidate, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. If valid fullpath to local file, caches it and return cached filename
|
||||||
|
if _, err := os.Stat(filenameOrUrl); err == nil {
|
||||||
|
err := copyFile(filenameOrUrl, cachedFileCandidate)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return cachedFileCandidate, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Cached candidate file NOT exist. Try to download it and save to `CacheDir`
|
||||||
|
if isValidURL(filenameOrUrl) {
|
||||||
|
if _, err := http.Get(filenameOrUrl); err == nil {
|
||||||
|
err := downloadFile(filenameOrUrl, cachedFileCandidate)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return cachedFileCandidate, nil
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Error: %v\n", err)
|
||||||
|
err = fmt.Errorf("Unable to parse %q as a URL or as a local path.\n", filenameOrUrl)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not resolves
|
||||||
|
err = fmt.Errorf("Unable to parse %q as a URL or as a local path.\n", filenameOrUrl)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidURL(url string) bool {
|
||||||
|
|
||||||
|
// TODO: implement
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// downloadFile downloads file from URL and stores it in local filepath.
|
||||||
|
// It writes to the destination file as it downloads it, without loading
|
||||||
|
// the entire file into memory. An `io.TeeReader` is passed into Copy()
|
||||||
|
// to report progress on the download.
|
||||||
|
func downloadFile(url string, filepath string) error {
|
||||||
|
// Create path if not existing
|
||||||
|
dir := path.Dir(filepath)
|
||||||
|
filename := path.Base(filepath)
|
||||||
|
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||||
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the file with .tmp extension, so that we won't overwrite a
|
||||||
|
// file until it's downloaded fully
|
||||||
|
out, err := os.Create(filepath + ".tmp")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer out.Close()
|
||||||
|
|
||||||
|
// Get the data
|
||||||
|
resp, err := http.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Check server response
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
err := fmt.Errorf("bad status: %s(%v)", resp.Status, resp.StatusCode)
|
||||||
|
if resp.StatusCode == 404 {
|
||||||
|
err = fmt.Errorf("download file not found: %q for downloading", url)
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("download file failed: %q", url)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// the total file size to download
|
||||||
|
size, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
|
||||||
|
downloadSize := uint64(size)
|
||||||
|
|
||||||
|
// Create our bytes counter and pass it to be used alongside our writer
|
||||||
|
counter := &writeCounter{FileSize: downloadSize}
|
||||||
|
_, err = io.Copy(out, io.TeeReader(resp.Body, counter))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("\r%s... %s/%s completed", filename, byteCountIEC(counter.Total), byteCountIEC(counter.FileSize))
|
||||||
|
// The progress use the same line so print a new line once it's finished downloading
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
// Rename the tmp file back to the original file
|
||||||
|
err = os.Rename(filepath+".tmp", filepath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeCounter counts the number of bytes written to it. By implementing the Write method,
|
||||||
|
// it is of the io.Writer interface and we can pass this into io.TeeReader()
|
||||||
|
// Every write to this writer, will print the progress of the file write.
|
||||||
|
type writeCounter struct {
|
||||||
|
Total uint64
|
||||||
|
FileSize uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wc *writeCounter) Write(p []byte) (int, error) {
|
||||||
|
n := len(p)
|
||||||
|
wc.Total += uint64(n)
|
||||||
|
wc.printProgress()
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrintProgress prints the progress of a file write
|
||||||
|
func (wc writeCounter) printProgress() {
|
||||||
|
// Clear the line by using a character return to go back to the start and remove
|
||||||
|
// the remaining characters by filling it with spaces
|
||||||
|
fmt.Printf("\r%s", strings.Repeat(" ", 50))
|
||||||
|
|
||||||
|
// Return again and print current status of download
|
||||||
|
fmt.Printf("\rDownloading... %s/%s", byteCountIEC(wc.Total), byteCountIEC(wc.FileSize))
|
||||||
|
}
|
||||||
|
|
||||||
|
// byteCountIEC converts bytes to human-readable string in binary (IEC) format.
|
||||||
|
func byteCountIEC(b uint64) string {
|
||||||
|
const unit = 1024
|
||||||
|
if b < unit {
|
||||||
|
return fmt.Sprintf("%d B", b)
|
||||||
|
}
|
||||||
|
div, exp := uint64(unit), 0
|
||||||
|
for n := b / unit; n >= unit; n /= unit {
|
||||||
|
div *= unit
|
||||||
|
exp++
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%.1f %ciB",
|
||||||
|
float64(b)/float64(div), "KMGTPE"[exp])
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyFile(src, dst string) error {
|
||||||
|
sourceFileStat, err := os.Stat(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !sourceFileStat.Mode().IsRegular() {
|
||||||
|
return fmt.Errorf("%s is not a regular file", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
source, err := os.Open(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer source.Close()
|
||||||
|
|
||||||
|
destination, err := os.Create(dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer destination.Close()
|
||||||
|
_, err = io.Copy(destination, source)
|
||||||
|
return err
|
||||||
|
}
|
35
init.go
Normal file
35
init.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package gotch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
CacheDir string = "NOT_SETTING"
|
||||||
|
gotchEnvKey string = "GOTCH_CACHE"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// default path: {$HOME}/.cache/gotch
|
||||||
|
homeDir := os.Getenv("HOME")
|
||||||
|
CacheDir = fmt.Sprintf("%s/.cache/transformer", homeDir)
|
||||||
|
|
||||||
|
initEnv()
|
||||||
|
|
||||||
|
log.Printf("INFO: CacheDir=%q\n", CacheDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func initEnv() {
|
||||||
|
val := os.Getenv(gotchEnvKey)
|
||||||
|
if val != "" {
|
||||||
|
CacheDir = val
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(CacheDir); os.IsNotExist(err) {
|
||||||
|
if err := os.MkdirAll(CacheDir, 0755); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
1937
pickle/pickle.go
Normal file
1937
pickle/pickle.go
Normal file
File diff suppressed because it is too large
Load Diff
60
pickle/pickle_example_test.go
Normal file
60
pickle/pickle_example_test.go
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
package pickle_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
"github.com/sugarme/gotch/pickle"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExampleLoadInfo() {
|
||||||
|
modelName := "vgg16"
|
||||||
|
url, ok := gotch.ModelUrls[modelName]
|
||||||
|
if !ok {
|
||||||
|
log.Fatalf("Unsupported model name %q\n", modelName)
|
||||||
|
}
|
||||||
|
modelFile, err := gotch.CachedPath(url)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pickle.LoadInfo(modelFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// classifier.0.bias - [4096]
|
||||||
|
// classifier.0.weight - [4096 25088]
|
||||||
|
// classifier.3.bias - [4096]
|
||||||
|
// classifier.3.weight - [4096 4096]
|
||||||
|
// classifier.6.bias - [1000]
|
||||||
|
// classifier.6.weight - [1000 4096]
|
||||||
|
// features.0.bias - [64]
|
||||||
|
// features.0.weight - [64 3 3 3]
|
||||||
|
// features.10.bias - [256]
|
||||||
|
// features.10.weight - [256 128 3 3]
|
||||||
|
// features.12.bias - [256]
|
||||||
|
// features.12.weight - [256 256 3 3]
|
||||||
|
// features.14.bias - [256]
|
||||||
|
// features.14.weight - [256 256 3 3]
|
||||||
|
// features.17.bias - [512]
|
||||||
|
// features.17.weight - [512 256 3 3]
|
||||||
|
// features.19.bias - [512]
|
||||||
|
// features.19.weight - [512 512 3 3]
|
||||||
|
// features.2.bias - [64]
|
||||||
|
// features.2.weight - [64 64 3 3]
|
||||||
|
// features.21.bias - [512]
|
||||||
|
// features.21.weight - [512 512 3 3]
|
||||||
|
// features.24.bias - [512]
|
||||||
|
// features.24.weight - [512 512 3 3]
|
||||||
|
// features.26.bias - [512]
|
||||||
|
// features.26.weight - [512 512 3 3]
|
||||||
|
// features.28.bias - [512]
|
||||||
|
// features.28.weight - [512 512 3 3]
|
||||||
|
// features.5.bias - [128]
|
||||||
|
// features.5.weight - [128 64 3 3]
|
||||||
|
// features.7.bias - [128]
|
||||||
|
// features.7.weight - [128 128 3 3]
|
||||||
|
// Num of variables: 32
|
||||||
|
}
|
519
pickle/serialization.go
Normal file
519
pickle/serialization.go
Normal file
|
@ -0,0 +1,519 @@
|
||||||
|
package pickle
|
||||||
|
|
||||||
|
// Ref.
|
||||||
|
// https://docs.python.org/3/library/pickle.html
|
||||||
|
// https://docs.python.org/3/library/pickletools.html
|
||||||
|
// https://github.com/python/cpython/blob/main/Lib/pickle.py (****real code here****)
|
||||||
|
// https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||||
|
// https://github.com/pytorch/pytorch/blob/master/torch/serialization.py
|
||||||
|
|
||||||
|
import (
|
||||||
|
"archive/tar"
|
||||||
|
"archive/zip"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/big"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch/nn"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
const hexMagicNumber = "1950a86a20f9469cfc6c"
|
||||||
|
const protocolVersion = 1001
|
||||||
|
|
||||||
|
var ErrInvalidMagicNumber = errors.New("invalid pytorch magic number")
|
||||||
|
var ErrInvalidProtocolVersion = errors.New("invalid pytorch protocol version")
|
||||||
|
|
||||||
|
// Encode encodes model using pickling machinery.
|
||||||
|
// Output pickled model can be loads with Python Pytorch as `torch.load("pytorch_model.bin")`
|
||||||
|
//
|
||||||
|
// TODO. implement pickling part so that model can be exported and load with Python Pytorch.
|
||||||
|
// See https://github.com/python/cpython/blob/b0de6299a840a397d4fe3e6c98159d9f258d3295/Lib/pickle.py#L407
|
||||||
|
func Encode(model ts.Module, outputFile string) error {
|
||||||
|
panic("NotImplementedError")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode decodes pickled data created by 'torch.save()' with Python Pytorch
|
||||||
|
// and rebuilds named tensor weights.
|
||||||
|
func Decode(filename string) (map[string]*ts.Tensor, error) {
|
||||||
|
newUnpickler := func(r io.Reader) Unpickler {
|
||||||
|
return NewUnpickler(r)
|
||||||
|
}
|
||||||
|
result, err := LoadWithUnpickler(filename, newUnpickler)
|
||||||
|
if err != nil {
|
||||||
|
err := fmt.Errorf("Decode() failed: %w", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rebuild tensors from Storage tensors
|
||||||
|
namedTensors := make(map[string]*ts.Tensor)
|
||||||
|
dictResult, isOrderedDict := result.(*OrderedDict)
|
||||||
|
if !isOrderedDict {
|
||||||
|
err := fmt.Errorf("Decode() failed: expected 'OrderedDict' type, got %v\n", reflect.TypeOf(result).String())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for name, item := range dictResult.Map {
|
||||||
|
sx, isStorageTensor := item.Value.(*StorageTensor)
|
||||||
|
if !isStorageTensor {
|
||||||
|
err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data := sx.Source.GetData()
|
||||||
|
size := sx.Size
|
||||||
|
dtype := sx.Source.DType()
|
||||||
|
device := sx.Source.Device()
|
||||||
|
stride := sx.Stride
|
||||||
|
storageOffset := sx.StorageOffset
|
||||||
|
|
||||||
|
// fmt.Printf("%q - shape: %v - stride: %v - storageOffset: %v\n", sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
|
||||||
|
|
||||||
|
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
||||||
|
if sx.RequiresGrad {
|
||||||
|
x.MustRequiresGrad_(sx.RequiresGrad)
|
||||||
|
}
|
||||||
|
|
||||||
|
namedTensors[name.(string)] = x
|
||||||
|
}
|
||||||
|
|
||||||
|
return namedTensors, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadWithUnpickler is like Load, but it accepts a newUnpickler function which
|
||||||
|
// is used to create new customized pickle.Unpickler instances.
|
||||||
|
func LoadWithUnpickler(filename string, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) {
|
||||||
|
if !isZipFile(filename) {
|
||||||
|
return loadLegacyFile(filename, newUnpickler)
|
||||||
|
}
|
||||||
|
return loadZipFile(filename, newUnpickler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadZipFile(filename string, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) {
|
||||||
|
// Open a zip archive for reading.
|
||||||
|
r, err := zip.OpenReader(filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
fileRecords := make(map[string]*zip.File, len(r.File))
|
||||||
|
for _, f := range r.File {
|
||||||
|
_, recordName := path.Split(f.Name)
|
||||||
|
fileRecords[recordName] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, isTorchScript := fileRecords["constants.pkl"]; isTorchScript {
|
||||||
|
return nil, fmt.Errorf("TorchScript is not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
dataFile, hasDataFile := fileRecords["data.pkl"]
|
||||||
|
if !hasDataFile {
|
||||||
|
return nil, fmt.Errorf("data.pkl not found in zip file")
|
||||||
|
}
|
||||||
|
df, err := dataFile.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer df.Close()
|
||||||
|
|
||||||
|
loadedStorages := make(map[string]Storage)
|
||||||
|
|
||||||
|
u := newUnpickler(df)
|
||||||
|
u.FindClass = makePickleFindClass(u.FindClass)
|
||||||
|
u.PersistentLoad = func(savedId interface{}) (interface{}, error) {
|
||||||
|
tuple, tupleOk := savedId.(*Tuple)
|
||||||
|
if !tupleOk || tuple.Len() == 0 {
|
||||||
|
return nil, fmt.Errorf("PersistentLoad: non-empty tuple expected, got %#v", savedId)
|
||||||
|
}
|
||||||
|
typename, typenameOk := tuple.Get(0).(string)
|
||||||
|
if !typenameOk {
|
||||||
|
return nil, fmt.Errorf("PersistentLoad: cannot get typename")
|
||||||
|
}
|
||||||
|
if typename != "storage" {
|
||||||
|
return nil, fmt.Errorf("unknown typename for PersistentLoad, expected 'storage' but got '%s'", typename)
|
||||||
|
}
|
||||||
|
if tuple.Len() < 5 {
|
||||||
|
return nil, fmt.Errorf("PersistentLoad: unexpected storage data length")
|
||||||
|
}
|
||||||
|
dataType, dataTypeOk := tuple.Get(1).(StorageClass)
|
||||||
|
key, keyOk := tuple.Get(2).(string)
|
||||||
|
location, locationOk := tuple.Get(3).(string)
|
||||||
|
size, sizeOk := tuple.Get(4).(int)
|
||||||
|
if !dataTypeOk || !keyOk || !locationOk || !sizeOk {
|
||||||
|
return nil, fmt.Errorf("PersistentLoad: unexpected data types")
|
||||||
|
}
|
||||||
|
storage, storageExists := loadedStorages[key]
|
||||||
|
if !storageExists {
|
||||||
|
storage, err = loadTensor(dataType, size, location, key, fileRecords)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
loadedStorages[key] = storage
|
||||||
|
}
|
||||||
|
return storage, nil
|
||||||
|
}
|
||||||
|
return u.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadTensor(
|
||||||
|
dataType StorageClass,
|
||||||
|
size int,
|
||||||
|
location, key string,
|
||||||
|
zipFileRecords map[string]*zip.File,
|
||||||
|
) (Storage, error) {
|
||||||
|
file, fileOk := zipFileRecords[key]
|
||||||
|
if !fileOk {
|
||||||
|
return nil, fmt.Errorf("cannot find zip record '%s'", key)
|
||||||
|
}
|
||||||
|
f, err := file.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
storage := dataType.New(size, location)
|
||||||
|
err = storage.SetFromFileWithSize(f, size)
|
||||||
|
return storage, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadLegacyFile(filename string, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) {
|
||||||
|
f, err := os.Open(filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
tr := tar.NewReader(f)
|
||||||
|
for {
|
||||||
|
_, err := tr.Next()
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
// TODO: ...
|
||||||
|
panic("legacy load from tar not implemented")
|
||||||
|
case io.EOF:
|
||||||
|
break // End of archive
|
||||||
|
case tar.ErrHeader, io.ErrUnexpectedEOF:
|
||||||
|
_, err = f.Seek(0, io.SeekStart)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return loadLegacyNoTar(f, newUnpickler)
|
||||||
|
default:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadLegacyNoTar(f *os.File, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) {
|
||||||
|
if err := readAndCheckMagicNumber(f); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := readAndChecProtocolVersion(f); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err := unpickle(f); err != nil { // sys info
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
deserializedObjects := make(map[string]Storage)
|
||||||
|
|
||||||
|
u := newUnpickler(f)
|
||||||
|
u.FindClass = makePickleFindClass(u.FindClass)
|
||||||
|
u.PersistentLoad = func(savedId interface{}) (interface{}, error) {
|
||||||
|
tuple, tupleOk := savedId.(*Tuple)
|
||||||
|
if !tupleOk || tuple.Len() == 0 {
|
||||||
|
return nil, fmt.Errorf("PersistentLoad: non-empty tuple expected, got %#v", savedId)
|
||||||
|
}
|
||||||
|
typename, typenameOk := tuple.Get(0).(string)
|
||||||
|
if !typenameOk {
|
||||||
|
return nil, fmt.Errorf("PersistentLoad: cannot get typename")
|
||||||
|
}
|
||||||
|
|
||||||
|
// fmt.Printf("typename: %s\n", typename)
|
||||||
|
|
||||||
|
switch typename {
|
||||||
|
case "storage":
|
||||||
|
if tuple.Len() < 6 {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"PersistentLoad: unexpected storage data length")
|
||||||
|
}
|
||||||
|
dataType, dataTypeOk := tuple.Get(1).(StorageClass)
|
||||||
|
rootKey, rootKeyOk := tuple.Get(2).(string)
|
||||||
|
location, locationOk := tuple.Get(3).(string)
|
||||||
|
size, sizeOk := tuple.Get(4).(int)
|
||||||
|
viewMetadata := tuple.Get(5)
|
||||||
|
if !dataTypeOk || !rootKeyOk || !locationOk || !sizeOk {
|
||||||
|
return nil, fmt.Errorf("PersistentLoad: unexpected data types")
|
||||||
|
}
|
||||||
|
|
||||||
|
// fmt.Printf("dtype: %v - rootKey: %v - device: %v - size %v - viewMetaData: %v\n", reflect.TypeOf(dataType), rootKey, location, size, viewMetadata)
|
||||||
|
|
||||||
|
storage, storageExists := deserializedObjects[rootKey]
|
||||||
|
if !storageExists {
|
||||||
|
storage = dataType.New(size, location)
|
||||||
|
deserializedObjects[rootKey] = storage
|
||||||
|
}
|
||||||
|
switch vm := viewMetadata.(type) {
|
||||||
|
case nil:
|
||||||
|
return storage, nil
|
||||||
|
case []interface{}:
|
||||||
|
if len(vm) != 3 {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"PersistentLoad: unexpected view metadata length")
|
||||||
|
}
|
||||||
|
panic("viewMetadata not implemented")
|
||||||
|
// TODO: ...
|
||||||
|
// view_key, offset, view_size = view_metadata
|
||||||
|
// if view_key not in deserialized_objects:
|
||||||
|
// deserialized_objects[view_key] = storage[offset:offset + view_size]
|
||||||
|
// return deserialized_objects[view_key]
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("PersistentLoad: unexpected view metadata type")
|
||||||
|
}
|
||||||
|
case "module":
|
||||||
|
if tuple.Len() < 2 {
|
||||||
|
return nil, fmt.Errorf("PersistentLoad: unexpected module data length")
|
||||||
|
}
|
||||||
|
return tuple.Get(1), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("Unexpected saved ID type: %s", typename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := u.Load()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rawStorageKeys, err := unpickle(f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
storageKeys, err := makeStorageKeys(rawStorageKeys)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range storageKeys {
|
||||||
|
storageObj, ok := deserializedObjects[key]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("storage object not found for key '%s'", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = storageObj.SetFromFile(f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeStorageKeys(obj interface{}) ([]string, error) {
|
||||||
|
list, ok := obj.(*List)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid storage keys data")
|
||||||
|
}
|
||||||
|
keys := make([]string, len(*list))
|
||||||
|
for i, rawKey := range *list {
|
||||||
|
key, keyOk := rawKey.(string)
|
||||||
|
if !keyOk {
|
||||||
|
return nil, fmt.Errorf("invalid storage key")
|
||||||
|
}
|
||||||
|
keys[i] = key
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readAndCheckMagicNumber(r io.Reader) error {
|
||||||
|
obj, err := unpickle(r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if n, ok := obj.(*big.Int); !ok || n.Text(16) != hexMagicNumber {
|
||||||
|
return ErrInvalidMagicNumber
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readAndChecProtocolVersion(r io.Reader) error {
|
||||||
|
obj, err := unpickle(r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if n, ok := obj.(int); !ok || n != protocolVersion {
|
||||||
|
return ErrInvalidProtocolVersion
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func unpickle(r io.Reader) (interface{}, error) {
|
||||||
|
u := NewUnpickler(r)
|
||||||
|
return u.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func isZipFile(filename string) bool {
|
||||||
|
r, err := zip.OpenReader(filename)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
r.Close()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func makePickleFindClass(fallback func(module, name string) (interface{}, error)) func(module, name string) (interface{}, error) {
|
||||||
|
return func(module, name string) (interface{}, error) {
|
||||||
|
switch module + "." + name {
|
||||||
|
case "torch._utils._rebuild_tensor":
|
||||||
|
return &RebuildTensor{}, nil
|
||||||
|
case "torch._utils._rebuild_tensor_v2":
|
||||||
|
return &RebuildTensorV2{}, nil
|
||||||
|
case "torch.FloatStorage":
|
||||||
|
return &FloatStorageClass{}, nil
|
||||||
|
case "torch.HalfStorage":
|
||||||
|
return &HalfStorageClass{}, nil
|
||||||
|
case "torch.DoubleStorage":
|
||||||
|
return &DoubleStorageClass{}, nil
|
||||||
|
case "torch.CharStorage":
|
||||||
|
return &CharStorageClass{}, nil
|
||||||
|
case "torch.ShortStorage":
|
||||||
|
return &ShortStorageClass{}, nil
|
||||||
|
case "torch.IntStorage":
|
||||||
|
return &IntStorageClass{}, nil
|
||||||
|
case "torch.LongStorage":
|
||||||
|
return &LongStorageClass{}, nil
|
||||||
|
case "torch.ByteStorage":
|
||||||
|
return &ByteStorageClass{}, nil
|
||||||
|
case "torch.BoolStorage":
|
||||||
|
return &BoolStorageClass{}, nil
|
||||||
|
case "torch.nn.backends.thnn._get_thnn_function_backend":
|
||||||
|
// this is for historical pickle deserilaization, it is not used otherwise
|
||||||
|
return getThnnFunctionBackend{}, nil
|
||||||
|
default:
|
||||||
|
if fallback == nil {
|
||||||
|
return nil, fmt.Errorf("class not found: %s %s", module, name)
|
||||||
|
}
|
||||||
|
return fallback(module, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadAll finds and loads all weights from varstore.
|
||||||
|
// It will throw err if one of weights from varstore cannot find from loaded pretrained model.
|
||||||
|
func LoadAll(vs *nn.VarStore, modelFile string) error {
|
||||||
|
weights, err := Decode(modelFile)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("LoadAll() failed: %w", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// for tsName, _ := range vs.Vars.NamedVariables {
|
||||||
|
for tsName := range vs.Vars.NamedVariables {
|
||||||
|
// missing variable
|
||||||
|
currTs, ok := weights[tsName]
|
||||||
|
if !ok {
|
||||||
|
err = fmt.Errorf("LoadAll() failed: Cannot find tensor with name: %v in variable store. \n", tsName)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// mismatched shape
|
||||||
|
sourceShape := currTs.MustSize()
|
||||||
|
destShape := vs.Vars.NamedVariables[tsName].MustSize()
|
||||||
|
if !reflect.DeepEqual(destShape, sourceShape) {
|
||||||
|
err = fmt.Errorf("LoadAll() failed: Mismatched shape error for variable name: %v - At store: %v - At source %v\n", tsName, destShape, sourceShape)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ts.NoGrad(func() {
|
||||||
|
vs.Vars.NamedVariables[tsName].Copy_(currTs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, x := range weights {
|
||||||
|
x.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadPartial finds and loads weights for varstore.
|
||||||
|
// It returns list of unfound weight names.
|
||||||
|
func LoadPartial(vs *nn.VarStore, modelFile string) ([]string, error) {
|
||||||
|
weights, err := Decode(modelFile)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("LoadPartial() failed: %w", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var missingVariables []string
|
||||||
|
|
||||||
|
// Match and in-place copy value (update) from newly loaded tensors
|
||||||
|
// to existing named tensors if name is matched. Throw error otherwise.
|
||||||
|
for tsName := range vs.Vars.NamedVariables {
|
||||||
|
var currTs *ts.Tensor
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
// missing variable
|
||||||
|
if currTs, ok = weights[tsName]; !ok {
|
||||||
|
missingVariables = append(missingVariables, tsName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// mismatched shape
|
||||||
|
destShape := currTs.MustSize()
|
||||||
|
sourceShape := vs.Vars.NamedVariables[tsName].MustSize()
|
||||||
|
if !reflect.DeepEqual(destShape, sourceShape) {
|
||||||
|
fmt.Printf("WARNING: Mismatched shape error for variable name: %v - At store: %v - At source %v. Skip loading this weight...\n", tsName, destShape, sourceShape)
|
||||||
|
missingVariables = append(missingVariables, tsName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ts.NoGrad(func() {
|
||||||
|
vs.Vars.NamedVariables[tsName].Copy_(currTs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, x := range weights {
|
||||||
|
x.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
return missingVariables, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadInfo loads pretrained weights and prints out name and shape of weights.
|
||||||
|
func LoadInfo(modelFile string) error {
|
||||||
|
weights, err := Decode(modelFile)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("LoadInfo() failed: %w", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
layers := make([]string, 0, len(weights))
|
||||||
|
for tsName := range weights {
|
||||||
|
layers = append(layers, tsName)
|
||||||
|
}
|
||||||
|
sort.Strings(layers)
|
||||||
|
for _, l := range layers {
|
||||||
|
var x *ts.Tensor
|
||||||
|
for tsName, tsVal := range weights {
|
||||||
|
if tsName == l {
|
||||||
|
x = tsVal
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Printf("%s - %+v\n", l, x.MustSize())
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Num of variables: %v\n", len(weights))
|
||||||
|
|
||||||
|
for _, x := range weights {
|
||||||
|
x.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
667
pickle/storage.go
Normal file
667
pickle/storage.go
Normal file
|
@ -0,0 +1,667 @@
|
||||||
|
package pickle
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This file implements Pytorch storage data types.
|
||||||
|
// ref: https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/storage.py
|
||||||
|
/*
|
||||||
|
torch.double: 'DoubleStorage',
|
||||||
|
torch.float: 'FloatStorage',
|
||||||
|
torch.half: 'HalfStorage',
|
||||||
|
torch.long: 'LongStorage',
|
||||||
|
torch.int: 'IntStorage',
|
||||||
|
torch.int16: 'ShortStorage',
|
||||||
|
torch.int8: 'CharStorage',
|
||||||
|
torch.uint8: 'ByteStorage',
|
||||||
|
torch.bool: 'BoolStorage',
|
||||||
|
torch.bfloat16: 'BFloat16Storage',
|
||||||
|
torch.cdouble: 'ComplexDoubleStorage',
|
||||||
|
torch.cfloat: 'ComplexFloatStorage',
|
||||||
|
torch.qint8: 'QInt8Storage',
|
||||||
|
torch.qint32: 'QInt32Storage',
|
||||||
|
torch.quint8: 'QUInt8Storage',
|
||||||
|
torch.quint4x2: 'QUInt4x2Storage',
|
||||||
|
torch.quint2x4: 'QUInt2x4Storage',
|
||||||
|
*/
|
||||||
|
|
||||||
|
// StorageClass defines interface for types to be used in Storage.
|
||||||
|
type StorageClass interface {
|
||||||
|
New(size int, location string) Storage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Storage define Storage interface.
|
||||||
|
type Storage interface {
|
||||||
|
SetFromFile(r io.Reader) error
|
||||||
|
SetFromFileWithSize(r io.Reader, size int) error
|
||||||
|
DType() gotch.DType
|
||||||
|
GetData() interface{}
|
||||||
|
Device() gotch.Device
|
||||||
|
}
|
||||||
|
|
||||||
|
// BaseStorage represents a base storage.
|
||||||
|
type BaseStorage struct {
|
||||||
|
Size int
|
||||||
|
Location string
|
||||||
|
}
|
||||||
|
|
||||||
|
// HalfStorage:
|
||||||
|
// ============
|
||||||
|
|
||||||
|
type HalfStorageClass struct{}
|
||||||
|
|
||||||
|
var _ StorageClass = &HalfStorageClass{}
|
||||||
|
|
||||||
|
func (s *HalfStorageClass) New(size int, location string) Storage {
|
||||||
|
return &HalfStorage{
|
||||||
|
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||||
|
Data: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type HalfStorage struct {
|
||||||
|
BaseStorage
|
||||||
|
Data []float32
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Storage = &HalfStorage{}
|
||||||
|
|
||||||
|
func (s *HalfStorage) SetFromFile(r io.Reader) error {
|
||||||
|
return setFromFile(s, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HalfStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||||
|
data := make([]float32, size)
|
||||||
|
br := NewLimitedBufferReader(r, size, 2, 512)
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
bytes, err := br.ReadNext()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
u16 := binary.LittleEndian.Uint16(bytes)
|
||||||
|
data[i] = math.Float32frombits(FloatBits16to32(u16))
|
||||||
|
}
|
||||||
|
s.Data = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HalfStorage) GetData() interface{} {
|
||||||
|
return s.Data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HalfStorage) DType() gotch.DType {
|
||||||
|
return gotch.Float
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HalfStorage) Device() gotch.Device {
|
||||||
|
switch s.Location {
|
||||||
|
case "cuda":
|
||||||
|
return gotch.CudaIfAvailable()
|
||||||
|
default:
|
||||||
|
return gotch.CPU
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FloatStorage:
|
||||||
|
// =============
|
||||||
|
|
||||||
|
type FloatStorageClass struct{}
|
||||||
|
|
||||||
|
var _ StorageClass = &FloatStorageClass{}
|
||||||
|
|
||||||
|
func (s *FloatStorageClass) New(size int, location string) Storage {
|
||||||
|
return &FloatStorage{
|
||||||
|
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||||
|
Data: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type FloatStorage struct {
|
||||||
|
BaseStorage
|
||||||
|
Data []float32
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Storage = &FloatStorage{}
|
||||||
|
|
||||||
|
func (s *FloatStorage) SetFromFile(r io.Reader) error {
|
||||||
|
return setFromFile(s, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FloatStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||||
|
data := make([]float32, size)
|
||||||
|
br := NewLimitedBufferReader(r, size, 4, 512)
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
bytes, err := br.ReadNext()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data[i] = math.Float32frombits(binary.LittleEndian.Uint32(bytes))
|
||||||
|
}
|
||||||
|
s.Data = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FloatStorage) GetData() interface{} {
|
||||||
|
return s.Data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FloatStorage) DType() gotch.DType {
|
||||||
|
return gotch.Float
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FloatStorage) Device() gotch.Device {
|
||||||
|
switch s.Location {
|
||||||
|
case "cuda":
|
||||||
|
return gotch.CudaIfAvailable()
|
||||||
|
default:
|
||||||
|
return gotch.CPU
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoubleStorage:
|
||||||
|
// ==============
|
||||||
|
|
||||||
|
type DoubleStorageClass struct{}
|
||||||
|
|
||||||
|
var _ StorageClass = &DoubleStorageClass{}
|
||||||
|
|
||||||
|
func (s *DoubleStorageClass) New(size int, location string) Storage {
|
||||||
|
return &DoubleStorage{
|
||||||
|
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||||
|
Data: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type DoubleStorage struct {
|
||||||
|
BaseStorage
|
||||||
|
Data []float64
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Storage = &DoubleStorage{}
|
||||||
|
|
||||||
|
func (s *DoubleStorage) SetFromFile(r io.Reader) error {
|
||||||
|
return setFromFile(s, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DoubleStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||||
|
data := make([]float64, size)
|
||||||
|
br := NewLimitedBufferReader(r, size, 8, 512)
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
bytes, err := br.ReadNext()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data[i] = math.Float64frombits(binary.LittleEndian.Uint64(bytes))
|
||||||
|
}
|
||||||
|
s.Data = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DoubleStorage) GetData() interface{} {
|
||||||
|
return s.Data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DoubleStorage) DType() gotch.DType {
|
||||||
|
return gotch.Double
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DoubleStorage) Device() gotch.Device {
|
||||||
|
switch s.Location {
|
||||||
|
case "cuda":
|
||||||
|
return gotch.CudaIfAvailable()
|
||||||
|
default:
|
||||||
|
return gotch.CPU
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CharStorage:
|
||||||
|
// ============
|
||||||
|
|
||||||
|
type CharStorageClass struct{}
|
||||||
|
|
||||||
|
var _ StorageClass = &CharStorageClass{}
|
||||||
|
|
||||||
|
func (s *CharStorageClass) New(size int, location string) Storage {
|
||||||
|
return &CharStorage{
|
||||||
|
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||||
|
Data: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CharStorage struct {
|
||||||
|
BaseStorage
|
||||||
|
Data []int8
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Storage = &CharStorage{}
|
||||||
|
|
||||||
|
func (s *CharStorage) SetFromFile(r io.Reader) error {
|
||||||
|
return setFromFile(s, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CharStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||||
|
data := make([]int8, size)
|
||||||
|
br := NewLimitedBufferReader(r, size, 1, 512)
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
bytes, err := br.ReadNext()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data[i] = int8(bytes[0])
|
||||||
|
}
|
||||||
|
s.Data = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CharStorage) GetData() interface{} {
|
||||||
|
return s.Data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CharStorage) DType() gotch.DType {
|
||||||
|
return gotch.Int8
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CharStorage) Device() gotch.Device {
|
||||||
|
switch s.Location {
|
||||||
|
case "cuda":
|
||||||
|
return gotch.CudaIfAvailable()
|
||||||
|
default:
|
||||||
|
return gotch.CPU
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShortStorage:
|
||||||
|
// =============
|
||||||
|
|
||||||
|
type ShortStorageClass struct{}
|
||||||
|
|
||||||
|
var _ StorageClass = &ShortStorageClass{}
|
||||||
|
|
||||||
|
func (s *ShortStorageClass) New(size int, location string) Storage {
|
||||||
|
return &ShortStorage{
|
||||||
|
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||||
|
Data: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ShortStorage struct {
|
||||||
|
BaseStorage
|
||||||
|
Data []int16
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Storage = &ShortStorage{}
|
||||||
|
|
||||||
|
func (s *ShortStorage) SetFromFile(r io.Reader) error {
|
||||||
|
return setFromFile(s, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShortStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||||
|
data := make([]int16, size)
|
||||||
|
br := NewLimitedBufferReader(r, size, 2, 512)
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
bytes, err := br.ReadNext()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data[i] = int16(binary.LittleEndian.Uint16(bytes))
|
||||||
|
}
|
||||||
|
s.Data = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShortStorage) GetData() interface{} {
|
||||||
|
return s.Data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShortStorage) DType() gotch.DType {
|
||||||
|
return gotch.Int16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShortStorage) Device() gotch.Device {
|
||||||
|
switch s.Location {
|
||||||
|
case "cuda":
|
||||||
|
return gotch.CudaIfAvailable()
|
||||||
|
default:
|
||||||
|
return gotch.CPU
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntStorage:
|
||||||
|
// ===========
|
||||||
|
|
||||||
|
type IntStorageClass struct{}
|
||||||
|
|
||||||
|
var _ StorageClass = &IntStorageClass{}
|
||||||
|
|
||||||
|
func (s *IntStorageClass) New(size int, location string) Storage {
|
||||||
|
return &IntStorage{
|
||||||
|
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||||
|
Data: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type IntStorage struct {
|
||||||
|
BaseStorage
|
||||||
|
Data []int32
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Storage = &IntStorage{}
|
||||||
|
|
||||||
|
func (s *IntStorage) SetFromFile(r io.Reader) error {
|
||||||
|
return setFromFile(s, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IntStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||||
|
data := make([]int32, size)
|
||||||
|
br := NewLimitedBufferReader(r, size, 4, 512)
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
bytes, err := br.ReadNext()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data[i] = int32(binary.LittleEndian.Uint32(bytes))
|
||||||
|
}
|
||||||
|
s.Data = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IntStorage) GetData() interface{} {
|
||||||
|
return s.Data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IntStorage) DType() gotch.DType {
|
||||||
|
return gotch.Int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IntStorage) Device() gotch.Device {
|
||||||
|
switch s.Location {
|
||||||
|
case "cuda":
|
||||||
|
return gotch.CudaIfAvailable()
|
||||||
|
default:
|
||||||
|
return gotch.CPU
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LongStorage:
|
||||||
|
// ============
|
||||||
|
|
||||||
|
type LongStorageClass struct{}
|
||||||
|
|
||||||
|
var _ StorageClass = &LongStorageClass{}
|
||||||
|
|
||||||
|
func (s *LongStorageClass) New(size int, location string) Storage {
|
||||||
|
return &LongStorage{
|
||||||
|
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||||
|
Data: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type LongStorage struct {
|
||||||
|
BaseStorage
|
||||||
|
Data []int64
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Storage = &LongStorage{}
|
||||||
|
|
||||||
|
func (s *LongStorage) SetFromFile(r io.Reader) error {
|
||||||
|
return setFromFile(s, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *LongStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||||
|
data := make([]int64, size)
|
||||||
|
br := NewLimitedBufferReader(r, size, 8, 512)
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
bytes, err := br.ReadNext()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data[i] = int64(binary.LittleEndian.Uint64(bytes))
|
||||||
|
}
|
||||||
|
s.Data = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *LongStorage) GetData() interface{} {
|
||||||
|
return s.Data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *LongStorage) DType() gotch.DType {
|
||||||
|
return gotch.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *LongStorage) Device() gotch.Device {
|
||||||
|
switch s.Location {
|
||||||
|
case "cuda":
|
||||||
|
return gotch.CudaIfAvailable()
|
||||||
|
default:
|
||||||
|
return gotch.CPU
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByteStorage:
|
||||||
|
// ============
|
||||||
|
|
||||||
|
type ByteStorageClass struct{}
|
||||||
|
|
||||||
|
var _ StorageClass = &ByteStorageClass{}
|
||||||
|
|
||||||
|
func (s *ByteStorageClass) New(size int, location string) Storage {
|
||||||
|
return &ByteStorage{
|
||||||
|
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||||
|
Data: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ByteStorage struct {
|
||||||
|
BaseStorage
|
||||||
|
Data []uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Storage = &ByteStorage{}
|
||||||
|
|
||||||
|
func (s *ByteStorage) SetFromFile(r io.Reader) error {
|
||||||
|
return setFromFile(s, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ByteStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||||
|
data := make([]uint8, size)
|
||||||
|
br := NewLimitedBufferReader(r, size, 1, 512)
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
bytes, err := br.ReadNext()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data[i] = bytes[0]
|
||||||
|
}
|
||||||
|
s.Data = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ByteStorage) GetData() interface{} {
|
||||||
|
return s.Data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ByteStorage) DType() gotch.DType {
|
||||||
|
return gotch.Uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ByteStorage) Device() gotch.Device {
|
||||||
|
switch s.Location {
|
||||||
|
case "cuda":
|
||||||
|
return gotch.CudaIfAvailable()
|
||||||
|
default:
|
||||||
|
return gotch.CPU
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BoolStorage:
|
||||||
|
// ============
|
||||||
|
|
||||||
|
type BoolStorageClass struct{}
|
||||||
|
|
||||||
|
var _ StorageClass = &BoolStorageClass{}
|
||||||
|
|
||||||
|
func (s *BoolStorageClass) New(size int, location string) Storage {
|
||||||
|
return &BoolStorage{
|
||||||
|
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||||
|
Data: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type BoolStorage struct {
|
||||||
|
BaseStorage
|
||||||
|
Data []bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Storage = &BoolStorage{}
|
||||||
|
|
||||||
|
func (s *BoolStorage) SetFromFile(r io.Reader) error {
|
||||||
|
return setFromFile(s, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BoolStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||||
|
data := make([]bool, size)
|
||||||
|
br := NewLimitedBufferReader(r, size, 1, 512)
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
bytes, err := br.ReadNext()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data[i] = bytes[0] == 1
|
||||||
|
}
|
||||||
|
s.Data = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BoolStorage) GetData() interface{} {
|
||||||
|
return s.Data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BoolStorage) DType() gotch.DType {
|
||||||
|
return gotch.Float
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BoolStorage) Device() gotch.Device {
|
||||||
|
switch s.Location {
|
||||||
|
case "cuda":
|
||||||
|
return gotch.CudaIfAvailable()
|
||||||
|
default:
|
||||||
|
return gotch.CPU
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setFromFile(s Storage, r io.Reader) error {
|
||||||
|
sizeBuf := make([]byte, 8)
|
||||||
|
_, err := r.Read(sizeBuf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
size := int(binary.LittleEndian.Uint64(sizeBuf))
|
||||||
|
return s.SetFromFileWithSize(r, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StorageTensor:
|
||||||
|
//===============
|
||||||
|
type StorageTensor struct {
|
||||||
|
Source Storage
|
||||||
|
StorageOffset int64
|
||||||
|
Size []int64
|
||||||
|
Stride []int64
|
||||||
|
RequiresGrad bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rebuild Tensor:
|
||||||
|
// ===============
|
||||||
|
// ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L132
|
||||||
|
// ref. def _rebuild_tensor(storage, storage_offset, size, stride):
|
||||||
|
|
||||||
|
type RebuildTensor struct{}
|
||||||
|
|
||||||
|
var _ Callable = &RebuildTensor{}
|
||||||
|
|
||||||
|
func (r *RebuildTensor) Call(args ...interface{}) (interface{}, error) {
|
||||||
|
if len(args) != 4 {
|
||||||
|
return nil, fmt.Errorf("RebuildTensor.Call() failed. Expected 4 args, got %d: %#v", len(args), args)
|
||||||
|
}
|
||||||
|
|
||||||
|
storage, storageOk := args[0].(Storage)
|
||||||
|
storageOffset, storageOffsetOk := args[1].(int)
|
||||||
|
size, sizeOk := args[2].(*Tuple)
|
||||||
|
stride, strideOk := args[3].(*Tuple)
|
||||||
|
if !storageOk || !storageOffsetOk || !sizeOk || !strideOk {
|
||||||
|
return nil, fmt.Errorf("RebuildTensor.Call() unexpected args: %#v", args)
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor := &StorageTensor{
|
||||||
|
Source: storage,
|
||||||
|
StorageOffset: int64(storageOffset),
|
||||||
|
RequiresGrad: false,
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
tensor.Size, err = tupleToInt64Slice(size)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tensor.Stride, err = tupleToInt64Slice(stride)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return tensor, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RebuildTensorV2 represents a struct to rebuild tensor back from pickle object.
|
||||||
|
type RebuildTensorV2 struct{}
|
||||||
|
|
||||||
|
var _ Callable = &RebuildTensorV2{}
|
||||||
|
|
||||||
|
func (r *RebuildTensorV2) Call(args ...interface{}) (interface{}, error) {
|
||||||
|
if len(args) != 6 {
|
||||||
|
return nil, fmt.Errorf("RebuildTensorV2 unexpected args: %#v", args)
|
||||||
|
}
|
||||||
|
|
||||||
|
storage, storageOk := args[0].(Storage)
|
||||||
|
storageOffset, storageOffsetOk := args[1].(int)
|
||||||
|
size, sizeOk := args[2].(*Tuple)
|
||||||
|
stride, strideOk := args[3].(*Tuple)
|
||||||
|
requiresGrad, requiresGradOk := args[4].(bool)
|
||||||
|
// arg[5] "backward hooks" is unused
|
||||||
|
if !storageOk || !storageOffsetOk || !sizeOk || !strideOk ||
|
||||||
|
!requiresGradOk {
|
||||||
|
return nil, fmt.Errorf("RebuildTensorV2 unexpected args: %#v", args)
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor := &StorageTensor{
|
||||||
|
Source: storage,
|
||||||
|
StorageOffset: int64(storageOffset),
|
||||||
|
RequiresGrad: requiresGrad,
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
tensor.Size, err = tupleToInt64Slice(size)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tensor.Stride, err = tupleToInt64Slice(stride)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return tensor, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
|
||||||
|
length := tuple.Len()
|
||||||
|
slice := make([]int64, length)
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
value, ok := tuple.Get(i).(int)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("tuple of ints expected: %#v", tuple)
|
||||||
|
}
|
||||||
|
slice[i] = int64(value)
|
||||||
|
}
|
||||||
|
return slice, nil
|
||||||
|
}
|
519
pickle/type.go
Normal file
519
pickle/type.go
Normal file
|
@ -0,0 +1,519 @@
|
||||||
|
package pickle
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/list"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This file contains custom types and interfaces for casting Python data types from Pickle.
|
||||||
|
//
|
||||||
|
// Custom Types:
|
||||||
|
// =============
|
||||||
|
// 1. ByteArray
|
||||||
|
// 2. Dict
|
||||||
|
// 3. Tuple
|
||||||
|
// 4. OrderedDict
|
||||||
|
// 5. List
|
||||||
|
// 6. Set
|
||||||
|
// 7. FrozenSet
|
||||||
|
// 8. Object
|
||||||
|
// 9. Reconstructor
|
||||||
|
// 10. GenericClass
|
||||||
|
|
||||||
|
// Interfaces:
|
||||||
|
// ===========
|
||||||
|
|
||||||
|
// Callable is implemented by any value that can be directly called to get a
|
||||||
|
// new value.
|
||||||
|
//
|
||||||
|
// It is usually implemented by Python-like functions (returning a value
|
||||||
|
// given some arguments), or classes (typically returning an instance given
|
||||||
|
// some constructor arguments).
|
||||||
|
type Callable interface {
|
||||||
|
// Call mimics a direct invocation on a Python value, such as a function
|
||||||
|
// or class (constructor).
|
||||||
|
Call(args ...interface{}) (interface{}, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PyNewable is implemented by any value that has a Python-like
|
||||||
|
// "__new__" method.
|
||||||
|
//
|
||||||
|
// It is usually implemented by values representing Python classes.
|
||||||
|
type PyNewable interface {
|
||||||
|
// PyNew mimics Python invocation of the "__new__" method, usually
|
||||||
|
// provided by classes.
|
||||||
|
//
|
||||||
|
// See: https://docs.python.org/3/reference/datamodel.html#object.__new__
|
||||||
|
PyNew(args ...interface{}) (interface{}, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PyStateSettable is implemented by any value that has a Python-like
|
||||||
|
// "__setstate__" method.
|
||||||
|
type PyStateSettable interface {
|
||||||
|
// PySetState mimics Python invocation of the "__setstate__" method.
|
||||||
|
//
|
||||||
|
// See: https://docs.python.org/3/library/pickle.html#object.__setstate__
|
||||||
|
PySetState(state interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// PyDictSettable is implemented by any value that can store dictionary-like
|
||||||
|
// key/value pairs. It reflects Python behavior of setting a key/value pair on
|
||||||
|
// an object's "__dict__" attribute.
|
||||||
|
type PyDictSettable interface {
|
||||||
|
// PyDictSet mimics the setting of a key/value pair on an object's
|
||||||
|
//"__dict__" attribute.
|
||||||
|
//
|
||||||
|
// See: https://docs.python.org/3/library/stdtypes.html#object.__dict__
|
||||||
|
PyDictSet(key, value interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// PyAttrSettable is implemented by any value on which an existing or new
|
||||||
|
// Python-like attribute can be set. In Python this is done with "setattr"
|
||||||
|
// builtin function.
|
||||||
|
type PyAttrSettable interface {
|
||||||
|
// PySetAttr mimics the setting of an arbitrary value to an object's
|
||||||
|
// attribute.
|
||||||
|
//
|
||||||
|
// In Python this is done with "setattr" function, to which object,
|
||||||
|
// attribute name, and value are passed. For an easy and clear
|
||||||
|
// implementation, here instead we require this method to be implemented
|
||||||
|
// on the "object" itself.
|
||||||
|
//
|
||||||
|
// See: https://docs.python.org/3/library/functions.html#setattr
|
||||||
|
PySetAttr(key string, value interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByteArray:
|
||||||
|
//===========
|
||||||
|
|
||||||
|
// ByteArray simulates Python bytearray.
|
||||||
|
type ByteArray []byte
|
||||||
|
|
||||||
|
func NewByteArray() *ByteArray {
|
||||||
|
arr := make(ByteArray, 0)
|
||||||
|
return &arr
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewByteArrayFromSlice(slice []byte) *ByteArray {
|
||||||
|
arr := ByteArray(slice)
|
||||||
|
return &arr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ByteArray) Get(i int) byte {
|
||||||
|
return (*a)[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ByteArray) Len() int {
|
||||||
|
return len(*a)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dict:
|
||||||
|
//======
|
||||||
|
|
||||||
|
// DictSetter is implemented by any value that exhibits a dict-like behaviour,
|
||||||
|
// allowing arbitrary key/value pairs to be set.
|
||||||
|
type DictSetter interface {
|
||||||
|
Set(key, value interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dict represents a Python "dict" (builtin type).
|
||||||
|
//
|
||||||
|
// It is implemented as a slice, instead of a map, because in Go not
|
||||||
|
// all types can be map's keys (e.g. slices).
|
||||||
|
type Dict []*DictEntry
|
||||||
|
|
||||||
|
type DictEntry struct {
|
||||||
|
Key interface{}
|
||||||
|
Value interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDict makes and returns a new empty Dict.
|
||||||
|
func NewDict() *Dict {
|
||||||
|
d := make(Dict, 0)
|
||||||
|
return &d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets into the Dict the given key/value pair.
|
||||||
|
func (d *Dict) Set(key, value interface{}) {
|
||||||
|
*d = append(*d, &DictEntry{
|
||||||
|
Key: key,
|
||||||
|
Value: value,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the value associated with the given key (if any), and whether
|
||||||
|
// the key is present or not.
|
||||||
|
func (d *Dict) Get(key interface{}) (interface{}, bool) {
|
||||||
|
for _, entry := range *d {
|
||||||
|
if reflect.DeepEqual(entry.Key, key) {
|
||||||
|
return entry.Value, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustGet returns the value associated with the given key, if if it exists,
|
||||||
|
// otherwise it panics.
|
||||||
|
func (d *Dict) MustGet(key interface{}) interface{} {
|
||||||
|
value, ok := d.Get(key)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("key not found in Dict: %#v", key))
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the length of the Dict, that is, the amount of key/value pairs
|
||||||
|
// contained by the Dict.
|
||||||
|
func (d *Dict) Len() int {
|
||||||
|
return len(*d)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ DictSetter = &Dict{}
|
||||||
|
|
||||||
|
// Tuple:
|
||||||
|
// ======
|
||||||
|
|
||||||
|
type Tuple []interface{}
|
||||||
|
|
||||||
|
func NewTupleFromSlice(slice []interface{}) *Tuple {
|
||||||
|
t := Tuple(slice)
|
||||||
|
return &t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tuple) Get(i int) interface{} {
|
||||||
|
return (*t)[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tuple) Len() int {
|
||||||
|
return len(*t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrderedDict:
|
||||||
|
// ============
|
||||||
|
|
||||||
|
// OrderedDictClass represent Python "collections.OrderedDict" class.
|
||||||
|
//
|
||||||
|
// This class allows the indirect creation of OrderedDict objects.
|
||||||
|
type OrderedDictClass struct{}
|
||||||
|
|
||||||
|
var _ Callable = &OrderedDictClass{}
|
||||||
|
|
||||||
|
// Call returns a new empty OrderedDict. It is equivalent to Python
|
||||||
|
// constructor "collections.OrderedDict()".
|
||||||
|
//
|
||||||
|
// No arguments are supported.
|
||||||
|
func (*OrderedDictClass) Call(args ...interface{}) (interface{}, error) {
|
||||||
|
if len(args) != 0 {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"OrderedDictClass.Call args not supported: %#v", args)
|
||||||
|
}
|
||||||
|
return NewOrderedDict(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrderedDict is a minimal and trivial implementation of an ordered map,
|
||||||
|
// which represent a Python "collections.OrderedDict" object.
|
||||||
|
//
|
||||||
|
// It is composed by a simple unordered Map, and a List to keep the order of
|
||||||
|
// the entries. The former is useful for direct key lookups, the latter for
|
||||||
|
// iteration.
|
||||||
|
type OrderedDict struct {
|
||||||
|
// Map associates a key of any type (interface{}) to OrderedDictEntry
|
||||||
|
// pointer values. These values are shared with List.
|
||||||
|
Map map[interface{}]*OrderedDictEntry
|
||||||
|
// List is an ordered list of OrderedDictEntry pointers, which are
|
||||||
|
// also shared with Map.
|
||||||
|
List *list.List
|
||||||
|
// PyDict represents Python "object.__dict__" dictionary of attributes.
|
||||||
|
PyDict map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ DictSetter = &OrderedDict{}
|
||||||
|
var _ PyDictSettable = &OrderedDict{}
|
||||||
|
|
||||||
|
// OrderedDictEntry is a single key/value pair stored in an OrderedDict.
|
||||||
|
//
|
||||||
|
// A pointer to an OrderedDictEntry is always shared between OrderedDict's Map
|
||||||
|
// and List.
|
||||||
|
type OrderedDictEntry struct {
|
||||||
|
// Key of a single OrderedDict's entry.
|
||||||
|
Key interface{}
|
||||||
|
// Value of a single OrderedDict's entry.
|
||||||
|
Value interface{}
|
||||||
|
// ListElement is a pointer to the OrderedDict's List Element which
|
||||||
|
// contains this very OrderedDictEntry.
|
||||||
|
ListElement *list.Element
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOrderedDict makes and returns a new empty OrderedDict.
|
||||||
|
func NewOrderedDict() *OrderedDict {
|
||||||
|
return &OrderedDict{
|
||||||
|
Map: make(map[interface{}]*OrderedDictEntry),
|
||||||
|
List: list.New(),
|
||||||
|
PyDict: make(map[string]interface{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets into the OrderedDict the given key/value pair. If the key does not
|
||||||
|
// exist yet, the new pair is positioned at the end (back) of the OrderedDict.
|
||||||
|
// If the key already exists, the existing associated value is replaced with the
|
||||||
|
// new one, and the original position is maintained.
|
||||||
|
func (o *OrderedDict) Set(k, v interface{}) {
|
||||||
|
if entry, ok := o.Map[k]; ok {
|
||||||
|
entry.Value = v
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := &OrderedDictEntry{
|
||||||
|
Key: k,
|
||||||
|
Value: v,
|
||||||
|
}
|
||||||
|
entry.ListElement = o.List.PushBack(entry)
|
||||||
|
o.Map[k] = entry
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the value associated with the given key (if any), and whether
|
||||||
|
// the key is present or not.
|
||||||
|
func (o *OrderedDict) Get(k interface{}) (interface{}, bool) {
|
||||||
|
entry, ok := o.Map[k]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return entry.Value, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustGet returns the value associated with the given key, if if it exists,
|
||||||
|
// otherwise it panics.
|
||||||
|
func (o *OrderedDict) MustGet(key interface{}) interface{} {
|
||||||
|
value, ok := o.Get(key)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("key not found in OrderedDict: %#v", key))
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the length of the OrderedDict, that is, the amount of key/value
|
||||||
|
// pairs contained by the OrderedDict.
|
||||||
|
func (o *OrderedDict) Len() int {
|
||||||
|
return len(o.Map)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PyDictSet mimics the setting of a key/value pair on Python "__dict__"
|
||||||
|
// attribute of the OrderedDict.
|
||||||
|
func (o *OrderedDict) PyDictSet(key, value interface{}) error {
|
||||||
|
sKey, keyOk := key.(string)
|
||||||
|
if !keyOk {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"OrderedDict.PyDictSet() requires string key: %#v", key)
|
||||||
|
}
|
||||||
|
o.PyDict[sKey] = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List:
|
||||||
|
// =====
|
||||||
|
|
||||||
|
// ListAppender is implemented by any value that exhibits a list-like
|
||||||
|
// behaviour, allowing arbitrary values to be appended.
|
||||||
|
type ListAppender interface {
|
||||||
|
Append(v interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// List represents a Python "list" (builtin type).
|
||||||
|
type List []interface{}
|
||||||
|
|
||||||
|
var _ ListAppender = &List{}
|
||||||
|
|
||||||
|
// NewList makes and returns a new empty List.
|
||||||
|
func NewList() *List {
|
||||||
|
l := make(List, 0)
|
||||||
|
return &l
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewListFromSlice makes and returns a new List initialized with the elements
|
||||||
|
// of the given slice.
|
||||||
|
//
|
||||||
|
// The new List is a simple type cast of the input slice; the slice is _not_
|
||||||
|
// copied.
|
||||||
|
func NewListFromSlice(slice []interface{}) *List {
|
||||||
|
l := List(slice)
|
||||||
|
return &l
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append appends one element to the end of the List.
|
||||||
|
func (l *List) Append(v interface{}) {
|
||||||
|
*l = append(*l, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the element of the List at the given index.
|
||||||
|
//
|
||||||
|
// It panics if the index is out of range.
|
||||||
|
func (l *List) Get(i int) interface{} {
|
||||||
|
return (*l)[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the length of the List.
|
||||||
|
func (l *List) Len() int {
|
||||||
|
return len(*l)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set:
|
||||||
|
// ====
|
||||||
|
|
||||||
|
// SetAdder is implemented by any value that exhibits a set-like behaviour,
|
||||||
|
// allowing arbitrary values to be added.
|
||||||
|
type SetAdder interface {
|
||||||
|
Add(v interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set represents a Python "set" (builtin type).
|
||||||
|
//
|
||||||
|
// It is implemented in Go as a map with empty struct values; the actual set
|
||||||
|
// of generic "interface{}" items is thus represented by all the keys.
|
||||||
|
type Set map[interface{}]setEmptyStruct
|
||||||
|
|
||||||
|
var _ SetAdder = &Set{}
|
||||||
|
|
||||||
|
type setEmptyStruct struct{}
|
||||||
|
|
||||||
|
// NewSet makes and returns a new empty Set.
|
||||||
|
func NewSet() *Set {
|
||||||
|
s := make(Set)
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSetFromSlice makes and returns a new Set initialized with the elements
|
||||||
|
// of the given slice.
|
||||||
|
func NewSetFromSlice(slice []interface{}) *Set {
|
||||||
|
s := make(Set, len(slice))
|
||||||
|
for _, item := range slice {
|
||||||
|
s[item] = setEmptyStruct{}
|
||||||
|
}
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the length of the Set.
|
||||||
|
func (s *Set) Len() int {
|
||||||
|
return len(*s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds one element to the Set.
|
||||||
|
func (s *Set) Add(v interface{}) {
|
||||||
|
(*s)[v] = setEmptyStruct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Has returns whether the given value is present in the Set (true)
|
||||||
|
// or not (false).
|
||||||
|
func (s *Set) Has(v interface{}) bool {
|
||||||
|
_, ok := (*s)[v]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// FrozenSet:
|
||||||
|
//===========
|
||||||
|
|
||||||
|
// FrozenSet represents a Python "frozenset" (builtin type).
|
||||||
|
//
|
||||||
|
// It is implemented in Go as a map with empty struct values; the actual set
|
||||||
|
// of generic "interface{}" items is thus represented by all the keys.
|
||||||
|
type FrozenSet map[interface{}]frozenSetEmptyStruct
|
||||||
|
|
||||||
|
type frozenSetEmptyStruct struct{}
|
||||||
|
|
||||||
|
// NewFrozenSetFromSlice makes and returns a new FrozenSet initialized
|
||||||
|
// with the elements of the given slice.
|
||||||
|
func NewFrozenSetFromSlice(slice []interface{}) *FrozenSet {
|
||||||
|
f := make(FrozenSet, len(slice))
|
||||||
|
for _, item := range slice {
|
||||||
|
f[item] = frozenSetEmptyStruct{}
|
||||||
|
}
|
||||||
|
return &f
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the length of the FrozenSet.
|
||||||
|
func (f *FrozenSet) Len() int {
|
||||||
|
return len(*f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Has returns whether the given value is present in the FrozenSet (true)
|
||||||
|
// or not (false).
|
||||||
|
func (f *FrozenSet) Has(v interface{}) bool {
|
||||||
|
_, ok := (*f)[v]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Object:
|
||||||
|
//========
|
||||||
|
|
||||||
|
type ObjectClass struct{}
|
||||||
|
|
||||||
|
var _ PyNewable = &ObjectClass{}
|
||||||
|
|
||||||
|
func (o *ObjectClass) PyNew(args ...interface{}) (interface{}, error) {
|
||||||
|
if len(args) == 0 {
|
||||||
|
return nil, fmt.Errorf("ObjectClass.PyNew called with no arguments")
|
||||||
|
}
|
||||||
|
switch class := args[0].(type) {
|
||||||
|
case PyNewable:
|
||||||
|
return class.PyNew()
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"ObjectClass.PyNew unprocessable args: %#v", args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconstructor:
|
||||||
|
//===============
|
||||||
|
|
||||||
|
type Reconstructor struct{}
|
||||||
|
|
||||||
|
var _ Callable = &Reconstructor{}
|
||||||
|
|
||||||
|
func (r *Reconstructor) Call(args ...interface{}) (interface{}, error) {
|
||||||
|
if len(args) < 2 {
|
||||||
|
return nil, fmt.Errorf("Reconstructor: invalid arguments: %#v", args)
|
||||||
|
}
|
||||||
|
class := args[0]
|
||||||
|
switch base := args[1].(type) {
|
||||||
|
case PyNewable:
|
||||||
|
return base.PyNew(class)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"Reconstructor: unprocessable arguments: %#v", args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenericClass:
|
||||||
|
//==============
|
||||||
|
|
||||||
|
type GenericClass struct {
|
||||||
|
Module string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ PyNewable = &GenericClass{}
|
||||||
|
|
||||||
|
type GenericObject struct {
|
||||||
|
Class *GenericClass
|
||||||
|
ConstructorArgs []interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGenericClass(module, name string) *GenericClass {
|
||||||
|
return &GenericClass{Module: module, Name: name}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GenericClass) PyNew(args ...interface{}) (interface{}, error) {
|
||||||
|
return &GenericObject{
|
||||||
|
Class: g,
|
||||||
|
ConstructorArgs: args,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getThnnFunctionBackend is for historical pickle deserilaization, it is not used otherwise
|
||||||
|
type getThnnFunctionBackend struct{}
|
||||||
|
|
||||||
|
var _ Callable = &getThnnFunctionBackend{}
|
||||||
|
|
||||||
|
func (getThnnFunctionBackend) Call(_ ...interface{}) (interface{}, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
112
pickle/util.go
Normal file
112
pickle/util.go
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
package pickle
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
|
||||||
|
// Converts the bits representation of a Half Float (16 bits) number to
|
||||||
|
// an IEEE 754 float representation (32 bits)
|
||||||
|
// From http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
|
||||||
|
func FloatBits16to32(u16 uint16) uint32 {
|
||||||
|
return mantissaTable[offsetTable[u16>>10]+(uint32(u16)&0x3ff)] + exponentTable[u16>>10]
|
||||||
|
}
|
||||||
|
|
||||||
|
var mantissaTable [2048]uint32
|
||||||
|
var exponentTable [64]uint32
|
||||||
|
var offsetTable [64]uint32
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
initMantissaTable()
|
||||||
|
initExponentTable()
|
||||||
|
initOffsetTable()
|
||||||
|
}
|
||||||
|
|
||||||
|
func initMantissaTable() {
|
||||||
|
mantissaTable[0] = 0
|
||||||
|
for i := uint32(1); i < 1024; i++ {
|
||||||
|
mantissaTable[i] = convertMantissa(i)
|
||||||
|
}
|
||||||
|
for i := uint32(1024); i < 2048; i++ {
|
||||||
|
mantissaTable[i] = 0x38000000 + ((i - 1024) << 13)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func initExponentTable() {
|
||||||
|
exponentTable[0] = 0
|
||||||
|
exponentTable[31] = 0x47800000
|
||||||
|
exponentTable[32] = 0x80000000
|
||||||
|
exponentTable[63] = 0xC7800000
|
||||||
|
for i := uint32(1); i < 31; i++ {
|
||||||
|
exponentTable[i] = i << 23
|
||||||
|
}
|
||||||
|
for i := uint32(33); i < 63; i++ {
|
||||||
|
exponentTable[i] = 0x80000000 + (i-32)<<23
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func initOffsetTable() {
|
||||||
|
offsetTable[0] = 0
|
||||||
|
offsetTable[32] = 0
|
||||||
|
for i := uint32(1); i < 31; i++ {
|
||||||
|
offsetTable[i] = 1024
|
||||||
|
}
|
||||||
|
for i := uint32(32); i < 64; i++ {
|
||||||
|
offsetTable[i] = 1024
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertMantissa(i uint32) uint32 {
|
||||||
|
var m uint32 = i << 13 // zero pad mantissa bits
|
||||||
|
var e uint32 = 0 // zero exponent
|
||||||
|
for m&0x00800000 != 0 { // while not normalized
|
||||||
|
e -= 0x00800000 // decrement exponent (1 << 23)
|
||||||
|
m <<= 1 // shift mantissa
|
||||||
|
}
|
||||||
|
m &= ^uint32(0x00800000) // clear leading 1 bit
|
||||||
|
e += 0x38800000 // adjust bias ((127-14)<<23)
|
||||||
|
return m | e // return combined number
|
||||||
|
}
|
||||||
|
|
||||||
|
type LimitedBufferReader struct {
|
||||||
|
r io.Reader
|
||||||
|
scalarSize int
|
||||||
|
remainingBytes int
|
||||||
|
buf []byte
|
||||||
|
bufIndex int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLimitedBufferReader(
|
||||||
|
r io.Reader,
|
||||||
|
dataSize, scalarSize, bufferSize int,
|
||||||
|
) *LimitedBufferReader {
|
||||||
|
size := bufferSize * scalarSize
|
||||||
|
return &LimitedBufferReader{
|
||||||
|
r: r,
|
||||||
|
scalarSize: scalarSize,
|
||||||
|
remainingBytes: scalarSize * dataSize,
|
||||||
|
buf: make([]byte, size),
|
||||||
|
bufIndex: size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (br *LimitedBufferReader) HasNext() bool {
|
||||||
|
return br.remainingBytes != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (br *LimitedBufferReader) ReadNext() ([]byte, error) {
|
||||||
|
if br.remainingBytes == 0 {
|
||||||
|
return nil, io.EOF
|
||||||
|
}
|
||||||
|
if br.bufIndex == len(br.buf) {
|
||||||
|
br.bufIndex = 0
|
||||||
|
if br.remainingBytes < len(br.buf) {
|
||||||
|
br.buf = br.buf[0:br.remainingBytes]
|
||||||
|
}
|
||||||
|
_, err := br.r.Read(br.buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result := br.buf[br.bufIndex : br.bufIndex+br.scalarSize]
|
||||||
|
br.bufIndex += br.scalarSize
|
||||||
|
br.remainingBytes -= br.scalarSize
|
||||||
|
return result, nil
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user