fix(nn/rnn): fixed memory leak and removed defaultSeq func; example/char-rnn: completed
This commit is contained in:
parent
11a8dd9245
commit
74480db4d6
|
@ -21,27 +21,41 @@ const (
|
|||
func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Device) (retVal string) {
|
||||
|
||||
labels := data.Labels()
|
||||
state := lstm.ZeroState(1)
|
||||
inState := lstm.ZeroState(1)
|
||||
lastLabel := int64(0)
|
||||
var result string
|
||||
var runes []rune
|
||||
|
||||
for i := 0; i < int(SamplingLen); i++ {
|
||||
|
||||
input := ts.MustZeros([]int64{1, labels}, gotch.Float, device)
|
||||
input.MustNarrow(1, lastLabel, 1, false).MustFill_(ts.FloatScalar(1.0))
|
||||
state = lstm.Step(input, state)
|
||||
// NOTE. `Narrow` creates tensor that shares same storage
|
||||
inputView := input.MustNarrow(1, lastLabel, 1, false)
|
||||
inputView.MustFill_(ts.FloatScalar(1.0))
|
||||
|
||||
forwardTs := linear.Forward(state.(nn.LSTMState).H())
|
||||
squeeze1Ts := forwardTs.MustSqueeze1(0, false)
|
||||
softmaxTs := squeeze1Ts.MustSoftmax(-1, gotch.Float, false)
|
||||
sampledY := softmaxTs.MustMultinomial(1, false, false)
|
||||
state := lstm.Step(input, inState)
|
||||
|
||||
// 1. Delete inState tensors (from C land memory)
|
||||
inState.(nn.LSTMState).Tensor1.MustDrop()
|
||||
inState.(nn.LSTMState).Tensor2.MustDrop()
|
||||
// 2. Then update with current state
|
||||
inState = state
|
||||
// 3. Delete intermediate tensors
|
||||
input.MustDrop()
|
||||
inputView.MustDrop()
|
||||
|
||||
forwardTs := linear.Forward(state.(nn.LSTMState).H()).MustSqueeze1(0, true).MustSoftmax(-1, gotch.Float, true)
|
||||
sampledY := forwardTs.MustMultinomial(1, false, true)
|
||||
lastLabel = sampledY.Int64Values()[0]
|
||||
sampledY.MustDrop()
|
||||
char := data.LabelForChar(lastLabel)
|
||||
|
||||
result += fmt.Sprintf("%v", lastLabel)
|
||||
runes = append(runes, char)
|
||||
}
|
||||
|
||||
return result
|
||||
// Delete the last state
|
||||
inState.(nn.LSTMState).Tensor1.MustDrop()
|
||||
inState.(nn.LSTMState).Tensor2.MustDrop()
|
||||
|
||||
return string(runes)
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
@ -59,6 +73,7 @@ func main() {
|
|||
|
||||
lstm := nn.NewLSTM(vs.Root(), labels, HiddenSize, nn.DefaultRNNConfig())
|
||||
linear := nn.NewLinear(vs.Root(), HiddenSize, labels, nn.DefaultLinearConfig())
|
||||
|
||||
optConfig := nn.DefaultAdamConfig()
|
||||
opt, err := optConfig.Build(vs, LearningRate)
|
||||
if err != nil {
|
||||
|
@ -86,7 +101,12 @@ func main() {
|
|||
ysTmp3 := ysTmp2.MustTo(device, true)
|
||||
ys := ysTmp3.MustView([]int64{BatchSize * SeqLen}, true)
|
||||
|
||||
lstmOut, _ := lstm.Seq(xsOnehot)
|
||||
lstmOut, outState := lstm.Seq(xsOnehot)
|
||||
// NOTE. Although outState will not be used. There a hidden memory usage
|
||||
// on C land memory that is needed to free up. Don't use `_`
|
||||
outState.(nn.LSTMState).Tensor1.MustDrop()
|
||||
outState.(nn.LSTMState).Tensor2.MustDrop()
|
||||
|
||||
logits := linear.Forward(lstmOut)
|
||||
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true)
|
||||
|
||||
|
@ -96,18 +116,26 @@ func main() {
|
|||
sumLoss += loss.Float64Values()[0]
|
||||
cntLoss += 1.0
|
||||
|
||||
batchTs.MustDrop()
|
||||
batchNarrow.MustDrop()
|
||||
xsOnehotTmp.MustDrop()
|
||||
xsOnehot.MustDrop()
|
||||
lstmOut.MustDrop()
|
||||
ys.MustDrop()
|
||||
lstmOut.MustDrop()
|
||||
loss.MustDrop()
|
||||
|
||||
batchCount++
|
||||
fmt.Printf("Batch %v - sumLoss: %v - cntLoss %v\n", batchCount, sumLoss, cntLoss)
|
||||
if batchCount%500 == 0 {
|
||||
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Epoch %v - Loss: %v", epoch, sumLoss/cntLoss)
|
||||
fmt.Printf("Sample: %v", sample(data, lstm, linear, device))
|
||||
sampleStr := sample(data, lstm, linear, device)
|
||||
fmt.Printf("Epoch %v - Loss: %v \n", epoch, sumLoss/cntLoss)
|
||||
fmt.Println(sampleStr)
|
||||
|
||||
dataIter.Data.MustDrop()
|
||||
dataIter.Indexes.MustDrop()
|
||||
}
|
||||
|
||||
}
|
||||
|
|
39
nn/rnn.go
39
nn/rnn.go
|
@ -29,13 +29,6 @@ type RNN interface {
|
|||
SeqInit(input ts.Tensor, inState State) (ts.Tensor, State)
|
||||
}
|
||||
|
||||
func defaultSeq(self interface{}, input ts.Tensor) (ts.Tensor, State) {
|
||||
batchDim := input.MustSize()[0]
|
||||
inState := self.(RNN).ZeroState(batchDim)
|
||||
|
||||
return self.(RNN).SeqInit(input, inState)
|
||||
}
|
||||
|
||||
// The state for a LSTM network, this contains two tensors.
|
||||
type LSTMState struct {
|
||||
Tensor1 ts.Tensor
|
||||
|
@ -141,10 +134,14 @@ func (l LSTM) ZeroState(batchDim int64) (retVal State) {
|
|||
shape := []int64{layerDim, batchDim, l.hiddenDim}
|
||||
zeros := ts.MustZeros(shape, gotch.Float, l.device)
|
||||
|
||||
return LSTMState{
|
||||
retVal = LSTMState{
|
||||
Tensor1: zeros.MustShallowClone(),
|
||||
Tensor2: zeros.MustShallowClone(),
|
||||
}
|
||||
|
||||
zeros.MustDrop()
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (l LSTM) Step(input ts.Tensor, inState State) (retVal State) {
|
||||
|
@ -159,8 +156,17 @@ func (l LSTM) Step(input ts.Tensor, inState State) (retVal State) {
|
|||
return state
|
||||
}
|
||||
|
||||
func (l LSTM) Seq(input ts.Tensor) (ts.Tensor, State) {
|
||||
return defaultSeq(l, input)
|
||||
func (l LSTM) Seq(input ts.Tensor) (output ts.Tensor, state State) {
|
||||
batchDim := input.MustSize()[0]
|
||||
inState := l.ZeroState(batchDim)
|
||||
|
||||
output, state = l.SeqInit(input, inState)
|
||||
|
||||
// Delete intermediate tensors in inState
|
||||
inState.(LSTMState).Tensor1.MustDrop()
|
||||
inState.(LSTMState).Tensor2.MustDrop()
|
||||
|
||||
return output, state
|
||||
}
|
||||
|
||||
func (l LSTM) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) {
|
||||
|
@ -254,8 +260,17 @@ func (g GRU) Step(input ts.Tensor, inState State) (retVal State) {
|
|||
return state
|
||||
}
|
||||
|
||||
func (g GRU) Seq(input ts.Tensor) (ts.Tensor, State) {
|
||||
return defaultSeq(g, input)
|
||||
func (g GRU) Seq(input ts.Tensor) (output ts.Tensor, state State) {
|
||||
batchDim := input.MustSize()[0]
|
||||
inState := g.ZeroState(batchDim)
|
||||
|
||||
output, state = g.SeqInit(input, inState)
|
||||
|
||||
// Delete intermediate tensors in inState
|
||||
inState.(LSTMState).Tensor1.MustDrop()
|
||||
inState.(LSTMState).Tensor2.MustDrop()
|
||||
|
||||
return output, state
|
||||
}
|
||||
|
||||
func (g GRU) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) {
|
||||
|
|
|
@ -269,6 +269,7 @@ func (tdi *TextDataIter) Next() (retVal Tensor, ok bool) {
|
|||
indexesTs := tdi.Indexes.Idx(narrowIdx)
|
||||
|
||||
indexes := indexesTs.Int64Values()
|
||||
indexesTs.MustDrop()
|
||||
|
||||
var batch []Tensor
|
||||
|
||||
|
@ -280,6 +281,11 @@ func (tdi *TextDataIter) Next() (retVal Tensor, ok bool) {
|
|||
|
||||
retVal = MustStack(batch, 0)
|
||||
|
||||
// Delete intermediate tensors
|
||||
for _, xs := range batch {
|
||||
xs.MustDrop()
|
||||
}
|
||||
|
||||
return retVal, true
|
||||
}
|
||||
|
||||
|
|
|
@ -621,7 +621,7 @@ func (ts Tensor) Numel() (retVal uint) {
|
|||
return uint(FlattenDim(shape))
|
||||
}
|
||||
|
||||
// ShallowCopy returns a new tensor that share storage with the input tensor.
|
||||
// ShallowClone returns a new tensor that share storage with the input tensor.
|
||||
func (ts Tensor) ShallowClone() (retVal Tensor, err error) {
|
||||
|
||||
ctensor := lib.AtShallowClone(ts.ctensor)
|
||||
|
@ -1120,10 +1120,14 @@ func (ts Tensor) MustZeroPad2d(left, right, top, bottom int64, del bool) (retVal
|
|||
func (ts Tensor) Onehot(labels int64) (retVal Tensor) {
|
||||
dims := ts.MustSize()
|
||||
dims = append(dims, labels)
|
||||
unsqueezeTs := ts.MustUnsqueeze(-1, false).MustTotype(gotch.Int64, true)
|
||||
unsqueezeTs := ts.MustUnsqueeze(-1, false)
|
||||
inputTs := unsqueezeTs.MustTotype(gotch.Int64, true)
|
||||
|
||||
zerosTs := MustZeros(dims, gotch.Float, gotch.CPU)
|
||||
zerosTs.MustScatter1_(-1, unsqueezeTs, FloatScalar(1.0))
|
||||
return zerosTs
|
||||
retVal = zerosTs.MustScatter1(-1, inputTs, FloatScalar(1.0), true)
|
||||
inputTs.MustDrop()
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Swish() (retVal Tensor) {
|
||||
|
|
Loading…
Reference in New Issue
Block a user