fixed #88 - example/char-rnn memory leak
This commit is contained in:
parent
17f2c49e34
commit
f8ce489c09
|
@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
- Added `nn.MSELoss()`
|
- Added `nn.MSELoss()`
|
||||||
- reworked `ts.Format()`
|
- reworked `ts.Format()`
|
||||||
- Added conv2d benchmark
|
- Added conv2d benchmark
|
||||||
|
- Fixed #88 memory leak at `example/char-rnn`
|
||||||
|
|
||||||
## [Nofix]
|
## [Nofix]
|
||||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||||
|
|
|
@ -11,7 +11,7 @@ At the end of each training epoch, some sample text is generated and printed.
|
||||||
Any text file can be used as an input, as long as it's large enough for training.
|
Any text file can be used as an input, as long as it's large enough for training.
|
||||||
A typical example would be the
|
A typical example would be the
|
||||||
[tiny Shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt).
|
[tiny Shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt).
|
||||||
The training text file should be stored in `data/input.txt`.
|
The training text file should be stored in `data/char-rnn/input.txt`.
|
||||||
|
|
||||||
To run the example:
|
To run the example:
|
||||||
|
|
||||||
|
|
|
@ -14,12 +14,11 @@ const (
|
||||||
HiddenSize int64 = 256
|
HiddenSize int64 = 256
|
||||||
SeqLen int64 = 180
|
SeqLen int64 = 180
|
||||||
BatchSize int64 = 256
|
BatchSize int64 = 256
|
||||||
Epochs int = 100
|
Epochs int = 3
|
||||||
SamplingLen int64 = 1024
|
SamplingLen int64 = 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.Device) string {
|
func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.Device) string {
|
||||||
|
|
||||||
labels := data.Labels()
|
labels := data.Labels()
|
||||||
inState := lstm.ZeroState(1)
|
inState := lstm.ZeroState(1)
|
||||||
lastLabel := int64(0)
|
lastLabel := int64(0)
|
||||||
|
@ -59,13 +58,12 @@ func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.De
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
cuda := gotch.NewCuda()
|
device := gotch.CudaIfAvailable()
|
||||||
device := cuda.CudaIfAvailable()
|
|
||||||
|
|
||||||
vs := nn.NewVarStore(device)
|
vs := nn.NewVarStore(device)
|
||||||
data, err := ts.NewTextData("../../data/char-rnn/input.txt")
|
data, err := ts.NewTextData("../../data/char-rnn/input.txt")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
labels := data.Labels()
|
labels := data.Labels()
|
||||||
|
@ -94,41 +92,36 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
batchNarrow := batchTs.MustNarrow(1, 0, SeqLen, false)
|
batchNarrow := batchTs.MustNarrow(1, 0, SeqLen, false)
|
||||||
xsOnehotTmp := batchNarrow.Onehot(labels)
|
xsOnehot := batchNarrow.Onehot(labels).MustTo(device, true) // [256, 180, 65]
|
||||||
xsOnehot := xsOnehotTmp.MustTo(device, true) // shape: [256, 180, 65]
|
batchNarrow.MustDrop()
|
||||||
ysTmp1 := batchTs.MustNarrow(1, 1, SeqLen, true)
|
|
||||||
ysTmp2 := ysTmp1.MustTotype(gotch.Int64, true)
|
ys := batchTs.MustNarrow(1, 1, SeqLen, true).MustTotype(gotch.Int64, true).MustTo(device, true).MustView([]int64{BatchSize * SeqLen}, true)
|
||||||
ysTmp3 := ysTmp2.MustTo(device, true)
|
|
||||||
ys := ysTmp3.MustView([]int64{BatchSize * SeqLen}, true)
|
|
||||||
|
|
||||||
lstmOut, outState := lstm.Seq(xsOnehot)
|
lstmOut, outState := lstm.Seq(xsOnehot)
|
||||||
// NOTE. Although outState will not be used. There a hidden memory usage
|
// NOTE. Although outState will not be used. There a hidden memory usage
|
||||||
// on C land memory that is needed to free up. Don't use `_`
|
// on C land memory that is needed to free up. Don't use `_`
|
||||||
outState.(*nn.LSTMState).Tensor1.MustDrop()
|
outState.(*nn.LSTMState).Tensor1.MustDrop()
|
||||||
outState.(*nn.LSTMState).Tensor2.MustDrop()
|
outState.(*nn.LSTMState).Tensor2.MustDrop()
|
||||||
|
xsOnehot.MustDrop()
|
||||||
|
|
||||||
logits := linear.Forward(lstmOut)
|
logits := linear.Forward(lstmOut)
|
||||||
|
lstmOut.MustDrop()
|
||||||
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true)
|
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true)
|
||||||
|
|
||||||
loss := lossView.CrossEntropyForLogits(ys)
|
loss := lossView.CrossEntropyForLogits(ys)
|
||||||
|
ys.MustDrop()
|
||||||
|
lossView.MustDrop()
|
||||||
|
|
||||||
opt.BackwardStepClip(loss, 0.5)
|
opt.BackwardStepClip(loss, 0.5)
|
||||||
sumLoss += loss.Float64Values()[0]
|
sumLoss += loss.Float64Values()[0]
|
||||||
cntLoss += 1.0
|
cntLoss += 1.0
|
||||||
|
|
||||||
// batchTs.MustDrop()
|
|
||||||
// batchNarrow.MustDrop()
|
|
||||||
// xsOnehotTmp.MustDrop()
|
|
||||||
xsOnehot.MustDrop()
|
|
||||||
ys.MustDrop()
|
|
||||||
lstmOut.MustDrop()
|
|
||||||
loss.MustDrop()
|
loss.MustDrop()
|
||||||
|
|
||||||
batchCount++
|
batchCount++
|
||||||
if batchCount%500 == 0 {
|
if batchCount%500 == 0 {
|
||||||
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
|
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
|
||||||
}
|
}
|
||||||
}
|
} // infinite for-loop
|
||||||
|
|
||||||
sampleStr := sample(data, lstm, linear, device)
|
sampleStr := sample(data, lstm, linear, device)
|
||||||
fmt.Printf("Epoch %v - Loss: %v \n", epoch, sumLoss/cntLoss)
|
fmt.Printf("Epoch %v - Loss: %v \n", epoch, sumLoss/cntLoss)
|
||||||
|
@ -137,5 +130,4 @@ func main() {
|
||||||
dataIter.Data.MustDrop()
|
dataIter.Data.MustDrop()
|
||||||
dataIter.Indexes.MustDrop()
|
dataIter.Indexes.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user