add Progress() method to TextDataIter

This commit is contained in:
Tim Cassidy 2023-02-10 19:30:09 -08:00
parent d9a3a69e11
commit 8a08a1d6ab
2 changed files with 7 additions and 0 deletions

View File

@ -121,6 +121,7 @@ func main() {
if batchCount%500 == 0 {
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
}
fmt.Printf("dataIter: progress: %v\n", dataIter.Progress())
} // infinite for-loop
sampleStr := sample(data, lstm, linear, device)

View File

@ -226,6 +226,12 @@ func NewTextData(filename string) (*TextData, error) {
}, nil
}
func (tdi *TextDataIter) Progress() float32 {
startIndex := (tdi.BatchIndex * tdi.BatchSize)
availableIndices := tdi.IndexesLen
progress := float32(startIndex) / float32(availableIndices)
return progress
}
// Labels returns the number of different `character` (rune) used by the dataset.
func (td *TextData) Labels() (retVal int64) {
return int64(len(td.CharForLabel))