Merge branch 'master' of ssh://github.com/sugarme/gotch into mem

This commit is contained in:
sugarme 2023-07-04 23:22:04 +10:00
commit 4aad97b216
5 changed files with 151 additions and 44 deletions

View File

@ -121,6 +121,7 @@ func main() {
if batchCount%500 == 0 {
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
}
fmt.Printf("dataIter: progress: %v\n", dataIter.Progress())
} // infinite for-loop
sampleStr := sample(data, lstm, linear, device)

View File

@ -10,6 +10,9 @@ import (
"os"
"strconv"
"strings"
"reflect"
"runtime"
)
// This file implements Python Pickle Machinery.
@ -552,6 +555,7 @@ func loadBinInt1(up *Unpickler) error {
err = fmt.Errorf("loadBinInt1() failed: %w", err)
return err
}
up.append(int(b))
return nil
@ -1007,10 +1011,9 @@ func loadTuple3(up *Unpickler) error {
objects := make([]interface{}, 3)
var (
err error
n int = 0
)
for i := 2; i >= 0; i-- {
objects[n], err = up.stackPop()
objects[i], err = up.stackPop()
if err != nil {
err = fmt.Errorf("loadTuple3() failed: %w", err)
}
@ -1251,6 +1254,7 @@ func loadGlobal(up *Unpickler) error {
err = fmt.Errorf("loadGlobal() failed: %w", err)
return err
}
up.append(class)
return nil
}
@ -1378,6 +1382,7 @@ func loadReduce(up *Unpickler) error {
err = fmt.Errorf("loadReduce() failed: %w", err)
return err
}
callable, callableOk := function.(Callable)
if !callableOk {
err := fmt.Errorf("REDUCE requires a Callable object: %#v", function)
@ -1935,3 +1940,7 @@ func Loads(s string) (interface{}, error) {
return up.Load()
}
func GetFunctionName(i interface{}) string {
return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name()
}

View File

@ -51,49 +51,135 @@ func Decode(filename string) (map[string]*ts.Tensor, error) {
return nil, err
}
dtype := reflect.TypeOf(result).String()
// Rebuild tensors from Storage tensors
namedTensors := make(map[string]*ts.Tensor)
dictResult, isOrderedDict := result.(*OrderedDict)
if !isOrderedDict {
err := fmt.Errorf("Decode() failed: expected 'OrderedDict' type, got %v\n", reflect.TypeOf(result).String())
switch dtype {
case "*pickle.Dict":
dictResult := *result.(*Dict)
for _, item := range dictResult {
name := item.Key
itemTyp := reflect.TypeOf(item.Value).String()
switch itemTyp {
case "*pickle.Dict": // Nested *pickle.Dict case
subResult := *item.Value.(*Dict)
for _, subItem := range subResult {
subName := subItem.Key
x, ok := subItem.Value.(*StorageTensor)
if !ok {
log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v. Skip decoding parameter %q ...\n", reflect.TypeOf(subItem.Value).String(), subName)
continue
}
data := x.Source.GetData()
size := x.Size
dtype := x.Source.DType()
device := x.Source.Device()
stride := x.Stride
storageOffset := x.StorageOffset
if reflect.ValueOf(data).Len() == 0 {
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
continue
}
// TODO. should we just skip them?
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
size = []int64{1}
stride = []int64{1}
}
x1 := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
if x.RequiresGrad {
x1.MustRequiresGrad_(x.RequiresGrad)
}
namedTensors[name.(string)] = x1
}
default:
sx, isStorageTensor := item.Value.(*StorageTensor)
// if !isStorageTensor {
// err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
// return nil, err
// }
if !isStorageTensor {
log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v, with value of %v. Skip decoding parameter %q ...\n", reflect.TypeOf(item.Value).String(), item.Value, name)
continue
}
data := sx.Source.GetData()
size := sx.Size
dtype := sx.Source.DType()
device := sx.Source.Device()
stride := sx.Stride
storageOffset := sx.StorageOffset
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
// log.Printf("data: %v\n", data)
// Dealing with Pytorch `..._tracked` variables.
if reflect.ValueOf(data).Len() == 0 {
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
continue
}
// TODO. should we just skip them?
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
size = []int64{1}
stride = []int64{1}
}
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
if sx.RequiresGrad {
x.MustRequiresGrad_(sx.RequiresGrad)
}
namedTensors[name.(string)] = x
}
}
case "*pickle.OrderedDict":
dictResult := result.(*OrderedDict)
for name, item := range dictResult.Map {
sx, isStorageTensor := item.Value.(*StorageTensor)
if !isStorageTensor {
err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
return nil, err
}
data := sx.Source.GetData()
size := sx.Size
dtype := sx.Source.DType()
device := sx.Source.Device()
stride := sx.Stride
storageOffset := sx.StorageOffset
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
// log.Printf("data: %v\n", data)
// Dealing with Pytorch `..._tracked` variables.
if reflect.ValueOf(data).Len() == 0 {
log.Printf("INFO: skip weigth %q with zero data length.\n", name.(string))
continue
}
// TODO. should we just skip them?
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
size = []int64{1}
stride = []int64{1}
}
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
if sx.RequiresGrad {
x.MustRequiresGrad_(sx.RequiresGrad)
}
namedTensors[name.(string)] = x
}
default:
err := fmt.Errorf("Decode() failed: expected '*pickle.OrderedDict' or '*pickle.Dict' type, got %v\n", dtype)
return nil, err
}
for name, item := range dictResult.Map {
sx, isStorageTensor := item.Value.(*StorageTensor)
if !isStorageTensor {
err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
return nil, err
}
data := sx.Source.GetData()
size := sx.Size
dtype := sx.Source.DType()
device := sx.Source.Device()
stride := sx.Stride
storageOffset := sx.StorageOffset
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
// log.Printf("data: %v\n", data)
// Dealing with Pytorch `..._tracked` variables.
if reflect.ValueOf(data).Len() == 0 {
log.Printf("INFO: skip weigth %q with zero data length.\n", name.(string))
continue
}
// TODO. should we just skip them?
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
size = []int64{1}
stride = []int64{1}
}
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
if sx.RequiresGrad {
x.MustRequiresGrad_(sx.RequiresGrad)
}
namedTensors[name.(string)] = x
}
return namedTensors, nil
}
@ -161,6 +247,7 @@ func loadZipFile(filename string, newUnpickler func(r io.Reader) Unpickler) (int
if !dataTypeOk || !keyOk || !locationOk || !sizeOk {
return nil, fmt.Errorf("PersistentLoad: unexpected data types")
}
storage, storageExists := loadedStorages[key]
if !storageExists {
storage, err = loadTensor(dataType, size, location, key, fileRecords)

View File

@ -567,7 +567,7 @@ func setFromFile(s Storage, r io.Reader) error {
}
// StorageTensor:
//===============
// ===============
type StorageTensor struct {
Source Storage
StorageOffset int64
@ -629,7 +629,9 @@ func (r *RebuildTensorV2) Call(args ...interface{}) (interface{}, error) {
storageOffset, storageOffsetOk := args[1].(int)
size, sizeOk := args[2].(*Tuple)
stride, strideOk := args[3].(*Tuple)
requiresGrad, requiresGradOk := args[4].(bool)
// arg[5] "backward hooks" is unused
if !storageOk || !storageOffsetOk || !sizeOk || !strideOk ||
!requiresGradOk {
@ -681,7 +683,9 @@ func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
for i := 0; i < length; i++ {
value, ok := tuple.Get(i).(int)
if !ok {
return nil, fmt.Errorf("tuple of ints expected: %#v", tuple)
// return nil, fmt.Errorf("tuple of ints expected. Got %#v", tuple)
fmt.Printf("WARNING: tuple of ints expected. Got %#v\n", tuple)
continue
}
slice[i] = int64(value)
}
@ -689,7 +693,7 @@ func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
}
// Rebuild Sparse Tensor:
//=======================
// =======================
// ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L178
type RebuildSparseTensor struct{}

View File

@ -226,6 +226,12 @@ func NewTextData(filename string) (*TextData, error) {
}, nil
}
func (tdi *TextDataIter) Progress() float32 {
startIndex := (tdi.BatchIndex * tdi.BatchSize)
availableIndices := tdi.IndexesLen
progress := float32(startIndex) / float32(availableIndices)
return progress
}
// Labels returns the number of different `character` (rune) used by the dataset.
func (td *TextData) Labels() (retVal int64) {
return int64(len(td.CharForLabel))