From 8a08a1d6abf4dbd8efd04bfc1aef4dca096b5eb0 Mon Sep 17 00:00:00 2001 From: Tim Cassidy Date: Fri, 10 Feb 2023 19:30:09 -0800 Subject: [PATCH 1/2] add Progress() method to TextDataIter --- example/char-rnn/main.go | 1 + ts/data.go | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/example/char-rnn/main.go b/example/char-rnn/main.go index 9ed957c..7d187ee 100644 --- a/example/char-rnn/main.go +++ b/example/char-rnn/main.go @@ -121,6 +121,7 @@ func main() { if batchCount%500 == 0 { fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount) } + fmt.Printf("dataIter: progress: %v\n", dataIter.Progress()) } // infinite for-loop sampleStr := sample(data, lstm, linear, device) diff --git a/ts/data.go b/ts/data.go index 493a6d3..16dd97e 100644 --- a/ts/data.go +++ b/ts/data.go @@ -226,6 +226,12 @@ func NewTextData(filename string) (*TextData, error) { }, nil } +func (tdi *TextDataIter) Progress() float32 { + startIndex := (tdi.BatchIndex * tdi.BatchSize) + availableIndices := tdi.IndexesLen + progress := float32(startIndex) / float32(availableIndices) + return progress +} // Labels returns the number of different `character` (rune) used by the dataset. func (td *TextData) Labels() (retVal int64) { return int64(len(td.CharForLabel)) From 568f78930f2e640bfe4eaad2c84adf31b80a7f25 Mon Sep 17 00:00:00 2001 From: sugarme Date: Mon, 13 Feb 2023 12:17:29 +1100 Subject: [PATCH 2/2] pickle added more cases for unpickling *pickle.Dict --- pickle/pickle.go | 13 +++- pickle/serialization.go | 165 ++++++++++++++++++++++++++++++---------- pickle/storage.go | 10 ++- 3 files changed, 144 insertions(+), 44 deletions(-) diff --git a/pickle/pickle.go b/pickle/pickle.go index 4bd5725..4b174a3 100644 --- a/pickle/pickle.go +++ b/pickle/pickle.go @@ -10,6 +10,9 @@ import ( "os" "strconv" "strings" + + "reflect" + "runtime" ) // This file implements Python Pickle Machinery. @@ -552,6 +555,7 @@ func loadBinInt1(up *Unpickler) error { err = fmt.Errorf("loadBinInt1() failed: %w", err) return err } + up.append(int(b)) return nil @@ -1007,10 +1011,9 @@ func loadTuple3(up *Unpickler) error { objects := make([]interface{}, 3) var ( err error - n int = 0 ) for i := 2; i >= 0; i-- { - objects[n], err = up.stackPop() + objects[i], err = up.stackPop() if err != nil { err = fmt.Errorf("loadTuple3() failed: %w", err) } @@ -1251,6 +1254,7 @@ func loadGlobal(up *Unpickler) error { err = fmt.Errorf("loadGlobal() failed: %w", err) return err } + up.append(class) return nil } @@ -1378,6 +1382,7 @@ func loadReduce(up *Unpickler) error { err = fmt.Errorf("loadReduce() failed: %w", err) return err } + callable, callableOk := function.(Callable) if !callableOk { err := fmt.Errorf("REDUCE requires a Callable object: %#v", function) @@ -1935,3 +1940,7 @@ func Loads(s string) (interface{}, error) { return up.Load() } + +func GetFunctionName(i interface{}) string { + return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() +} diff --git a/pickle/serialization.go b/pickle/serialization.go index eb84606..c5d9617 100644 --- a/pickle/serialization.go +++ b/pickle/serialization.go @@ -51,49 +51,135 @@ func Decode(filename string) (map[string]*ts.Tensor, error) { return nil, err } + dtype := reflect.TypeOf(result).String() + // Rebuild tensors from Storage tensors namedTensors := make(map[string]*ts.Tensor) - dictResult, isOrderedDict := result.(*OrderedDict) - if !isOrderedDict { - err := fmt.Errorf("Decode() failed: expected 'OrderedDict' type, got %v\n", reflect.TypeOf(result).String()) + switch dtype { + case "*pickle.Dict": + dictResult := *result.(*Dict) + for _, item := range dictResult { + name := item.Key + itemTyp := reflect.TypeOf(item.Value).String() + switch itemTyp { + case "*pickle.Dict": // Nested *pickle.Dict case + subResult := *item.Value.(*Dict) + for _, subItem := range subResult { + subName := subItem.Key + x, ok := subItem.Value.(*StorageTensor) + if !ok { + log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v. Skip decoding parameter %q ...\n", reflect.TypeOf(subItem.Value).String(), subName) + continue + } + data := x.Source.GetData() + size := x.Size + dtype := x.Source.DType() + device := x.Source.Device() + stride := x.Stride + storageOffset := x.StorageOffset + if reflect.ValueOf(data).Len() == 0 { + log.Printf("INFO: skip weight %q with zero data length.\n", name.(string)) + continue + } + + // TODO. should we just skip them? + if reflect.ValueOf(data).Len() == 1 && len(size) == 0 { + size = []int64{1} + stride = []int64{1} + } + + x1 := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true) + if x.RequiresGrad { + x1.MustRequiresGrad_(x.RequiresGrad) + } + + namedTensors[name.(string)] = x1 + } + + default: + sx, isStorageTensor := item.Value.(*StorageTensor) + + // if !isStorageTensor { + // err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String()) + // return nil, err + // } + if !isStorageTensor { + log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v, with value of %v. Skip decoding parameter %q ...\n", reflect.TypeOf(item.Value).String(), item.Value, name) + continue + } + + data := sx.Source.GetData() + size := sx.Size + dtype := sx.Source.DType() + device := sx.Source.Device() + stride := sx.Stride + storageOffset := sx.StorageOffset + + // log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset) + // log.Printf("data: %v\n", data) + + // Dealing with Pytorch `..._tracked` variables. + if reflect.ValueOf(data).Len() == 0 { + log.Printf("INFO: skip weight %q with zero data length.\n", name.(string)) + continue + } + + // TODO. should we just skip them? + if reflect.ValueOf(data).Len() == 1 && len(size) == 0 { + size = []int64{1} + stride = []int64{1} + } + + x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true) + if sx.RequiresGrad { + x.MustRequiresGrad_(sx.RequiresGrad) + } + + namedTensors[name.(string)] = x + } + } + case "*pickle.OrderedDict": + dictResult := result.(*OrderedDict) + for name, item := range dictResult.Map { + sx, isStorageTensor := item.Value.(*StorageTensor) + if !isStorageTensor { + err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String()) + return nil, err + } + + data := sx.Source.GetData() + size := sx.Size + dtype := sx.Source.DType() + device := sx.Source.Device() + stride := sx.Stride + storageOffset := sx.StorageOffset + + // log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset) + // log.Printf("data: %v\n", data) + + // Dealing with Pytorch `..._tracked` variables. + if reflect.ValueOf(data).Len() == 0 { + log.Printf("INFO: skip weigth %q with zero data length.\n", name.(string)) + continue + } + + // TODO. should we just skip them? + if reflect.ValueOf(data).Len() == 1 && len(size) == 0 { + size = []int64{1} + stride = []int64{1} + } + + x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true) + if sx.RequiresGrad { + x.MustRequiresGrad_(sx.RequiresGrad) + } + + namedTensors[name.(string)] = x + } + default: + err := fmt.Errorf("Decode() failed: expected '*pickle.OrderedDict' or '*pickle.Dict' type, got %v\n", dtype) return nil, err } - for name, item := range dictResult.Map { - sx, isStorageTensor := item.Value.(*StorageTensor) - if !isStorageTensor { - err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String()) - return nil, err - } - - data := sx.Source.GetData() - size := sx.Size - dtype := sx.Source.DType() - device := sx.Source.Device() - stride := sx.Stride - storageOffset := sx.StorageOffset - - // log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset) - // log.Printf("data: %v\n", data) - - // Dealing with Pytorch `..._tracked` variables. - if reflect.ValueOf(data).Len() == 0 { - log.Printf("INFO: skip weigth %q with zero data length.\n", name.(string)) - continue - } - - // TODO. should we just skip them? - if reflect.ValueOf(data).Len() == 1 && len(size) == 0 { - size = []int64{1} - stride = []int64{1} - } - - x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true) - if sx.RequiresGrad { - x.MustRequiresGrad_(sx.RequiresGrad) - } - - namedTensors[name.(string)] = x - } return namedTensors, nil } @@ -161,6 +247,7 @@ func loadZipFile(filename string, newUnpickler func(r io.Reader) Unpickler) (int if !dataTypeOk || !keyOk || !locationOk || !sizeOk { return nil, fmt.Errorf("PersistentLoad: unexpected data types") } + storage, storageExists := loadedStorages[key] if !storageExists { storage, err = loadTensor(dataType, size, location, key, fileRecords) diff --git a/pickle/storage.go b/pickle/storage.go index 0374fbb..ce6aba0 100644 --- a/pickle/storage.go +++ b/pickle/storage.go @@ -567,7 +567,7 @@ func setFromFile(s Storage, r io.Reader) error { } // StorageTensor: -//=============== +// =============== type StorageTensor struct { Source Storage StorageOffset int64 @@ -629,7 +629,9 @@ func (r *RebuildTensorV2) Call(args ...interface{}) (interface{}, error) { storageOffset, storageOffsetOk := args[1].(int) size, sizeOk := args[2].(*Tuple) stride, strideOk := args[3].(*Tuple) + requiresGrad, requiresGradOk := args[4].(bool) + // arg[5] "backward hooks" is unused if !storageOk || !storageOffsetOk || !sizeOk || !strideOk || !requiresGradOk { @@ -681,7 +683,9 @@ func tupleToInt64Slice(tuple *Tuple) ([]int64, error) { for i := 0; i < length; i++ { value, ok := tuple.Get(i).(int) if !ok { - return nil, fmt.Errorf("tuple of ints expected: %#v", tuple) + // return nil, fmt.Errorf("tuple of ints expected. Got %#v", tuple) + fmt.Printf("WARNING: tuple of ints expected. Got %#v\n", tuple) + continue } slice[i] = int64(value) } @@ -689,7 +693,7 @@ func tupleToInt64Slice(tuple *Tuple) ([]int64, error) { } // Rebuild Sparse Tensor: -//======================= +// ======================= // ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L178 type RebuildSparseTensor struct{}