diff --git a/example/char-rnn/main.go b/example/char-rnn/main.go index 24880a9..b1af728 100644 --- a/example/char-rnn/main.go +++ b/example/char-rnn/main.go @@ -18,7 +18,7 @@ const ( SamplingLen int64 = 1024 ) -func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Device) (retVal string) { +func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.Device) string { labels := data.Labels() inState := lstm.ZeroState(1) @@ -34,15 +34,15 @@ func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Devic state := lstm.Step(input, inState) // 1. Delete inState tensors (from C land memory) - inState.(nn.LSTMState).Tensor1.MustDrop() - inState.(nn.LSTMState).Tensor2.MustDrop() + inState.(*nn.LSTMState).Tensor1.MustDrop() + inState.(*nn.LSTMState).Tensor2.MustDrop() // 2. Then update with current state inState = state // 3. Delete intermediate tensors input.MustDrop() inputView.MustDrop() - forwardTs := linear.Forward(state.(nn.LSTMState).H()).MustSqueeze1(0, true).MustSoftmax(-1, gotch.Float, true) + forwardTs := linear.Forward(state.(*nn.LSTMState).H()).MustSqueeze1(0, true).MustSoftmax(-1, gotch.Float, true) sampledY := forwardTs.MustMultinomial(1, false, true) lastLabel = sampledY.Int64Values()[0] sampledY.MustDrop() @@ -52,8 +52,8 @@ func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Devic } // Delete the last state - inState.(nn.LSTMState).Tensor1.MustDrop() - inState.(nn.LSTMState).Tensor2.MustDrop() + inState.(*nn.LSTMState).Tensor1.MustDrop() + inState.(*nn.LSTMState).Tensor2.MustDrop() return string(runes) } @@ -104,8 +104,8 @@ func main() { lstmOut, outState := lstm.Seq(xsOnehot) // NOTE. Although outState will not be used. There a hidden memory usage // on C land memory that is needed to free up. Don't use `_` - outState.(nn.LSTMState).Tensor1.MustDrop() - outState.(nn.LSTMState).Tensor2.MustDrop() + outState.(*nn.LSTMState).Tensor1.MustDrop() + outState.(*nn.LSTMState).Tensor2.MustDrop() logits := linear.Forward(lstmOut) lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true) diff --git a/nn/rnn.go b/nn/rnn.go index 1a4d875..67c4212 100644 --- a/nn/rnn.go +++ b/nn/rnn.go @@ -124,7 +124,7 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM { // Implement RNN interface for LSTM: // ================================= -func (l *LSTM) ZeroState(batchDim int64) (retVal State) { +func (l *LSTM) ZeroState(batchDim int64) State { var numDirections int64 = 1 if l.config.Bidirectional { numDirections = 2 @@ -134,7 +134,7 @@ func (l *LSTM) ZeroState(batchDim int64) (retVal State) { shape := []int64{layerDim, batchDim, l.hiddenDim} zeros := ts.MustZeros(shape, gotch.Float, l.device) - retVal = LSTMState{ + retVal := &LSTMState{ Tensor1: zeros.MustShallowClone(), Tensor2: zeros.MustShallowClone(), } @@ -144,7 +144,7 @@ func (l *LSTM) ZeroState(batchDim int64) (retVal State) { return retVal } -func (l *LSTM) Step(input *ts.Tensor, inState State) (retVal State) { +func (l *LSTM) Step(input *ts.Tensor, inState State) State { ip := input.MustUnsqueeze(1, false) output, state := l.SeqInit(ip, inState) @@ -156,24 +156,24 @@ func (l *LSTM) Step(input *ts.Tensor, inState State) (retVal State) { return state } -func (l *LSTM) Seq(input *ts.Tensor) (output *ts.Tensor, state State) { +func (l *LSTM) Seq(input *ts.Tensor) (*ts.Tensor, State) { batchDim := input.MustSize()[0] inState := l.ZeroState(batchDim) - output, state = l.SeqInit(input, inState) + output, state := l.SeqInit(input, inState) // Delete intermediate tensors in inState - inState.(LSTMState).Tensor1.MustDrop() - inState.(LSTMState).Tensor2.MustDrop() + inState.(*LSTMState).Tensor1.MustDrop() + inState.(*LSTMState).Tensor2.MustDrop() return output, state } func (l *LSTM) SeqInit(input *ts.Tensor, inState State) (*ts.Tensor, State) { - output, h, c := input.MustLstm([]ts.Tensor{*inState.(LSTMState).Tensor1, *inState.(LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst) + output, h, c := input.MustLstm([]ts.Tensor{*inState.(*LSTMState).Tensor1, *inState.(*LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst) - return output, LSTMState{ + return output, &LSTMState{ Tensor1: h, Tensor2: c, } @@ -243,7 +243,7 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) (retVal *GRU) { // Implement RNN interface for GRU: // ================================ -func (g *GRU) ZeroState(batchDim int64) (retVal State) { +func (g *GRU) ZeroState(batchDim int64) State { var numDirections int64 = 1 if g.config.Bidirectional { numDirections = 2 @@ -254,10 +254,10 @@ func (g *GRU) ZeroState(batchDim int64) (retVal State) { tensor := ts.MustZeros(shape, gotch.Float, g.device) - return GRUState{Tensor: tensor} + return &GRUState{Tensor: tensor} } -func (g *GRU) Step(input *ts.Tensor, inState State) (retVal State) { +func (g *GRU) Step(input *ts.Tensor, inState State) State { unsqueezedInput := input.MustUnsqueeze(1, false) output, state := g.SeqInit(unsqueezedInput, inState) @@ -269,21 +269,21 @@ func (g *GRU) Step(input *ts.Tensor, inState State) (retVal State) { return state } -func (g *GRU) Seq(input *ts.Tensor) (output *ts.Tensor, state State) { +func (g *GRU) Seq(input *ts.Tensor) (*ts.Tensor, State) { batchDim := input.MustSize()[0] inState := g.ZeroState(batchDim) - output, state = g.SeqInit(input, inState) + output, state := g.SeqInit(input, inState) // Delete intermediate tensors in inState - inState.(GRUState).Tensor.MustDrop() + inState.(*GRUState).Tensor.MustDrop() return output, state } func (g *GRU) SeqInit(input *ts.Tensor, inState State) (*ts.Tensor, State) { - output, h := input.MustGru(inState.(GRUState).Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst) + output, h := input.MustGru(inState.(*GRUState).Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst) - return output, GRUState{Tensor: h} + return output, &GRUState{Tensor: h} } diff --git a/nn/rnn_test.go b/nn/rnn_test.go index 762e022..47f7263 100644 --- a/nn/rnn_test.go +++ b/nn/rnn_test.go @@ -32,10 +32,10 @@ func gruTest(rnnConfig *nn.RNNConfig, t *testing.T) { // Step test input := ts.MustRandn([]int64{batchDim, inputDim}, gotch.Float, gotch.CPU) - output := gru.Step(input, gru.ZeroState(batchDim).(nn.GRUState)) + output := gru.Step(input, gru.ZeroState(batchDim).(*nn.GRUState)) want := []int64{layerDim, batchDim, outputDim} - got := output.(nn.GRUState).Tensor.MustSize() + got := output.(*nn.GRUState).Tensor.MustSize() if !reflect.DeepEqual(want, got) { fmt.Println("Step test:") @@ -97,12 +97,12 @@ func lstmTest(rnnConfig *nn.RNNConfig, t *testing.T) { // Step test input := ts.MustRandn([]int64{batchDim, inputDim}, gotch.Float, gotch.CPU) - output := lstm.Step(input, lstm.ZeroState(batchDim).(nn.LSTMState)) + output := lstm.Step(input, lstm.ZeroState(batchDim).(*nn.LSTMState)) wantH := []int64{layerDim, batchDim, outputDim} - gotH := output.(nn.LSTMState).Tensor1.MustSize() + gotH := output.(*nn.LSTMState).Tensor1.MustSize() wantC := []int64{layerDim, batchDim, outputDim} - gotC := output.(nn.LSTMState).Tensor2.MustSize() + gotC := output.(*nn.LSTMState).Tensor2.MustSize() if !reflect.DeepEqual(wantH, gotH) { fmt.Println("Step test:") diff --git a/vision/imagenet.go b/vision/imagenet.go index d6c6e2a..ab81586 100644 --- a/vision/imagenet.go +++ b/vision/imagenet.go @@ -1281,7 +1281,7 @@ type TopItem struct { } // Returns the top k classes as well as the associated scores. -func (in *ImageNet) Top(input ts.Tensor, k int64) []TopItem { +func (in *ImageNet) Top(input *ts.Tensor, k int64) []TopItem { var tensor *ts.Tensor shape := input.MustSize()