Merge branch 'master' of ssh://github.com/sugarme/gotch into mem
This commit is contained in:
commit
4aad97b216
|
@ -121,6 +121,7 @@ func main() {
|
||||||
if batchCount%500 == 0 {
|
if batchCount%500 == 0 {
|
||||||
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
|
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
|
||||||
}
|
}
|
||||||
|
fmt.Printf("dataIter: progress: %v\n", dataIter.Progress())
|
||||||
} // infinite for-loop
|
} // infinite for-loop
|
||||||
|
|
||||||
sampleStr := sample(data, lstm, linear, device)
|
sampleStr := sample(data, lstm, linear, device)
|
||||||
|
|
|
@ -10,6 +10,9 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This file implements Python Pickle Machinery.
|
// This file implements Python Pickle Machinery.
|
||||||
|
@ -552,6 +555,7 @@ func loadBinInt1(up *Unpickler) error {
|
||||||
err = fmt.Errorf("loadBinInt1() failed: %w", err)
|
err = fmt.Errorf("loadBinInt1() failed: %w", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
up.append(int(b))
|
up.append(int(b))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -1007,10 +1011,9 @@ func loadTuple3(up *Unpickler) error {
|
||||||
objects := make([]interface{}, 3)
|
objects := make([]interface{}, 3)
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
n int = 0
|
|
||||||
)
|
)
|
||||||
for i := 2; i >= 0; i-- {
|
for i := 2; i >= 0; i-- {
|
||||||
objects[n], err = up.stackPop()
|
objects[i], err = up.stackPop()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("loadTuple3() failed: %w", err)
|
err = fmt.Errorf("loadTuple3() failed: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -1251,6 +1254,7 @@ func loadGlobal(up *Unpickler) error {
|
||||||
err = fmt.Errorf("loadGlobal() failed: %w", err)
|
err = fmt.Errorf("loadGlobal() failed: %w", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
up.append(class)
|
up.append(class)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1378,6 +1382,7 @@ func loadReduce(up *Unpickler) error {
|
||||||
err = fmt.Errorf("loadReduce() failed: %w", err)
|
err = fmt.Errorf("loadReduce() failed: %w", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
callable, callableOk := function.(Callable)
|
callable, callableOk := function.(Callable)
|
||||||
if !callableOk {
|
if !callableOk {
|
||||||
err := fmt.Errorf("REDUCE requires a Callable object: %#v", function)
|
err := fmt.Errorf("REDUCE requires a Callable object: %#v", function)
|
||||||
|
@ -1935,3 +1940,7 @@ func Loads(s string) (interface{}, error) {
|
||||||
|
|
||||||
return up.Load()
|
return up.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetFunctionName(i interface{}) string {
|
||||||
|
return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name()
|
||||||
|
}
|
||||||
|
|
|
@ -51,49 +51,135 @@ func Decode(filename string) (map[string]*ts.Tensor, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dtype := reflect.TypeOf(result).String()
|
||||||
|
|
||||||
// Rebuild tensors from Storage tensors
|
// Rebuild tensors from Storage tensors
|
||||||
namedTensors := make(map[string]*ts.Tensor)
|
namedTensors := make(map[string]*ts.Tensor)
|
||||||
dictResult, isOrderedDict := result.(*OrderedDict)
|
switch dtype {
|
||||||
if !isOrderedDict {
|
case "*pickle.Dict":
|
||||||
err := fmt.Errorf("Decode() failed: expected 'OrderedDict' type, got %v\n", reflect.TypeOf(result).String())
|
dictResult := *result.(*Dict)
|
||||||
|
for _, item := range dictResult {
|
||||||
|
name := item.Key
|
||||||
|
itemTyp := reflect.TypeOf(item.Value).String()
|
||||||
|
switch itemTyp {
|
||||||
|
case "*pickle.Dict": // Nested *pickle.Dict case
|
||||||
|
subResult := *item.Value.(*Dict)
|
||||||
|
for _, subItem := range subResult {
|
||||||
|
subName := subItem.Key
|
||||||
|
x, ok := subItem.Value.(*StorageTensor)
|
||||||
|
if !ok {
|
||||||
|
log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v. Skip decoding parameter %q ...\n", reflect.TypeOf(subItem.Value).String(), subName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := x.Source.GetData()
|
||||||
|
size := x.Size
|
||||||
|
dtype := x.Source.DType()
|
||||||
|
device := x.Source.Device()
|
||||||
|
stride := x.Stride
|
||||||
|
storageOffset := x.StorageOffset
|
||||||
|
if reflect.ValueOf(data).Len() == 0 {
|
||||||
|
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO. should we just skip them?
|
||||||
|
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
|
||||||
|
size = []int64{1}
|
||||||
|
stride = []int64{1}
|
||||||
|
}
|
||||||
|
|
||||||
|
x1 := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
||||||
|
if x.RequiresGrad {
|
||||||
|
x1.MustRequiresGrad_(x.RequiresGrad)
|
||||||
|
}
|
||||||
|
|
||||||
|
namedTensors[name.(string)] = x1
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
sx, isStorageTensor := item.Value.(*StorageTensor)
|
||||||
|
|
||||||
|
// if !isStorageTensor {
|
||||||
|
// err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
if !isStorageTensor {
|
||||||
|
log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v, with value of %v. Skip decoding parameter %q ...\n", reflect.TypeOf(item.Value).String(), item.Value, name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
data := sx.Source.GetData()
|
||||||
|
size := sx.Size
|
||||||
|
dtype := sx.Source.DType()
|
||||||
|
device := sx.Source.Device()
|
||||||
|
stride := sx.Stride
|
||||||
|
storageOffset := sx.StorageOffset
|
||||||
|
|
||||||
|
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
|
||||||
|
// log.Printf("data: %v\n", data)
|
||||||
|
|
||||||
|
// Dealing with Pytorch `..._tracked` variables.
|
||||||
|
if reflect.ValueOf(data).Len() == 0 {
|
||||||
|
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO. should we just skip them?
|
||||||
|
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
|
||||||
|
size = []int64{1}
|
||||||
|
stride = []int64{1}
|
||||||
|
}
|
||||||
|
|
||||||
|
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
||||||
|
if sx.RequiresGrad {
|
||||||
|
x.MustRequiresGrad_(sx.RequiresGrad)
|
||||||
|
}
|
||||||
|
|
||||||
|
namedTensors[name.(string)] = x
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "*pickle.OrderedDict":
|
||||||
|
dictResult := result.(*OrderedDict)
|
||||||
|
for name, item := range dictResult.Map {
|
||||||
|
sx, isStorageTensor := item.Value.(*StorageTensor)
|
||||||
|
if !isStorageTensor {
|
||||||
|
err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data := sx.Source.GetData()
|
||||||
|
size := sx.Size
|
||||||
|
dtype := sx.Source.DType()
|
||||||
|
device := sx.Source.Device()
|
||||||
|
stride := sx.Stride
|
||||||
|
storageOffset := sx.StorageOffset
|
||||||
|
|
||||||
|
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
|
||||||
|
// log.Printf("data: %v\n", data)
|
||||||
|
|
||||||
|
// Dealing with Pytorch `..._tracked` variables.
|
||||||
|
if reflect.ValueOf(data).Len() == 0 {
|
||||||
|
log.Printf("INFO: skip weigth %q with zero data length.\n", name.(string))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO. should we just skip them?
|
||||||
|
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
|
||||||
|
size = []int64{1}
|
||||||
|
stride = []int64{1}
|
||||||
|
}
|
||||||
|
|
||||||
|
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
||||||
|
if sx.RequiresGrad {
|
||||||
|
x.MustRequiresGrad_(sx.RequiresGrad)
|
||||||
|
}
|
||||||
|
|
||||||
|
namedTensors[name.(string)] = x
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
err := fmt.Errorf("Decode() failed: expected '*pickle.OrderedDict' or '*pickle.Dict' type, got %v\n", dtype)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for name, item := range dictResult.Map {
|
|
||||||
sx, isStorageTensor := item.Value.(*StorageTensor)
|
|
||||||
if !isStorageTensor {
|
|
||||||
err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
data := sx.Source.GetData()
|
|
||||||
size := sx.Size
|
|
||||||
dtype := sx.Source.DType()
|
|
||||||
device := sx.Source.Device()
|
|
||||||
stride := sx.Stride
|
|
||||||
storageOffset := sx.StorageOffset
|
|
||||||
|
|
||||||
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
|
|
||||||
// log.Printf("data: %v\n", data)
|
|
||||||
|
|
||||||
// Dealing with Pytorch `..._tracked` variables.
|
|
||||||
if reflect.ValueOf(data).Len() == 0 {
|
|
||||||
log.Printf("INFO: skip weigth %q with zero data length.\n", name.(string))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO. should we just skip them?
|
|
||||||
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
|
|
||||||
size = []int64{1}
|
|
||||||
stride = []int64{1}
|
|
||||||
}
|
|
||||||
|
|
||||||
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
|
||||||
if sx.RequiresGrad {
|
|
||||||
x.MustRequiresGrad_(sx.RequiresGrad)
|
|
||||||
}
|
|
||||||
|
|
||||||
namedTensors[name.(string)] = x
|
|
||||||
}
|
|
||||||
|
|
||||||
return namedTensors, nil
|
return namedTensors, nil
|
||||||
}
|
}
|
||||||
|
@ -161,6 +247,7 @@ func loadZipFile(filename string, newUnpickler func(r io.Reader) Unpickler) (int
|
||||||
if !dataTypeOk || !keyOk || !locationOk || !sizeOk {
|
if !dataTypeOk || !keyOk || !locationOk || !sizeOk {
|
||||||
return nil, fmt.Errorf("PersistentLoad: unexpected data types")
|
return nil, fmt.Errorf("PersistentLoad: unexpected data types")
|
||||||
}
|
}
|
||||||
|
|
||||||
storage, storageExists := loadedStorages[key]
|
storage, storageExists := loadedStorages[key]
|
||||||
if !storageExists {
|
if !storageExists {
|
||||||
storage, err = loadTensor(dataType, size, location, key, fileRecords)
|
storage, err = loadTensor(dataType, size, location, key, fileRecords)
|
||||||
|
|
|
@ -567,7 +567,7 @@ func setFromFile(s Storage, r io.Reader) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// StorageTensor:
|
// StorageTensor:
|
||||||
//===============
|
// ===============
|
||||||
type StorageTensor struct {
|
type StorageTensor struct {
|
||||||
Source Storage
|
Source Storage
|
||||||
StorageOffset int64
|
StorageOffset int64
|
||||||
|
@ -629,7 +629,9 @@ func (r *RebuildTensorV2) Call(args ...interface{}) (interface{}, error) {
|
||||||
storageOffset, storageOffsetOk := args[1].(int)
|
storageOffset, storageOffsetOk := args[1].(int)
|
||||||
size, sizeOk := args[2].(*Tuple)
|
size, sizeOk := args[2].(*Tuple)
|
||||||
stride, strideOk := args[3].(*Tuple)
|
stride, strideOk := args[3].(*Tuple)
|
||||||
|
|
||||||
requiresGrad, requiresGradOk := args[4].(bool)
|
requiresGrad, requiresGradOk := args[4].(bool)
|
||||||
|
|
||||||
// arg[5] "backward hooks" is unused
|
// arg[5] "backward hooks" is unused
|
||||||
if !storageOk || !storageOffsetOk || !sizeOk || !strideOk ||
|
if !storageOk || !storageOffsetOk || !sizeOk || !strideOk ||
|
||||||
!requiresGradOk {
|
!requiresGradOk {
|
||||||
|
@ -681,7 +683,9 @@ func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
|
||||||
for i := 0; i < length; i++ {
|
for i := 0; i < length; i++ {
|
||||||
value, ok := tuple.Get(i).(int)
|
value, ok := tuple.Get(i).(int)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("tuple of ints expected: %#v", tuple)
|
// return nil, fmt.Errorf("tuple of ints expected. Got %#v", tuple)
|
||||||
|
fmt.Printf("WARNING: tuple of ints expected. Got %#v\n", tuple)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
slice[i] = int64(value)
|
slice[i] = int64(value)
|
||||||
}
|
}
|
||||||
|
@ -689,7 +693,7 @@ func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rebuild Sparse Tensor:
|
// Rebuild Sparse Tensor:
|
||||||
//=======================
|
// =======================
|
||||||
// ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L178
|
// ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L178
|
||||||
type RebuildSparseTensor struct{}
|
type RebuildSparseTensor struct{}
|
||||||
|
|
||||||
|
|
|
@ -226,6 +226,12 @@ func NewTextData(filename string) (*TextData, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tdi *TextDataIter) Progress() float32 {
|
||||||
|
startIndex := (tdi.BatchIndex * tdi.BatchSize)
|
||||||
|
availableIndices := tdi.IndexesLen
|
||||||
|
progress := float32(startIndex) / float32(availableIndices)
|
||||||
|
return progress
|
||||||
|
}
|
||||||
// Labels returns the number of different `character` (rune) used by the dataset.
|
// Labels returns the number of different `character` (rune) used by the dataset.
|
||||||
func (td *TextData) Labels() (retVal int64) {
|
func (td *TextData) Labels() (retVal int64) {
|
||||||
return int64(len(td.CharForLabel))
|
return int64(len(td.CharForLabel))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user