diff --git a/example/char-rnn/main.go b/example/char-rnn/main.go index 9ed957c..7d187ee 100644 --- a/example/char-rnn/main.go +++ b/example/char-rnn/main.go @@ -121,6 +121,7 @@ func main() { if batchCount%500 == 0 { fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount) } + fmt.Printf("dataIter: progress: %v\n", dataIter.Progress()) } // infinite for-loop sampleStr := sample(data, lstm, linear, device) diff --git a/ts/data.go b/ts/data.go index 493a6d3..16dd97e 100644 --- a/ts/data.go +++ b/ts/data.go @@ -226,6 +226,12 @@ func NewTextData(filename string) (*TextData, error) { }, nil } +func (tdi *TextDataIter) Progress() float32 { + startIndex := (tdi.BatchIndex * tdi.BatchSize) + availableIndices := tdi.IndexesLen + progress := float32(startIndex) / float32(availableIndices) + return progress +} // Labels returns the number of different `character` (rune) used by the dataset. func (td *TextData) Labels() (retVal int64) { return int64(len(td.CharForLabel))