fix(nn/rnn): correct LSTM not flatten weights; added Device.IsCuda(); WIP(example/char-rnn): still memory blowup
This commit is contained in:
parent
c69229c43a
commit
11a8dd9245
|
@ -105,3 +105,12 @@ func (d Device) CudaIfAvailable() Device {
|
||||||
return CPU
|
return CPU
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsCuda returns whether device is a Cuda device
|
||||||
|
func (d Device) IsCuda() bool {
|
||||||
|
if d.Name == "CPU" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
|
@ -71,42 +71,38 @@ func main() {
|
||||||
|
|
||||||
dataIter := data.IterShuffle(SeqLen+1, BatchSize)
|
dataIter := data.IterShuffle(SeqLen+1, BatchSize)
|
||||||
|
|
||||||
|
batchCount := 0
|
||||||
for {
|
for {
|
||||||
batchTs, ok := dataIter.Next()
|
batchTs, ok := dataIter.Next()
|
||||||
if !ok {
|
if !ok {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
testTs := ts.MustOfSlice([]int64{0, 1, 2, 3, 4, 5, 6, 7, 8})
|
|
||||||
testVal := testTs.Onehot(14).Float64Values()
|
|
||||||
fmt.Printf("testVal: %v\n", testVal)
|
|
||||||
|
|
||||||
fmt.Printf("batchTs shape: %v\n", batchTs.MustSize())
|
|
||||||
|
|
||||||
batchNarrow := batchTs.MustNarrow(1, 0, SeqLen, false)
|
batchNarrow := batchTs.MustNarrow(1, 0, SeqLen, false)
|
||||||
fmt.Printf("batchNarrow shape: %v\n", batchNarrow.MustSize())
|
|
||||||
xsOnehotTmp := batchNarrow.Onehot(labels)
|
xsOnehotTmp := batchNarrow.Onehot(labels)
|
||||||
xsOnehot := xsOnehotTmp.MustTo(device, true) // shape: [256, 180, 65]
|
xsOnehot := xsOnehotTmp.MustTo(device, true) // shape: [256, 180, 65]
|
||||||
ys := batchTs.MustNarrow(1, 1, SeqLen, false).MustTotype(gotch.Int64, false)
|
ysTmp1 := batchTs.MustNarrow(1, 1, SeqLen, true)
|
||||||
|
ysTmp2 := ysTmp1.MustTotype(gotch.Int64, true)
|
||||||
|
ysTmp3 := ysTmp2.MustTo(device, true)
|
||||||
|
ys := ysTmp3.MustView([]int64{BatchSize * SeqLen}, true)
|
||||||
|
|
||||||
// NOTE. WARNING occurred here...
|
|
||||||
// Warning: RNN module weights are not part of single contiguous chunk of memory.
|
|
||||||
// This means they need to be compacted at every call, possibly greatly increasing memory usage.
|
|
||||||
// To compact weights again call flatten_parameters().
|
|
||||||
// See: https://discuss.pytorch.org/t/rnn-module-weights-are-not-part-of-single-contiguous-chunk-of-memory/6011/21
|
|
||||||
lstmOut, _ := lstm.Seq(xsOnehot)
|
lstmOut, _ := lstm.Seq(xsOnehot)
|
||||||
|
|
||||||
panic("reached")
|
|
||||||
|
|
||||||
logits := linear.Forward(lstmOut)
|
logits := linear.Forward(lstmOut)
|
||||||
|
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true)
|
||||||
|
|
||||||
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, false)
|
loss := lossView.CrossEntropyForLogits(ys)
|
||||||
|
|
||||||
loss := lossView.CrossEntropyForLogits(ys.MustTo(device, false)).MustView([]int64{BatchSize * SeqLen}, false)
|
|
||||||
|
|
||||||
opt.BackwardStepClip(loss, 0.5)
|
opt.BackwardStepClip(loss, 0.5)
|
||||||
sumLoss += loss.Float64Values()[0]
|
sumLoss += loss.Float64Values()[0]
|
||||||
cntLoss += 1.0
|
cntLoss += 1.0
|
||||||
|
|
||||||
|
xsOnehot.MustDrop()
|
||||||
|
lstmOut.MustDrop()
|
||||||
|
ys.MustDrop()
|
||||||
|
loss.MustDrop()
|
||||||
|
|
||||||
|
batchCount++
|
||||||
|
fmt.Printf("Batch %v - sumLoss: %v - cntLoss %v\n", batchCount, sumLoss, cntLoss)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Epoch %v - Loss: %v", epoch, sumLoss/cntLoss)
|
fmt.Printf("Epoch %v - Loss: %v", epoch, sumLoss/cntLoss)
|
||||||
|
|
|
@ -111,6 +111,14 @@ func NewLSTM(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal LSTM) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if vs.Device().IsCuda() && gotch.Cuda.CudnnIsAvailable() {
|
||||||
|
// TODO: check if Cudnn is available here!!!
|
||||||
|
if vs.Device().IsCuda() {
|
||||||
|
// NOTE. 2 is for LSTM
|
||||||
|
// ref. rnn.cpp in Pytorch
|
||||||
|
ts.Must_CudnnRnnFlattenWeight(flatWeights, 4, inDim, 2, hiddenDim, cfg.NumLayers, cfg.BatchFirst, cfg.Bidirectional)
|
||||||
|
}
|
||||||
|
|
||||||
return LSTM{
|
return LSTM{
|
||||||
flatWeights: flatWeights,
|
flatWeights: flatWeights,
|
||||||
hiddenDim: hiddenDim,
|
hiddenDim: hiddenDim,
|
||||||
|
|
|
@ -1122,8 +1122,6 @@ func (ts Tensor) Onehot(labels int64) (retVal Tensor) {
|
||||||
dims = append(dims, labels)
|
dims = append(dims, labels)
|
||||||
unsqueezeTs := ts.MustUnsqueeze(-1, false).MustTotype(gotch.Int64, true)
|
unsqueezeTs := ts.MustUnsqueeze(-1, false).MustTotype(gotch.Int64, true)
|
||||||
zerosTs := MustZeros(dims, gotch.Float, gotch.CPU)
|
zerosTs := MustZeros(dims, gotch.Float, gotch.CPU)
|
||||||
fmt.Printf("zeroTs shape: %v\n", zerosTs.MustSize())
|
|
||||||
fmt.Printf("unsqueezeTs shape: %v\n", unsqueezeTs.MustSize())
|
|
||||||
zerosTs.MustScatter1_(-1, unsqueezeTs, FloatScalar(1.0))
|
zerosTs.MustScatter1_(-1, unsqueezeTs, FloatScalar(1.0))
|
||||||
return zerosTs
|
return zerosTs
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user