Merge branch 'master' of ssh://github.com/sugarme/gotch into mem
This commit is contained in:
commit
4aad97b216
|
@ -121,6 +121,7 @@ func main() {
|
|||
if batchCount%500 == 0 {
|
||||
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
|
||||
}
|
||||
fmt.Printf("dataIter: progress: %v\n", dataIter.Progress())
|
||||
} // infinite for-loop
|
||||
|
||||
sampleStr := sample(data, lstm, linear, device)
|
||||
|
|
|
@ -10,6 +10,9 @@ import (
|
|||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"reflect"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// This file implements Python Pickle Machinery.
|
||||
|
@ -552,6 +555,7 @@ func loadBinInt1(up *Unpickler) error {
|
|||
err = fmt.Errorf("loadBinInt1() failed: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
up.append(int(b))
|
||||
|
||||
return nil
|
||||
|
@ -1007,10 +1011,9 @@ func loadTuple3(up *Unpickler) error {
|
|||
objects := make([]interface{}, 3)
|
||||
var (
|
||||
err error
|
||||
n int = 0
|
||||
)
|
||||
for i := 2; i >= 0; i-- {
|
||||
objects[n], err = up.stackPop()
|
||||
objects[i], err = up.stackPop()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("loadTuple3() failed: %w", err)
|
||||
}
|
||||
|
@ -1251,6 +1254,7 @@ func loadGlobal(up *Unpickler) error {
|
|||
err = fmt.Errorf("loadGlobal() failed: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
up.append(class)
|
||||
return nil
|
||||
}
|
||||
|
@ -1378,6 +1382,7 @@ func loadReduce(up *Unpickler) error {
|
|||
err = fmt.Errorf("loadReduce() failed: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
callable, callableOk := function.(Callable)
|
||||
if !callableOk {
|
||||
err := fmt.Errorf("REDUCE requires a Callable object: %#v", function)
|
||||
|
@ -1935,3 +1940,7 @@ func Loads(s string) (interface{}, error) {
|
|||
|
||||
return up.Load()
|
||||
}
|
||||
|
||||
func GetFunctionName(i interface{}) string {
|
||||
return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name()
|
||||
}
|
||||
|
|
|
@ -51,49 +51,135 @@ func Decode(filename string) (map[string]*ts.Tensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
dtype := reflect.TypeOf(result).String()
|
||||
|
||||
// Rebuild tensors from Storage tensors
|
||||
namedTensors := make(map[string]*ts.Tensor)
|
||||
dictResult, isOrderedDict := result.(*OrderedDict)
|
||||
if !isOrderedDict {
|
||||
err := fmt.Errorf("Decode() failed: expected 'OrderedDict' type, got %v\n", reflect.TypeOf(result).String())
|
||||
switch dtype {
|
||||
case "*pickle.Dict":
|
||||
dictResult := *result.(*Dict)
|
||||
for _, item := range dictResult {
|
||||
name := item.Key
|
||||
itemTyp := reflect.TypeOf(item.Value).String()
|
||||
switch itemTyp {
|
||||
case "*pickle.Dict": // Nested *pickle.Dict case
|
||||
subResult := *item.Value.(*Dict)
|
||||
for _, subItem := range subResult {
|
||||
subName := subItem.Key
|
||||
x, ok := subItem.Value.(*StorageTensor)
|
||||
if !ok {
|
||||
log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v. Skip decoding parameter %q ...\n", reflect.TypeOf(subItem.Value).String(), subName)
|
||||
continue
|
||||
}
|
||||
data := x.Source.GetData()
|
||||
size := x.Size
|
||||
dtype := x.Source.DType()
|
||||
device := x.Source.Device()
|
||||
stride := x.Stride
|
||||
storageOffset := x.StorageOffset
|
||||
if reflect.ValueOf(data).Len() == 0 {
|
||||
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO. should we just skip them?
|
||||
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
|
||||
size = []int64{1}
|
||||
stride = []int64{1}
|
||||
}
|
||||
|
||||
x1 := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
||||
if x.RequiresGrad {
|
||||
x1.MustRequiresGrad_(x.RequiresGrad)
|
||||
}
|
||||
|
||||
namedTensors[name.(string)] = x1
|
||||
}
|
||||
|
||||
default:
|
||||
sx, isStorageTensor := item.Value.(*StorageTensor)
|
||||
|
||||
// if !isStorageTensor {
|
||||
// err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
|
||||
// return nil, err
|
||||
// }
|
||||
if !isStorageTensor {
|
||||
log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v, with value of %v. Skip decoding parameter %q ...\n", reflect.TypeOf(item.Value).String(), item.Value, name)
|
||||
continue
|
||||
}
|
||||
|
||||
data := sx.Source.GetData()
|
||||
size := sx.Size
|
||||
dtype := sx.Source.DType()
|
||||
device := sx.Source.Device()
|
||||
stride := sx.Stride
|
||||
storageOffset := sx.StorageOffset
|
||||
|
||||
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
|
||||
// log.Printf("data: %v\n", data)
|
||||
|
||||
// Dealing with Pytorch `..._tracked` variables.
|
||||
if reflect.ValueOf(data).Len() == 0 {
|
||||
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO. should we just skip them?
|
||||
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
|
||||
size = []int64{1}
|
||||
stride = []int64{1}
|
||||
}
|
||||
|
||||
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
||||
if sx.RequiresGrad {
|
||||
x.MustRequiresGrad_(sx.RequiresGrad)
|
||||
}
|
||||
|
||||
namedTensors[name.(string)] = x
|
||||
}
|
||||
}
|
||||
case "*pickle.OrderedDict":
|
||||
dictResult := result.(*OrderedDict)
|
||||
for name, item := range dictResult.Map {
|
||||
sx, isStorageTensor := item.Value.(*StorageTensor)
|
||||
if !isStorageTensor {
|
||||
err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data := sx.Source.GetData()
|
||||
size := sx.Size
|
||||
dtype := sx.Source.DType()
|
||||
device := sx.Source.Device()
|
||||
stride := sx.Stride
|
||||
storageOffset := sx.StorageOffset
|
||||
|
||||
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
|
||||
// log.Printf("data: %v\n", data)
|
||||
|
||||
// Dealing with Pytorch `..._tracked` variables.
|
||||
if reflect.ValueOf(data).Len() == 0 {
|
||||
log.Printf("INFO: skip weigth %q with zero data length.\n", name.(string))
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO. should we just skip them?
|
||||
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
|
||||
size = []int64{1}
|
||||
stride = []int64{1}
|
||||
}
|
||||
|
||||
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
||||
if sx.RequiresGrad {
|
||||
x.MustRequiresGrad_(sx.RequiresGrad)
|
||||
}
|
||||
|
||||
namedTensors[name.(string)] = x
|
||||
}
|
||||
default:
|
||||
err := fmt.Errorf("Decode() failed: expected '*pickle.OrderedDict' or '*pickle.Dict' type, got %v\n", dtype)
|
||||
return nil, err
|
||||
}
|
||||
for name, item := range dictResult.Map {
|
||||
sx, isStorageTensor := item.Value.(*StorageTensor)
|
||||
if !isStorageTensor {
|
||||
err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data := sx.Source.GetData()
|
||||
size := sx.Size
|
||||
dtype := sx.Source.DType()
|
||||
device := sx.Source.Device()
|
||||
stride := sx.Stride
|
||||
storageOffset := sx.StorageOffset
|
||||
|
||||
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
|
||||
// log.Printf("data: %v\n", data)
|
||||
|
||||
// Dealing with Pytorch `..._tracked` variables.
|
||||
if reflect.ValueOf(data).Len() == 0 {
|
||||
log.Printf("INFO: skip weigth %q with zero data length.\n", name.(string))
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO. should we just skip them?
|
||||
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
|
||||
size = []int64{1}
|
||||
stride = []int64{1}
|
||||
}
|
||||
|
||||
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
||||
if sx.RequiresGrad {
|
||||
x.MustRequiresGrad_(sx.RequiresGrad)
|
||||
}
|
||||
|
||||
namedTensors[name.(string)] = x
|
||||
}
|
||||
|
||||
return namedTensors, nil
|
||||
}
|
||||
|
@ -161,6 +247,7 @@ func loadZipFile(filename string, newUnpickler func(r io.Reader) Unpickler) (int
|
|||
if !dataTypeOk || !keyOk || !locationOk || !sizeOk {
|
||||
return nil, fmt.Errorf("PersistentLoad: unexpected data types")
|
||||
}
|
||||
|
||||
storage, storageExists := loadedStorages[key]
|
||||
if !storageExists {
|
||||
storage, err = loadTensor(dataType, size, location, key, fileRecords)
|
||||
|
|
|
@ -567,7 +567,7 @@ func setFromFile(s Storage, r io.Reader) error {
|
|||
}
|
||||
|
||||
// StorageTensor:
|
||||
//===============
|
||||
// ===============
|
||||
type StorageTensor struct {
|
||||
Source Storage
|
||||
StorageOffset int64
|
||||
|
@ -629,7 +629,9 @@ func (r *RebuildTensorV2) Call(args ...interface{}) (interface{}, error) {
|
|||
storageOffset, storageOffsetOk := args[1].(int)
|
||||
size, sizeOk := args[2].(*Tuple)
|
||||
stride, strideOk := args[3].(*Tuple)
|
||||
|
||||
requiresGrad, requiresGradOk := args[4].(bool)
|
||||
|
||||
// arg[5] "backward hooks" is unused
|
||||
if !storageOk || !storageOffsetOk || !sizeOk || !strideOk ||
|
||||
!requiresGradOk {
|
||||
|
@ -681,7 +683,9 @@ func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
|
|||
for i := 0; i < length; i++ {
|
||||
value, ok := tuple.Get(i).(int)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tuple of ints expected: %#v", tuple)
|
||||
// return nil, fmt.Errorf("tuple of ints expected. Got %#v", tuple)
|
||||
fmt.Printf("WARNING: tuple of ints expected. Got %#v\n", tuple)
|
||||
continue
|
||||
}
|
||||
slice[i] = int64(value)
|
||||
}
|
||||
|
@ -689,7 +693,7 @@ func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
|
|||
}
|
||||
|
||||
// Rebuild Sparse Tensor:
|
||||
//=======================
|
||||
// =======================
|
||||
// ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L178
|
||||
type RebuildSparseTensor struct{}
|
||||
|
||||
|
|
|
@ -226,6 +226,12 @@ func NewTextData(filename string) (*TextData, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (tdi *TextDataIter) Progress() float32 {
|
||||
startIndex := (tdi.BatchIndex * tdi.BatchSize)
|
||||
availableIndices := tdi.IndexesLen
|
||||
progress := float32(startIndex) / float32(availableIndices)
|
||||
return progress
|
||||
}
|
||||
// Labels returns the number of different `character` (rune) used by the dataset.
|
||||
func (td *TextData) Labels() (retVal int64) {
|
||||
return int64(len(td.CharForLabel))
|
||||
|
|
Loading…
Reference in New Issue
Block a user