diff --git a/device.go b/device.go index 84d4398..9180466 100644 --- a/device.go +++ b/device.go @@ -105,3 +105,12 @@ func (d Device) CudaIfAvailable() Device { return CPU } } + +// IsCuda returns whether device is a Cuda device +func (d Device) IsCuda() bool { + if d.Name == "CPU" { + return false + } + + return true +} diff --git a/example/char-rnn/main.go b/example/char-rnn/main.go index aa4b7ac..7861f7e 100644 --- a/example/char-rnn/main.go +++ b/example/char-rnn/main.go @@ -71,42 +71,38 @@ func main() { dataIter := data.IterShuffle(SeqLen+1, BatchSize) + batchCount := 0 for { batchTs, ok := dataIter.Next() if !ok { break } - testTs := ts.MustOfSlice([]int64{0, 1, 2, 3, 4, 5, 6, 7, 8}) - testVal := testTs.Onehot(14).Float64Values() - fmt.Printf("testVal: %v\n", testVal) - - fmt.Printf("batchTs shape: %v\n", batchTs.MustSize()) - batchNarrow := batchTs.MustNarrow(1, 0, SeqLen, false) - fmt.Printf("batchNarrow shape: %v\n", batchNarrow.MustSize()) xsOnehotTmp := batchNarrow.Onehot(labels) xsOnehot := xsOnehotTmp.MustTo(device, true) // shape: [256, 180, 65] - ys := batchTs.MustNarrow(1, 1, SeqLen, false).MustTotype(gotch.Int64, false) + ysTmp1 := batchTs.MustNarrow(1, 1, SeqLen, true) + ysTmp2 := ysTmp1.MustTotype(gotch.Int64, true) + ysTmp3 := ysTmp2.MustTo(device, true) + ys := ysTmp3.MustView([]int64{BatchSize * SeqLen}, true) - // NOTE. WARNING occurred here... - // Warning: RNN module weights are not part of single contiguous chunk of memory. - // This means they need to be compacted at every call, possibly greatly increasing memory usage. - // To compact weights again call flatten_parameters(). - // See: https://discuss.pytorch.org/t/rnn-module-weights-are-not-part-of-single-contiguous-chunk-of-memory/6011/21 lstmOut, _ := lstm.Seq(xsOnehot) - - panic("reached") - logits := linear.Forward(lstmOut) + lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true) - lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, false) - - loss := lossView.CrossEntropyForLogits(ys.MustTo(device, false)).MustView([]int64{BatchSize * SeqLen}, false) + loss := lossView.CrossEntropyForLogits(ys) opt.BackwardStepClip(loss, 0.5) sumLoss += loss.Float64Values()[0] cntLoss += 1.0 + + xsOnehot.MustDrop() + lstmOut.MustDrop() + ys.MustDrop() + loss.MustDrop() + + batchCount++ + fmt.Printf("Batch %v - sumLoss: %v - cntLoss %v\n", batchCount, sumLoss, cntLoss) } fmt.Printf("Epoch %v - Loss: %v", epoch, sumLoss/cntLoss) diff --git a/nn/rnn.go b/nn/rnn.go index 80f3f77..4788d0a 100644 --- a/nn/rnn.go +++ b/nn/rnn.go @@ -111,6 +111,14 @@ func NewLSTM(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal LSTM) { } } + // if vs.Device().IsCuda() && gotch.Cuda.CudnnIsAvailable() { + // TODO: check if Cudnn is available here!!! + if vs.Device().IsCuda() { + // NOTE. 2 is for LSTM + // ref. rnn.cpp in Pytorch + ts.Must_CudnnRnnFlattenWeight(flatWeights, 4, inDim, 2, hiddenDim, cfg.NumLayers, cfg.BatchFirst, cfg.Bidirectional) + } + return LSTM{ flatWeights: flatWeights, hiddenDim: hiddenDim, diff --git a/tensor/tensor.go b/tensor/tensor.go index 3732967..808a3a9 100644 --- a/tensor/tensor.go +++ b/tensor/tensor.go @@ -1122,8 +1122,6 @@ func (ts Tensor) Onehot(labels int64) (retVal Tensor) { dims = append(dims, labels) unsqueezeTs := ts.MustUnsqueeze(-1, false).MustTotype(gotch.Int64, true) zerosTs := MustZeros(dims, gotch.Float, gotch.CPU) - fmt.Printf("zeroTs shape: %v\n", zerosTs.MustSize()) - fmt.Printf("unsqueezeTs shape: %v\n", unsqueezeTs.MustSize()) zerosTs.MustScatter1_(-1, unsqueezeTs, FloatScalar(1.0)) return zerosTs }