gotch/pickle/serialization.go

614 lines
16 KiB
Go

package pickle
// Ref.
// https://docs.python.org/3/library/pickle.html
// https://docs.python.org/3/library/pickletools.html
// https://github.com/python/cpython/blob/main/Lib/pickle.py (****real code here****)
// https://pytorch.org/tutorials/beginner/saving_loading_models.html
// https://github.com/pytorch/pytorch/blob/master/torch/serialization.py
import (
"archive/tar"
"archive/zip"
"errors"
"fmt"
"io"
"log"
"math/big"
"os"
"path"
"reflect"
"sort"
"git.andr3h3nriqu3s.com/andr3/gotch"
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
)
const hexMagicNumber = "1950a86a20f9469cfc6c"
const protocolVersion = 1001
var ErrInvalidMagicNumber = errors.New("invalid pytorch magic number")
var ErrInvalidProtocolVersion = errors.New("invalid pytorch protocol version")
// Encode encodes model using pickling machinery.
// Output pickled model can be loads with Python Pytorch as `torch.load("pytorch_model.bin")`
//
// TODO. implement pickling part so that model can be exported and load with Python Pytorch.
// See https://github.com/python/cpython/blob/b0de6299a840a397d4fe3e6c98159d9f258d3295/Lib/pickle.py#L407
func Encode(model ts.Module, outputFile string) error {
panic("NotImplementedError")
}
// Decode decodes pickled data created by 'torch.save()' with Python Pytorch
// and rebuilds named tensor weights.
func Decode(filename string) (map[string]*ts.Tensor, error) {
newUnpickler := func(r io.Reader) Unpickler {
return NewUnpickler(r)
}
result, err := LoadWithUnpickler(filename, newUnpickler)
if err != nil {
err := fmt.Errorf("Decode() failed: %w", err)
return nil, err
}
dtype := reflect.TypeOf(result).String()
// Rebuild tensors from Storage tensors
namedTensors := make(map[string]*ts.Tensor)
switch dtype {
case "*pickle.Dict":
dictResult := *result.(*Dict)
for _, item := range dictResult {
name := item.Key
sx, isStorageTensor := item.Value.(*StorageTensor)
if !isStorageTensor {
err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
return nil, err
}
data := sx.Source.GetData()
size := sx.Size
dtype := sx.Source.DType()
device := sx.Source.Device()
stride := sx.Stride
storageOffset := sx.StorageOffset
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
// log.Printf("data: %v\n", data)
// Dealing with Pytorch `..._tracked` variables.
if reflect.ValueOf(data).Len() == 0 {
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
continue
}
// TODO. should we just skip them?
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
size = []int64{1}
stride = []int64{1}
}
x := ts.MustOfSlice(data, ts.WithDType(dtype)).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
if sx.RequiresGrad {
x.MustRequiresGrad_(sx.RequiresGrad)
}
namedTensors[name.(string)] = x
}
case "*pickle.OrderedDict":
dictResult := result.(*OrderedDict)
for name, item := range dictResult.Map {
sx, isStorageTensor := item.Value.(*StorageTensor)
if !isStorageTensor {
err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
return nil, err
}
data := sx.Source.GetData()
size := sx.Size
dtype := sx.Source.DType()
device := sx.Source.Device()
stride := sx.Stride
storageOffset := sx.StorageOffset
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
// log.Printf("data: %v\n", data)
// Dealing with Pytorch `..._tracked` variables.
if reflect.ValueOf(data).Len() == 0 {
log.Printf("INFO: skip weigth %q with zero data length.\n", name.(string))
continue
}
// TODO. should we just skip them?
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
size = []int64{1}
stride = []int64{1}
}
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
if sx.RequiresGrad {
x.MustRequiresGrad_(sx.RequiresGrad)
}
namedTensors[name.(string)] = x
}
default:
err := fmt.Errorf("Decode() failed: expected '*pickle.OrderedDict' or '*pickle.Dict' type, got %v\n", dtype)
return nil, err
}
return namedTensors, nil
}
// LoadWithUnpickler is like Load, but it accepts a newUnpickler function which
// is used to create new customized pickle.Unpickler instances.
func LoadWithUnpickler(filename string, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) {
if !isZipFile(filename) {
return loadLegacyFile(filename, newUnpickler)
}
return loadZipFile(filename, newUnpickler)
}
func loadZipFile(filename string, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) {
// Open a zip archive for reading.
r, err := zip.OpenReader(filename)
if err != nil {
return nil, err
}
defer r.Close()
fileRecords := make(map[string]*zip.File, len(r.File))
for _, f := range r.File {
_, recordName := path.Split(f.Name)
fileRecords[recordName] = f
}
if _, isTorchScript := fileRecords["constants.pkl"]; isTorchScript {
return nil, fmt.Errorf("TorchScript is not supported")
}
dataFile, hasDataFile := fileRecords["data.pkl"]
if !hasDataFile {
return nil, fmt.Errorf("data.pkl not found in zip file")
}
df, err := dataFile.Open()
if err != nil {
return nil, err
}
defer df.Close()
loadedStorages := make(map[string]Storage)
u := newUnpickler(df)
u.FindClass = makePickleFindClass(u.FindClass)
u.PersistentLoad = func(savedId interface{}) (interface{}, error) {
tuple, tupleOk := savedId.(*Tuple)
if !tupleOk || tuple.Len() == 0 {
return nil, fmt.Errorf("PersistentLoad: non-empty tuple expected, got %#v", savedId)
}
typename, typenameOk := tuple.Get(0).(string)
if !typenameOk {
return nil, fmt.Errorf("PersistentLoad: cannot get typename")
}
if typename != "storage" {
return nil, fmt.Errorf("unknown typename for PersistentLoad, expected 'storage' but got '%s'", typename)
}
if tuple.Len() < 5 {
return nil, fmt.Errorf("PersistentLoad: unexpected storage data length")
}
dataType, dataTypeOk := tuple.Get(1).(StorageClass)
key, keyOk := tuple.Get(2).(string)
location, locationOk := tuple.Get(3).(string)
size, sizeOk := tuple.Get(4).(int)
if !dataTypeOk || !keyOk || !locationOk || !sizeOk {
return nil, fmt.Errorf("PersistentLoad: unexpected data types")
}
storage, storageExists := loadedStorages[key]
if !storageExists {
storage, err = loadTensor(dataType, size, location, key, fileRecords)
if err != nil {
return nil, err
}
loadedStorages[key] = storage
}
return storage, nil
}
return u.Load()
}
func loadTensor(
dataType StorageClass,
size int,
location, key string,
zipFileRecords map[string]*zip.File,
) (Storage, error) {
file, fileOk := zipFileRecords[key]
if !fileOk {
return nil, fmt.Errorf("cannot find zip record '%s'", key)
}
f, err := file.Open()
if err != nil {
return nil, err
}
defer f.Close()
storage := dataType.New(size, location)
err = storage.SetFromFileWithSize(f, size)
return storage, err
}
func loadLegacyFile(filename string, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer f.Close()
tr := tar.NewReader(f)
for {
_, err := tr.Next()
switch err {
case nil:
// TODO: ...
panic("legacy load from tar not implemented")
case io.EOF:
break // End of archive
case tar.ErrHeader, io.ErrUnexpectedEOF:
_, err = f.Seek(0, io.SeekStart)
if err != nil {
return nil, err
}
return loadLegacyNoTar(f, newUnpickler)
default:
return nil, err
}
}
}
func loadLegacyNoTar(f *os.File, newUnpickler func(r io.Reader) Unpickler) (interface{}, error) {
if err := readAndCheckMagicNumber(f); err != nil {
return nil, err
}
if err := readAndChecProtocolVersion(f); err != nil {
return nil, err
}
if _, err := unpickle(f); err != nil { // sys info
return nil, err
}
deserializedObjects := make(map[string]Storage)
u := newUnpickler(f)
u.FindClass = makePickleFindClass(u.FindClass)
u.PersistentLoad = func(savedId interface{}) (interface{}, error) {
tuple, tupleOk := savedId.(*Tuple)
if !tupleOk || tuple.Len() == 0 {
return nil, fmt.Errorf("PersistentLoad: non-empty tuple expected, got %#v", savedId)
}
typename, typenameOk := tuple.Get(0).(string)
if !typenameOk {
return nil, fmt.Errorf("PersistentLoad: cannot get typename")
}
// fmt.Printf("typename: %s\n", typename)
switch typename {
case "storage":
if tuple.Len() < 6 {
return nil, fmt.Errorf(
"PersistentLoad: unexpected storage data length")
}
dataType, dataTypeOk := tuple.Get(1).(StorageClass)
rootKey, rootKeyOk := tuple.Get(2).(string)
location, locationOk := tuple.Get(3).(string)
size, sizeOk := tuple.Get(4).(int)
viewMetadata := tuple.Get(5)
if !dataTypeOk || !rootKeyOk || !locationOk || !sizeOk {
return nil, fmt.Errorf("PersistentLoad: unexpected data types")
}
// fmt.Printf("dtype: %v - rootKey: %v - device: %v - size %v - viewMetaData: %v\n", reflect.TypeOf(dataType), rootKey, location, size, viewMetadata)
storage, storageExists := deserializedObjects[rootKey]
if !storageExists {
storage = dataType.New(size, location)
deserializedObjects[rootKey] = storage
}
switch vm := viewMetadata.(type) {
case nil:
return storage, nil
case []interface{}:
if len(vm) != 3 {
return nil, fmt.Errorf(
"PersistentLoad: unexpected view metadata length")
}
panic("viewMetadata not implemented")
// TODO: ...
// view_key, offset, view_size = view_metadata
// if view_key not in deserialized_objects:
// deserialized_objects[view_key] = storage[offset:offset + view_size]
// return deserialized_objects[view_key]
default:
return nil, fmt.Errorf("PersistentLoad: unexpected view metadata type")
}
case "module":
if tuple.Len() < 2 {
return nil, fmt.Errorf("PersistentLoad: unexpected module data length")
}
return tuple.Get(1), nil
default:
return nil, fmt.Errorf("Unexpected saved ID type: %s", typename)
}
}
result, err := u.Load()
if err != nil {
return nil, err
}
rawStorageKeys, err := unpickle(f)
if err != nil {
return nil, err
}
storageKeys, err := makeStorageKeys(rawStorageKeys)
if err != nil {
return nil, err
}
for _, key := range storageKeys {
storageObj, ok := deserializedObjects[key]
if !ok {
return nil, fmt.Errorf("storage object not found for key '%s'", key)
}
err = storageObj.SetFromFile(f)
if err != nil {
return nil, err
}
}
return result, nil
}
func makeStorageKeys(obj interface{}) ([]string, error) {
list, ok := obj.(*List)
if !ok {
return nil, fmt.Errorf("invalid storage keys data")
}
keys := make([]string, len(*list))
for i, rawKey := range *list {
key, keyOk := rawKey.(string)
if !keyOk {
return nil, fmt.Errorf("invalid storage key")
}
keys[i] = key
}
return keys, nil
}
func readAndCheckMagicNumber(r io.Reader) error {
obj, err := unpickle(r)
if err != nil {
return err
}
if n, ok := obj.(*big.Int); !ok || n.Text(16) != hexMagicNumber {
return ErrInvalidMagicNumber
}
return nil
}
func readAndChecProtocolVersion(r io.Reader) error {
obj, err := unpickle(r)
if err != nil {
return err
}
if n, ok := obj.(int); !ok || n != protocolVersion {
return ErrInvalidProtocolVersion
}
return nil
}
func unpickle(r io.Reader) (interface{}, error) {
u := NewUnpickler(r)
return u.Load()
}
func isZipFile(filename string) bool {
r, err := zip.OpenReader(filename)
if err != nil {
return false
}
r.Close()
return true
}
func makePickleFindClass(fallback func(module, name string) (interface{}, error)) func(module, name string) (interface{}, error) {
return func(module, name string) (interface{}, error) {
switch module + "." + name {
case "torch._utils._rebuild_tensor":
return &RebuildTensor{}, nil
case "torch._utils._rebuild_tensor_v2":
return &RebuildTensorV2{}, nil
case "torch._utils._rebuild_parameter":
return &RebuildParameter{}, nil
case "torch._utils._sparse_tensor":
return &RebuildSparseTensor{}, nil
case "torch._utils._rebuild_sparse_csr_tensor":
return &RebuildSparseCsrTensor{}, nil
case "torch._utils._rebuild_device_tensor_from_numpy":
return &RebuildDeviceTensorFromNumpy{}, nil
case "torch._utils._rebuild_meta_tensor_no_storage":
return &RebuildMetaTensorNoStorage{}, nil
case "torch._utils._rebuild_qtensor":
return &RebuildQtensor{}, nil
case "torch.FloatStorage":
return &FloatStorageClass{}, nil
case "torch.HalfStorage":
return &HalfStorageClass{}, nil
case "torch.DoubleStorage":
return &DoubleStorageClass{}, nil
case "torch.CharStorage":
return &CharStorageClass{}, nil
case "torch.ShortStorage":
return &ShortStorageClass{}, nil
case "torch.IntStorage":
return &IntStorageClass{}, nil
case "torch.LongStorage":
return &LongStorageClass{}, nil
case "torch.ByteStorage":
return &ByteStorageClass{}, nil
case "torch.BoolStorage":
return &BoolStorageClass{}, nil
case "torch.nn.backends.thnn._get_thnn_function_backend":
// this is for historical pickle deserilaization, it is not used otherwise
return getThnnFunctionBackend{}, nil
default:
if fallback == nil {
return nil, fmt.Errorf("class not found: %s %s", module, name)
}
return fallback(module, name)
}
}
}
// LoadAll finds and loads all weights from varstore.
// It will throw err if one of weights from varstore cannot find from loaded pretrained model.
func LoadAll(vs *nn.VarStore, modelFile string) error {
weights, err := Decode(modelFile)
if err != nil {
err = fmt.Errorf("LoadAll() failed: %w", err)
return err
}
var namedTensors []ts.NamedTensor
for n, x := range weights {
namedTs := ts.NamedTensor{
Name: n,
Tensor: x,
}
namedTensors = append(namedTensors, namedTs)
}
err = vs.LoadWeights(namedTensors)
if err != nil {
return err
}
for _, x := range weights {
x.MustDrop()
}
return nil
}
// LoadPartial finds and loads weights for varstore.
// It returns list of unfound weight names.
func LoadPartial(vs *nn.VarStore, modelFile string) ([]string, error) {
weights, err := Decode(modelFile)
if err != nil {
err = fmt.Errorf("LoadPartial() failed: %w", err)
return nil, err
}
var namedTensors []ts.NamedTensor
for n, x := range weights {
namedTs := ts.NamedTensor{
Name: n,
Tensor: x,
}
namedTensors = append(namedTensors, namedTs)
}
var missingVariables []string
missingVariables, err = vs.LoadWeightsPartial(namedTensors)
if err != nil {
return nil, err
}
for _, x := range weights {
x.MustDrop()
}
return missingVariables, nil
}
type ModelInfor struct {
weights map[string][]int64
dtype gotch.DType
}
func NewModelInfor(weights map[string][]int64, dtype gotch.DType) *ModelInfor {
return &ModelInfor{
weights: weights,
dtype: dtype,
}
}
func (m *ModelInfor) String() string {
var summary string
layers := make([]string, 0, len(m.weights))
for tsName := range m.weights {
layers = append(layers, tsName)
}
sort.Strings(layers)
for _, l := range layers {
var x []int64
for tsName, shape := range m.weights {
if tsName == l {
x = shape
break
}
}
summary += fmt.Sprintf("%s - %+v\n", l, x)
}
summary += fmt.Sprintf("Num of variables: %v\n", len(m.weights))
summary += fmt.Sprintf("Tensor DType: %v\n", m.dtype)
return summary
}
func (m *ModelInfor) DType() gotch.DType {
return m.dtype
}
func (m *ModelInfor) Parameters() int {
return len(m.weights)
}
// LoadInfo loads pretrained weights and prints out name and shape of weights.
func LoadModelInfo(modelFile string) (*ModelInfor, error) {
weights, err := Decode(modelFile)
if err != nil {
err = fmt.Errorf("LoadInfo() failed: %w", err)
return nil, err
}
w := make(map[string][]int64)
var dtype gotch.DType
isFirst := true
for n, x := range weights {
w[n] = x.MustSize()
if isFirst {
dtype = x.DType()
isFirst = false
}
}
m := NewModelInfor(w, dtype)
ts.CleanUp()
return m, nil
}