added more pickle types
This commit is contained in:
parent
7aa4e50199
commit
36b61bf719
|
@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
## [Unreleased]
|
||||
- Fixed wrong `cacheDir` and switch off logging.
|
||||
- Added more pickle classes to handle unpickling
|
||||
|
||||
## [Nofix]
|
||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||
|
|
|
@ -1,36 +1,32 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
"github.com/sugarme/gotch/pickle"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
func main() {
|
||||
device := gotch.CPU
|
||||
vs := nn.NewVarStore(device)
|
||||
net := vision.VGG16(vs.Root(), 1000)
|
||||
// modelName := "vgg16"
|
||||
// modelName := "mobilenet_v2"
|
||||
// modelName := "resnet18"
|
||||
// modelName := "alexnet"
|
||||
// modelName := "squeezenet1_1"
|
||||
// modelName := "inception_v3_google"
|
||||
modelName := "efficientnet_b4"
|
||||
|
||||
modelName := "vgg16"
|
||||
modelUrl, ok := gotch.ModelUrls[modelName]
|
||||
url, ok := gotch.ModelUrls[modelName]
|
||||
if !ok {
|
||||
log.Fatal("model name %q not found.", modelName)
|
||||
log.Fatalf("Unsupported model name %q\n", modelName)
|
||||
}
|
||||
modelFile, err := gotch.CachedPath(url)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
modelFile, err := gotch.CachedPath(modelUrl)
|
||||
err = pickle.LoadInfo(modelFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = pickle.LoadAll(vs, modelFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%v\n", net)
|
||||
vs.Summary()
|
||||
}
|
||||
|
|
4
init.go
4
init.go
|
@ -12,12 +12,10 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
// default path: {$HOME}/.cache/gotch
|
||||
homeDir := os.Getenv("HOME")
|
||||
CacheDir = fmt.Sprintf("%s/.cache/gotch", homeDir)
|
||||
CacheDir = fmt.Sprintf("%s/.cache/gotch", homeDir) // default dir: "{$HOME}/.cache/gotch"
|
||||
|
||||
initEnv()
|
||||
|
||||
// log.Printf("INFO: CacheDir=%q\n", CacheDir)
|
||||
}
|
||||
|
||||
|
|
|
@ -71,7 +71,15 @@ func Decode(filename string) (map[string]*ts.Tensor, error) {
|
|||
stride := sx.Stride
|
||||
storageOffset := sx.StorageOffset
|
||||
|
||||
// fmt.Printf("%q - shape: %v - stride: %v - storageOffset: %v\n", sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
|
||||
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
|
||||
// log.Printf("data: %v\n", data)
|
||||
|
||||
// Dealing with Pytorch `..._tracked` variables.
|
||||
// TODO. should we just skip them?
|
||||
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
|
||||
size = []int64{1}
|
||||
stride = []int64{1}
|
||||
}
|
||||
|
||||
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
||||
if sx.RequiresGrad {
|
||||
|
@ -373,6 +381,19 @@ func makePickleFindClass(fallback func(module, name string) (interface{}, error)
|
|||
return &RebuildTensor{}, nil
|
||||
case "torch._utils._rebuild_tensor_v2":
|
||||
return &RebuildTensorV2{}, nil
|
||||
case "torch._utils._rebuild_parameter":
|
||||
return &RebuildParameter{}, nil
|
||||
case "torch._utils._sparse_tensor":
|
||||
return &RebuildSparseTensor{}, nil
|
||||
case "torch._utils._rebuild_sparse_csr_tensor":
|
||||
return &RebuildSparseCsrTensor{}, nil
|
||||
case "torch._utils._rebuild_device_tensor_from_numpy":
|
||||
return &RebuildDeviceTensorFromNumpy{}, nil
|
||||
case "torch._utils._rebuild_meta_tensor_no_storage":
|
||||
return &RebuildMetaTensorNoStorage{}, nil
|
||||
case "torch._utils._rebuild_qtensor":
|
||||
return &RebuildQtensor{}, nil
|
||||
|
||||
case "torch.FloatStorage":
|
||||
return &FloatStorageClass{}, nil
|
||||
case "torch.HalfStorage":
|
||||
|
|
|
@ -653,6 +653,28 @@ func (r *RebuildTensorV2) Call(args ...interface{}) (interface{}, error) {
|
|||
return tensor, nil
|
||||
}
|
||||
|
||||
// Rebuild Parameter:
|
||||
// ==================
|
||||
// RebuildTensor represents a struct to rebuild tensor back from pickle object.
|
||||
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L240
|
||||
type RebuildParameter struct{}
|
||||
|
||||
var _ Callable = &RebuildParameter{}
|
||||
|
||||
func (r *RebuildParameter) Call(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 3 { // data(*StorageTensor), requires_grad, backward_hooks
|
||||
return nil, fmt.Errorf("RebuildParameter unexpected 3 args, got %d: %#v", len(args), args)
|
||||
}
|
||||
|
||||
tensor, ok := args[0].(*StorageTensor)
|
||||
if !ok {
|
||||
err := fmt.Errorf("RebuildParameter.Call() failed: unexpected arg: %#v\n", args)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tensor, nil
|
||||
}
|
||||
|
||||
func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
|
||||
length := tuple.Len()
|
||||
slice := make([]int64, length)
|
||||
|
@ -665,3 +687,63 @@ func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
|
|||
}
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
// Rebuild Sparse Tensor:
|
||||
//=======================
|
||||
// ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L178
|
||||
type RebuildSparseTensor struct{}
|
||||
|
||||
var _ Callable = &RebuildSparseTensor{}
|
||||
|
||||
func (r *RebuildSparseTensor) Call(args ...interface{}) (interface{}, error) {
|
||||
// TODO.
|
||||
panic("RebuildSparseTensor.Call(): NotImplementedError")
|
||||
}
|
||||
|
||||
// Rebuild Sparse CSR Tensor:
|
||||
// ==========================
|
||||
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L187
|
||||
type RebuildSparseCsrTensor struct{}
|
||||
|
||||
var _ Callable = &RebuildSparseCsrTensor{}
|
||||
|
||||
func (r *RebuildSparseCsrTensor) Call(args ...interface{}) (interface{}, error) {
|
||||
// TODO.
|
||||
panic("RebuildSparseCsrTensor.Call(): NotImplementedError")
|
||||
}
|
||||
|
||||
// Rebuild Device Tensor From Numpy:
|
||||
// =================================
|
||||
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L197
|
||||
type RebuildDeviceTensorFromNumpy struct{}
|
||||
|
||||
var _ Callable = &RebuildDeviceTensorFromNumpy{}
|
||||
|
||||
func (r *RebuildDeviceTensorFromNumpy) Call(args ...interface{}) (interface{}, error) {
|
||||
// TODO.
|
||||
panic("RebuildDeviceTensorFromNumpy.Call(): NotImplementedError")
|
||||
}
|
||||
|
||||
// Rebuild Meta Tensor No Storage:
|
||||
// ===============================
|
||||
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L208
|
||||
type RebuildMetaTensorNoStorage struct{}
|
||||
|
||||
var _ Callable = &RebuildMetaTensorNoStorage{}
|
||||
|
||||
func (r *RebuildMetaTensorNoStorage) Call(args ...interface{}) (interface{}, error) {
|
||||
// TODO.
|
||||
panic("RebuildMetaTensorNoStorage.Call(): NotImplementedError")
|
||||
}
|
||||
|
||||
// Rebuild QTensor:
|
||||
// ================
|
||||
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L214
|
||||
type RebuildQtensor struct{}
|
||||
|
||||
var _ Callable = &RebuildQtensor{}
|
||||
|
||||
func (r *RebuildQtensor) Call(args ...interface{}) (interface{}, error) {
|
||||
// TODO.
|
||||
panic("RebuildQtensor.Call(): NotImplementedError")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user