fixed #88 - example/char-rnn memory leak

This commit is contained in:
sugarme 2023-01-31 17:07:50 +11:00
parent 17f2c49e34
commit f8ce489c09
3 changed files with 14 additions and 21 deletions

View File

@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `nn.MSELoss()`
- reworked `ts.Format()`
- Added conv2d benchmark
- Fixed #88 memory leak at `example/char-rnn`
## [Nofix]
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.

View File

@ -11,7 +11,7 @@ At the end of each training epoch, some sample text is generated and printed.
Any text file can be used as an input, as long as it's large enough for training.
A typical example would be the
[tiny Shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt).
The training text file should be stored in `data/input.txt`.
The training text file should be stored in `data/char-rnn/input.txt`.
To run the example:

View File

@ -14,12 +14,11 @@ const (
HiddenSize int64 = 256
SeqLen int64 = 180
BatchSize int64 = 256
Epochs int = 100
Epochs int = 3
SamplingLen int64 = 1024
)
func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.Device) string {
labels := data.Labels()
inState := lstm.ZeroState(1)
lastLabel := int64(0)
@ -59,13 +58,12 @@ func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.De
}
func main() {
cuda := gotch.NewCuda()
device := cuda.CudaIfAvailable()
device := gotch.CudaIfAvailable()
vs := nn.NewVarStore(device)
data, err := ts.NewTextData("../../data/char-rnn/input.txt")
if err != nil {
log.Fatal(err)
panic(err)
}
labels := data.Labels()
@ -94,41 +92,36 @@ func main() {
}
batchNarrow := batchTs.MustNarrow(1, 0, SeqLen, false)
xsOnehotTmp := batchNarrow.Onehot(labels)
xsOnehot := xsOnehotTmp.MustTo(device, true) // shape: [256, 180, 65]
ysTmp1 := batchTs.MustNarrow(1, 1, SeqLen, true)
ysTmp2 := ysTmp1.MustTotype(gotch.Int64, true)
ysTmp3 := ysTmp2.MustTo(device, true)
ys := ysTmp3.MustView([]int64{BatchSize * SeqLen}, true)
xsOnehot := batchNarrow.Onehot(labels).MustTo(device, true) // [256, 180, 65]
batchNarrow.MustDrop()
ys := batchTs.MustNarrow(1, 1, SeqLen, true).MustTotype(gotch.Int64, true).MustTo(device, true).MustView([]int64{BatchSize * SeqLen}, true)
lstmOut, outState := lstm.Seq(xsOnehot)
// NOTE. Although outState will not be used. There a hidden memory usage
// on C land memory that is needed to free up. Don't use `_`
outState.(*nn.LSTMState).Tensor1.MustDrop()
outState.(*nn.LSTMState).Tensor2.MustDrop()
xsOnehot.MustDrop()
logits := linear.Forward(lstmOut)
lstmOut.MustDrop()
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true)
loss := lossView.CrossEntropyForLogits(ys)
ys.MustDrop()
lossView.MustDrop()
opt.BackwardStepClip(loss, 0.5)
sumLoss += loss.Float64Values()[0]
cntLoss += 1.0
// batchTs.MustDrop()
// batchNarrow.MustDrop()
// xsOnehotTmp.MustDrop()
xsOnehot.MustDrop()
ys.MustDrop()
lstmOut.MustDrop()
loss.MustDrop()
batchCount++
if batchCount%500 == 0 {
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
}
}
} // infinite for-loop
sampleStr := sample(data, lstm, linear, device)
fmt.Printf("Epoch %v - Loss: %v \n", epoch, sumLoss/cntLoss)
@ -137,5 +130,4 @@ func main() {
dataIter.Data.MustDrop()
dataIter.Indexes.MustDrop()
}
}