add Progress() method to TextDataIter
This commit is contained in:
parent
d9a3a69e11
commit
8a08a1d6ab
|
@ -121,6 +121,7 @@ func main() {
|
|||
if batchCount%500 == 0 {
|
||||
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
|
||||
}
|
||||
fmt.Printf("dataIter: progress: %v\n", dataIter.Progress())
|
||||
} // infinite for-loop
|
||||
|
||||
sampleStr := sample(data, lstm, linear, device)
|
||||
|
|
|
@ -226,6 +226,12 @@ func NewTextData(filename string) (*TextData, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (tdi *TextDataIter) Progress() float32 {
|
||||
startIndex := (tdi.BatchIndex * tdi.BatchSize)
|
||||
availableIndices := tdi.IndexesLen
|
||||
progress := float32(startIndex) / float32(availableIndices)
|
||||
return progress
|
||||
}
|
||||
// Labels returns the number of different `character` (rune) used by the dataset.
|
||||
func (td *TextData) Labels() (retVal int64) {
|
||||
return int64(len(td.CharForLabel))
|
||||
|
|
Loading…
Reference in New Issue
Block a user