fixed #88 - example/char-rnn memory leak

This commit is contained in:
sugarme 2023-01-31 17:07:50 +11:00
parent 17f2c49e34
commit f8ce489c09
3 changed files with 14 additions and 21 deletions

View File

@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `nn.MSELoss()` - Added `nn.MSELoss()`
- reworked `ts.Format()` - reworked `ts.Format()`
- Added conv2d benchmark - Added conv2d benchmark
- Fixed #88 memory leak at `example/char-rnn`
## [Nofix] ## [Nofix]
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box. - ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.

View File

@ -11,7 +11,7 @@ At the end of each training epoch, some sample text is generated and printed.
Any text file can be used as an input, as long as it's large enough for training. Any text file can be used as an input, as long as it's large enough for training.
A typical example would be the A typical example would be the
[tiny Shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt). [tiny Shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt).
The training text file should be stored in `data/input.txt`. The training text file should be stored in `data/char-rnn/input.txt`.
To run the example: To run the example:

View File

@ -14,12 +14,11 @@ const (
HiddenSize int64 = 256 HiddenSize int64 = 256
SeqLen int64 = 180 SeqLen int64 = 180
BatchSize int64 = 256 BatchSize int64 = 256
Epochs int = 100 Epochs int = 3
SamplingLen int64 = 1024 SamplingLen int64 = 1024
) )
func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.Device) string { func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.Device) string {
labels := data.Labels() labels := data.Labels()
inState := lstm.ZeroState(1) inState := lstm.ZeroState(1)
lastLabel := int64(0) lastLabel := int64(0)
@ -59,13 +58,12 @@ func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.De
} }
func main() { func main() {
cuda := gotch.NewCuda() device := gotch.CudaIfAvailable()
device := cuda.CudaIfAvailable()
vs := nn.NewVarStore(device) vs := nn.NewVarStore(device)
data, err := ts.NewTextData("../../data/char-rnn/input.txt") data, err := ts.NewTextData("../../data/char-rnn/input.txt")
if err != nil { if err != nil {
log.Fatal(err) panic(err)
} }
labels := data.Labels() labels := data.Labels()
@ -94,41 +92,36 @@ func main() {
} }
batchNarrow := batchTs.MustNarrow(1, 0, SeqLen, false) batchNarrow := batchTs.MustNarrow(1, 0, SeqLen, false)
xsOnehotTmp := batchNarrow.Onehot(labels) xsOnehot := batchNarrow.Onehot(labels).MustTo(device, true) // [256, 180, 65]
xsOnehot := xsOnehotTmp.MustTo(device, true) // shape: [256, 180, 65] batchNarrow.MustDrop()
ysTmp1 := batchTs.MustNarrow(1, 1, SeqLen, true)
ysTmp2 := ysTmp1.MustTotype(gotch.Int64, true) ys := batchTs.MustNarrow(1, 1, SeqLen, true).MustTotype(gotch.Int64, true).MustTo(device, true).MustView([]int64{BatchSize * SeqLen}, true)
ysTmp3 := ysTmp2.MustTo(device, true)
ys := ysTmp3.MustView([]int64{BatchSize * SeqLen}, true)
lstmOut, outState := lstm.Seq(xsOnehot) lstmOut, outState := lstm.Seq(xsOnehot)
// NOTE. Although outState will not be used. There a hidden memory usage // NOTE. Although outState will not be used. There a hidden memory usage
// on C land memory that is needed to free up. Don't use `_` // on C land memory that is needed to free up. Don't use `_`
outState.(*nn.LSTMState).Tensor1.MustDrop() outState.(*nn.LSTMState).Tensor1.MustDrop()
outState.(*nn.LSTMState).Tensor2.MustDrop() outState.(*nn.LSTMState).Tensor2.MustDrop()
xsOnehot.MustDrop()
logits := linear.Forward(lstmOut) logits := linear.Forward(lstmOut)
lstmOut.MustDrop()
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true) lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true)
loss := lossView.CrossEntropyForLogits(ys) loss := lossView.CrossEntropyForLogits(ys)
ys.MustDrop()
lossView.MustDrop()
opt.BackwardStepClip(loss, 0.5) opt.BackwardStepClip(loss, 0.5)
sumLoss += loss.Float64Values()[0] sumLoss += loss.Float64Values()[0]
cntLoss += 1.0 cntLoss += 1.0
// batchTs.MustDrop()
// batchNarrow.MustDrop()
// xsOnehotTmp.MustDrop()
xsOnehot.MustDrop()
ys.MustDrop()
lstmOut.MustDrop()
loss.MustDrop() loss.MustDrop()
batchCount++ batchCount++
if batchCount%500 == 0 { if batchCount%500 == 0 {
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount) fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
} }
} } // infinite for-loop
sampleStr := sample(data, lstm, linear, device) sampleStr := sample(data, lstm, linear, device)
fmt.Printf("Epoch %v - Loss: %v \n", epoch, sumLoss/cntLoss) fmt.Printf("Epoch %v - Loss: %v \n", epoch, sumLoss/cntLoss)
@ -137,5 +130,4 @@ func main() {
dataIter.Data.MustDrop() dataIter.Data.MustDrop()
dataIter.Indexes.MustDrop() dataIter.Indexes.MustDrop()
} }
} }