add Progress() method to TextDataIter
This commit is contained in:
parent
d9a3a69e11
commit
8a08a1d6ab
|
@ -121,6 +121,7 @@ func main() {
|
||||||
if batchCount%500 == 0 {
|
if batchCount%500 == 0 {
|
||||||
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
|
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
|
||||||
}
|
}
|
||||||
|
fmt.Printf("dataIter: progress: %v\n", dataIter.Progress())
|
||||||
} // infinite for-loop
|
} // infinite for-loop
|
||||||
|
|
||||||
sampleStr := sample(data, lstm, linear, device)
|
sampleStr := sample(data, lstm, linear, device)
|
||||||
|
|
|
@ -226,6 +226,12 @@ func NewTextData(filename string) (*TextData, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tdi *TextDataIter) Progress() float32 {
|
||||||
|
startIndex := (tdi.BatchIndex * tdi.BatchSize)
|
||||||
|
availableIndices := tdi.IndexesLen
|
||||||
|
progress := float32(startIndex) / float32(availableIndices)
|
||||||
|
return progress
|
||||||
|
}
|
||||||
// Labels returns the number of different `character` (rune) used by the dataset.
|
// Labels returns the number of different `character` (rune) used by the dataset.
|
||||||
func (td *TextData) Labels() (retVal int64) {
|
func (td *TextData) Labels() (retVal int64) {
|
||||||
return int64(len(td.CharForLabel))
|
return int64(len(td.CharForLabel))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user