added more pickle types

This commit is contained in:
sugarme 2022-02-24 15:25:15 +11:00
parent 7aa4e50199
commit 36b61bf719
5 changed files with 120 additions and 22 deletions

View File

@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
- Fixed wrong `cacheDir` and switch off logging.
- Added more pickle classes to handle unpickling
## [Nofix]
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.

View File

@ -1,36 +1,32 @@
package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
"github.com/sugarme/gotch/pickle"
"github.com/sugarme/gotch/vision"
)
func main() {
device := gotch.CPU
vs := nn.NewVarStore(device)
net := vision.VGG16(vs.Root(), 1000)
// modelName := "vgg16"
// modelName := "mobilenet_v2"
// modelName := "resnet18"
// modelName := "alexnet"
// modelName := "squeezenet1_1"
// modelName := "inception_v3_google"
modelName := "efficientnet_b4"
modelName := "vgg16"
modelUrl, ok := gotch.ModelUrls[modelName]
url, ok := gotch.ModelUrls[modelName]
if !ok {
log.Fatal("model name %q not found.", modelName)
log.Fatalf("Unsupported model name %q\n", modelName)
}
modelFile, err := gotch.CachedPath(url)
if err != nil {
panic(err)
}
modelFile, err := gotch.CachedPath(modelUrl)
err = pickle.LoadInfo(modelFile)
if err != nil {
log.Fatal(err)
}
err = pickle.LoadAll(vs, modelFile)
if err != nil {
log.Fatal(err)
}
fmt.Printf("%v\n", net)
vs.Summary()
}

View File

@ -12,12 +12,10 @@ var (
)
func init() {
// default path: {$HOME}/.cache/gotch
homeDir := os.Getenv("HOME")
CacheDir = fmt.Sprintf("%s/.cache/gotch", homeDir)
CacheDir = fmt.Sprintf("%s/.cache/gotch", homeDir) // default dir: "{$HOME}/.cache/gotch"
initEnv()
// log.Printf("INFO: CacheDir=%q\n", CacheDir)
}

View File

@ -71,7 +71,15 @@ func Decode(filename string) (map[string]*ts.Tensor, error) {
stride := sx.Stride
storageOffset := sx.StorageOffset
// fmt.Printf("%q - shape: %v - stride: %v - storageOffset: %v\n", sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
// log.Printf("data: %v\n", data)
// Dealing with Pytorch `..._tracked` variables.
// TODO. should we just skip them?
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
size = []int64{1}
stride = []int64{1}
}
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
if sx.RequiresGrad {
@ -373,6 +381,19 @@ func makePickleFindClass(fallback func(module, name string) (interface{}, error)
return &RebuildTensor{}, nil
case "torch._utils._rebuild_tensor_v2":
return &RebuildTensorV2{}, nil
case "torch._utils._rebuild_parameter":
return &RebuildParameter{}, nil
case "torch._utils._sparse_tensor":
return &RebuildSparseTensor{}, nil
case "torch._utils._rebuild_sparse_csr_tensor":
return &RebuildSparseCsrTensor{}, nil
case "torch._utils._rebuild_device_tensor_from_numpy":
return &RebuildDeviceTensorFromNumpy{}, nil
case "torch._utils._rebuild_meta_tensor_no_storage":
return &RebuildMetaTensorNoStorage{}, nil
case "torch._utils._rebuild_qtensor":
return &RebuildQtensor{}, nil
case "torch.FloatStorage":
return &FloatStorageClass{}, nil
case "torch.HalfStorage":

View File

@ -653,6 +653,28 @@ func (r *RebuildTensorV2) Call(args ...interface{}) (interface{}, error) {
return tensor, nil
}
// Rebuild Parameter:
// ==================
// RebuildTensor represents a struct to rebuild tensor back from pickle object.
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L240
type RebuildParameter struct{}
var _ Callable = &RebuildParameter{}
func (r *RebuildParameter) Call(args ...interface{}) (interface{}, error) {
if len(args) != 3 { // data(*StorageTensor), requires_grad, backward_hooks
return nil, fmt.Errorf("RebuildParameter unexpected 3 args, got %d: %#v", len(args), args)
}
tensor, ok := args[0].(*StorageTensor)
if !ok {
err := fmt.Errorf("RebuildParameter.Call() failed: unexpected arg: %#v\n", args)
return nil, err
}
return tensor, nil
}
func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
length := tuple.Len()
slice := make([]int64, length)
@ -665,3 +687,63 @@ func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
}
return slice, nil
}
// Rebuild Sparse Tensor:
//=======================
// ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L178
type RebuildSparseTensor struct{}
var _ Callable = &RebuildSparseTensor{}
func (r *RebuildSparseTensor) Call(args ...interface{}) (interface{}, error) {
// TODO.
panic("RebuildSparseTensor.Call(): NotImplementedError")
}
// Rebuild Sparse CSR Tensor:
// ==========================
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L187
type RebuildSparseCsrTensor struct{}
var _ Callable = &RebuildSparseCsrTensor{}
func (r *RebuildSparseCsrTensor) Call(args ...interface{}) (interface{}, error) {
// TODO.
panic("RebuildSparseCsrTensor.Call(): NotImplementedError")
}
// Rebuild Device Tensor From Numpy:
// =================================
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L197
type RebuildDeviceTensorFromNumpy struct{}
var _ Callable = &RebuildDeviceTensorFromNumpy{}
func (r *RebuildDeviceTensorFromNumpy) Call(args ...interface{}) (interface{}, error) {
// TODO.
panic("RebuildDeviceTensorFromNumpy.Call(): NotImplementedError")
}
// Rebuild Meta Tensor No Storage:
// ===============================
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L208
type RebuildMetaTensorNoStorage struct{}
var _ Callable = &RebuildMetaTensorNoStorage{}
func (r *RebuildMetaTensorNoStorage) Call(args ...interface{}) (interface{}, error) {
// TODO.
panic("RebuildMetaTensorNoStorage.Call(): NotImplementedError")
}
// Rebuild QTensor:
// ================
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L214
type RebuildQtensor struct{}
var _ Callable = &RebuildQtensor{}
func (r *RebuildQtensor) Call(args ...interface{}) (interface{}, error) {
// TODO.
panic("RebuildQtensor.Call(): NotImplementedError")
}