fixed #88 - example/char-rnn memory leak
This commit is contained in:
parent
17f2c49e34
commit
f8ce489c09
|
@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- Added `nn.MSELoss()`
|
||||
- reworked `ts.Format()`
|
||||
- Added conv2d benchmark
|
||||
- Fixed #88 memory leak at `example/char-rnn`
|
||||
|
||||
## [Nofix]
|
||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||
|
|
|
@ -11,7 +11,7 @@ At the end of each training epoch, some sample text is generated and printed.
|
|||
Any text file can be used as an input, as long as it's large enough for training.
|
||||
A typical example would be the
|
||||
[tiny Shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt).
|
||||
The training text file should be stored in `data/input.txt`.
|
||||
The training text file should be stored in `data/char-rnn/input.txt`.
|
||||
|
||||
To run the example:
|
||||
|
||||
|
|
|
@ -14,12 +14,11 @@ const (
|
|||
HiddenSize int64 = 256
|
||||
SeqLen int64 = 180
|
||||
BatchSize int64 = 256
|
||||
Epochs int = 100
|
||||
Epochs int = 3
|
||||
SamplingLen int64 = 1024
|
||||
)
|
||||
|
||||
func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.Device) string {
|
||||
|
||||
labels := data.Labels()
|
||||
inState := lstm.ZeroState(1)
|
||||
lastLabel := int64(0)
|
||||
|
@ -59,13 +58,12 @@ func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.De
|
|||
}
|
||||
|
||||
func main() {
|
||||
cuda := gotch.NewCuda()
|
||||
device := cuda.CudaIfAvailable()
|
||||
device := gotch.CudaIfAvailable()
|
||||
|
||||
vs := nn.NewVarStore(device)
|
||||
data, err := ts.NewTextData("../../data/char-rnn/input.txt")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
labels := data.Labels()
|
||||
|
@ -94,41 +92,36 @@ func main() {
|
|||
}
|
||||
|
||||
batchNarrow := batchTs.MustNarrow(1, 0, SeqLen, false)
|
||||
xsOnehotTmp := batchNarrow.Onehot(labels)
|
||||
xsOnehot := xsOnehotTmp.MustTo(device, true) // shape: [256, 180, 65]
|
||||
ysTmp1 := batchTs.MustNarrow(1, 1, SeqLen, true)
|
||||
ysTmp2 := ysTmp1.MustTotype(gotch.Int64, true)
|
||||
ysTmp3 := ysTmp2.MustTo(device, true)
|
||||
ys := ysTmp3.MustView([]int64{BatchSize * SeqLen}, true)
|
||||
xsOnehot := batchNarrow.Onehot(labels).MustTo(device, true) // [256, 180, 65]
|
||||
batchNarrow.MustDrop()
|
||||
|
||||
ys := batchTs.MustNarrow(1, 1, SeqLen, true).MustTotype(gotch.Int64, true).MustTo(device, true).MustView([]int64{BatchSize * SeqLen}, true)
|
||||
|
||||
lstmOut, outState := lstm.Seq(xsOnehot)
|
||||
// NOTE. Although outState will not be used. There a hidden memory usage
|
||||
// on C land memory that is needed to free up. Don't use `_`
|
||||
outState.(*nn.LSTMState).Tensor1.MustDrop()
|
||||
outState.(*nn.LSTMState).Tensor2.MustDrop()
|
||||
xsOnehot.MustDrop()
|
||||
|
||||
logits := linear.Forward(lstmOut)
|
||||
lstmOut.MustDrop()
|
||||
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true)
|
||||
|
||||
loss := lossView.CrossEntropyForLogits(ys)
|
||||
ys.MustDrop()
|
||||
lossView.MustDrop()
|
||||
|
||||
opt.BackwardStepClip(loss, 0.5)
|
||||
sumLoss += loss.Float64Values()[0]
|
||||
cntLoss += 1.0
|
||||
|
||||
// batchTs.MustDrop()
|
||||
// batchNarrow.MustDrop()
|
||||
// xsOnehotTmp.MustDrop()
|
||||
xsOnehot.MustDrop()
|
||||
ys.MustDrop()
|
||||
lstmOut.MustDrop()
|
||||
loss.MustDrop()
|
||||
|
||||
batchCount++
|
||||
if batchCount%500 == 0 {
|
||||
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
|
||||
}
|
||||
}
|
||||
} // infinite for-loop
|
||||
|
||||
sampleStr := sample(data, lstm, linear, device)
|
||||
fmt.Printf("Epoch %v - Loss: %v \n", epoch, sumLoss/cntLoss)
|
||||
|
@ -137,5 +130,4 @@ func main() {
|
|||
dataIter.Data.MustDrop()
|
||||
dataIter.Indexes.MustDrop()
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user