fix(nn/rnn): correct LSTM not flatten weights; added Device.IsCuda(); WIP(example/char-rnn): still memory blowup

This commit is contained in:
sugarme 2020-07-27 17:17:38 +10:00
parent c69229c43a
commit 11a8dd9245
4 changed files with 32 additions and 21 deletions

View File

@ -105,3 +105,12 @@ func (d Device) CudaIfAvailable() Device {
return CPU
}
}
// IsCuda returns whether device is a Cuda device
func (d Device) IsCuda() bool {
if d.Name == "CPU" {
return false
}
return true
}

View File

@ -71,42 +71,38 @@ func main() {
dataIter := data.IterShuffle(SeqLen+1, BatchSize)
batchCount := 0
for {
batchTs, ok := dataIter.Next()
if !ok {
break
}
testTs := ts.MustOfSlice([]int64{0, 1, 2, 3, 4, 5, 6, 7, 8})
testVal := testTs.Onehot(14).Float64Values()
fmt.Printf("testVal: %v\n", testVal)
fmt.Printf("batchTs shape: %v\n", batchTs.MustSize())
batchNarrow := batchTs.MustNarrow(1, 0, SeqLen, false)
fmt.Printf("batchNarrow shape: %v\n", batchNarrow.MustSize())
xsOnehotTmp := batchNarrow.Onehot(labels)
xsOnehot := xsOnehotTmp.MustTo(device, true) // shape: [256, 180, 65]
ys := batchTs.MustNarrow(1, 1, SeqLen, false).MustTotype(gotch.Int64, false)
ysTmp1 := batchTs.MustNarrow(1, 1, SeqLen, true)
ysTmp2 := ysTmp1.MustTotype(gotch.Int64, true)
ysTmp3 := ysTmp2.MustTo(device, true)
ys := ysTmp3.MustView([]int64{BatchSize * SeqLen}, true)
// NOTE. WARNING occurred here...
// Warning: RNN module weights are not part of single contiguous chunk of memory.
// This means they need to be compacted at every call, possibly greatly increasing memory usage.
// To compact weights again call flatten_parameters().
// See: https://discuss.pytorch.org/t/rnn-module-weights-are-not-part-of-single-contiguous-chunk-of-memory/6011/21
lstmOut, _ := lstm.Seq(xsOnehot)
panic("reached")
logits := linear.Forward(lstmOut)
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true)
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, false)
loss := lossView.CrossEntropyForLogits(ys.MustTo(device, false)).MustView([]int64{BatchSize * SeqLen}, false)
loss := lossView.CrossEntropyForLogits(ys)
opt.BackwardStepClip(loss, 0.5)
sumLoss += loss.Float64Values()[0]
cntLoss += 1.0
xsOnehot.MustDrop()
lstmOut.MustDrop()
ys.MustDrop()
loss.MustDrop()
batchCount++
fmt.Printf("Batch %v - sumLoss: %v - cntLoss %v\n", batchCount, sumLoss, cntLoss)
}
fmt.Printf("Epoch %v - Loss: %v", epoch, sumLoss/cntLoss)

View File

@ -111,6 +111,14 @@ func NewLSTM(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal LSTM) {
}
}
// if vs.Device().IsCuda() && gotch.Cuda.CudnnIsAvailable() {
// TODO: check if Cudnn is available here!!!
if vs.Device().IsCuda() {
// NOTE. 2 is for LSTM
// ref. rnn.cpp in Pytorch
ts.Must_CudnnRnnFlattenWeight(flatWeights, 4, inDim, 2, hiddenDim, cfg.NumLayers, cfg.BatchFirst, cfg.Bidirectional)
}
return LSTM{
flatWeights: flatWeights,
hiddenDim: hiddenDim,

View File

@ -1122,8 +1122,6 @@ func (ts Tensor) Onehot(labels int64) (retVal Tensor) {
dims = append(dims, labels)
unsqueezeTs := ts.MustUnsqueeze(-1, false).MustTotype(gotch.Int64, true)
zerosTs := MustZeros(dims, gotch.Float, gotch.CPU)
fmt.Printf("zeroTs shape: %v\n", zerosTs.MustSize())
fmt.Printf("unsqueezeTs shape: %v\n", unsqueezeTs.MustSize())
zerosTs.MustScatter1_(-1, unsqueezeTs, FloatScalar(1.0))
return zerosTs
}