fix(nn/rnn): fixed memory leak and removed defaultSeq func; example/char-rnn: completed

This commit is contained in:
sugarme 2020-07-28 16:08:40 +10:00
parent 11a8dd9245
commit 74480db4d6
4 changed files with 85 additions and 32 deletions

View File

@ -21,27 +21,41 @@ const (
func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Device) (retVal string) {
labels := data.Labels()
state := lstm.ZeroState(1)
inState := lstm.ZeroState(1)
lastLabel := int64(0)
var result string
var runes []rune
for i := 0; i < int(SamplingLen); i++ {
input := ts.MustZeros([]int64{1, labels}, gotch.Float, device)
input.MustNarrow(1, lastLabel, 1, false).MustFill_(ts.FloatScalar(1.0))
state = lstm.Step(input, state)
// NOTE. `Narrow` creates tensor that shares same storage
inputView := input.MustNarrow(1, lastLabel, 1, false)
inputView.MustFill_(ts.FloatScalar(1.0))
forwardTs := linear.Forward(state.(nn.LSTMState).H())
squeeze1Ts := forwardTs.MustSqueeze1(0, false)
softmaxTs := squeeze1Ts.MustSoftmax(-1, gotch.Float, false)
sampledY := softmaxTs.MustMultinomial(1, false, false)
state := lstm.Step(input, inState)
// 1. Delete inState tensors (from C land memory)
inState.(nn.LSTMState).Tensor1.MustDrop()
inState.(nn.LSTMState).Tensor2.MustDrop()
// 2. Then update with current state
inState = state
// 3. Delete intermediate tensors
input.MustDrop()
inputView.MustDrop()
forwardTs := linear.Forward(state.(nn.LSTMState).H()).MustSqueeze1(0, true).MustSoftmax(-1, gotch.Float, true)
sampledY := forwardTs.MustMultinomial(1, false, true)
lastLabel = sampledY.Int64Values()[0]
sampledY.MustDrop()
char := data.LabelForChar(lastLabel)
result += fmt.Sprintf("%v", lastLabel)
runes = append(runes, char)
}
return result
// Delete the last state
inState.(nn.LSTMState).Tensor1.MustDrop()
inState.(nn.LSTMState).Tensor2.MustDrop()
return string(runes)
}
func main() {
@ -59,6 +73,7 @@ func main() {
lstm := nn.NewLSTM(vs.Root(), labels, HiddenSize, nn.DefaultRNNConfig())
linear := nn.NewLinear(vs.Root(), HiddenSize, labels, nn.DefaultLinearConfig())
optConfig := nn.DefaultAdamConfig()
opt, err := optConfig.Build(vs, LearningRate)
if err != nil {
@ -86,7 +101,12 @@ func main() {
ysTmp3 := ysTmp2.MustTo(device, true)
ys := ysTmp3.MustView([]int64{BatchSize * SeqLen}, true)
lstmOut, _ := lstm.Seq(xsOnehot)
lstmOut, outState := lstm.Seq(xsOnehot)
// NOTE. Although outState will not be used. There a hidden memory usage
// on C land memory that is needed to free up. Don't use `_`
outState.(nn.LSTMState).Tensor1.MustDrop()
outState.(nn.LSTMState).Tensor2.MustDrop()
logits := linear.Forward(lstmOut)
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true)
@ -96,18 +116,26 @@ func main() {
sumLoss += loss.Float64Values()[0]
cntLoss += 1.0
batchTs.MustDrop()
batchNarrow.MustDrop()
xsOnehotTmp.MustDrop()
xsOnehot.MustDrop()
lstmOut.MustDrop()
ys.MustDrop()
lstmOut.MustDrop()
loss.MustDrop()
batchCount++
fmt.Printf("Batch %v - sumLoss: %v - cntLoss %v\n", batchCount, sumLoss, cntLoss)
if batchCount%500 == 0 {
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
}
}
fmt.Printf("Epoch %v - Loss: %v", epoch, sumLoss/cntLoss)
fmt.Printf("Sample: %v", sample(data, lstm, linear, device))
sampleStr := sample(data, lstm, linear, device)
fmt.Printf("Epoch %v - Loss: %v \n", epoch, sumLoss/cntLoss)
fmt.Println(sampleStr)
dataIter.Data.MustDrop()
dataIter.Indexes.MustDrop()
}
}

View File

@ -29,13 +29,6 @@ type RNN interface {
SeqInit(input ts.Tensor, inState State) (ts.Tensor, State)
}
func defaultSeq(self interface{}, input ts.Tensor) (ts.Tensor, State) {
batchDim := input.MustSize()[0]
inState := self.(RNN).ZeroState(batchDim)
return self.(RNN).SeqInit(input, inState)
}
// The state for a LSTM network, this contains two tensors.
type LSTMState struct {
Tensor1 ts.Tensor
@ -141,10 +134,14 @@ func (l LSTM) ZeroState(batchDim int64) (retVal State) {
shape := []int64{layerDim, batchDim, l.hiddenDim}
zeros := ts.MustZeros(shape, gotch.Float, l.device)
return LSTMState{
retVal = LSTMState{
Tensor1: zeros.MustShallowClone(),
Tensor2: zeros.MustShallowClone(),
}
zeros.MustDrop()
return retVal
}
func (l LSTM) Step(input ts.Tensor, inState State) (retVal State) {
@ -159,8 +156,17 @@ func (l LSTM) Step(input ts.Tensor, inState State) (retVal State) {
return state
}
func (l LSTM) Seq(input ts.Tensor) (ts.Tensor, State) {
return defaultSeq(l, input)
func (l LSTM) Seq(input ts.Tensor) (output ts.Tensor, state State) {
batchDim := input.MustSize()[0]
inState := l.ZeroState(batchDim)
output, state = l.SeqInit(input, inState)
// Delete intermediate tensors in inState
inState.(LSTMState).Tensor1.MustDrop()
inState.(LSTMState).Tensor2.MustDrop()
return output, state
}
func (l LSTM) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) {
@ -254,8 +260,17 @@ func (g GRU) Step(input ts.Tensor, inState State) (retVal State) {
return state
}
func (g GRU) Seq(input ts.Tensor) (ts.Tensor, State) {
return defaultSeq(g, input)
func (g GRU) Seq(input ts.Tensor) (output ts.Tensor, state State) {
batchDim := input.MustSize()[0]
inState := g.ZeroState(batchDim)
output, state = g.SeqInit(input, inState)
// Delete intermediate tensors in inState
inState.(LSTMState).Tensor1.MustDrop()
inState.(LSTMState).Tensor2.MustDrop()
return output, state
}
func (g GRU) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) {

View File

@ -269,6 +269,7 @@ func (tdi *TextDataIter) Next() (retVal Tensor, ok bool) {
indexesTs := tdi.Indexes.Idx(narrowIdx)
indexes := indexesTs.Int64Values()
indexesTs.MustDrop()
var batch []Tensor
@ -280,6 +281,11 @@ func (tdi *TextDataIter) Next() (retVal Tensor, ok bool) {
retVal = MustStack(batch, 0)
// Delete intermediate tensors
for _, xs := range batch {
xs.MustDrop()
}
return retVal, true
}

View File

@ -621,7 +621,7 @@ func (ts Tensor) Numel() (retVal uint) {
return uint(FlattenDim(shape))
}
// ShallowCopy returns a new tensor that share storage with the input tensor.
// ShallowClone returns a new tensor that share storage with the input tensor.
func (ts Tensor) ShallowClone() (retVal Tensor, err error) {
ctensor := lib.AtShallowClone(ts.ctensor)
@ -1120,10 +1120,14 @@ func (ts Tensor) MustZeroPad2d(left, right, top, bottom int64, del bool) (retVal
func (ts Tensor) Onehot(labels int64) (retVal Tensor) {
dims := ts.MustSize()
dims = append(dims, labels)
unsqueezeTs := ts.MustUnsqueeze(-1, false).MustTotype(gotch.Int64, true)
unsqueezeTs := ts.MustUnsqueeze(-1, false)
inputTs := unsqueezeTs.MustTotype(gotch.Int64, true)
zerosTs := MustZeros(dims, gotch.Float, gotch.CPU)
zerosTs.MustScatter1_(-1, unsqueezeTs, FloatScalar(1.0))
return zerosTs
retVal = zerosTs.MustScatter1(-1, inputTs, FloatScalar(1.0), true)
inputTs.MustDrop()
return retVal
}
func (ts Tensor) Swish() (retVal Tensor) {