nn/rnn: corrected interface type conversion

This commit is contained in:
sugarme 2020-10-31 23:30:04 +11:00
parent b69d46eae4
commit 5414b6ed57
4 changed files with 31 additions and 31 deletions

View File

@ -18,7 +18,7 @@ const (
SamplingLen int64 = 1024
)
func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Device) (retVal string) {
func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.Device) string {
labels := data.Labels()
inState := lstm.ZeroState(1)
@ -34,15 +34,15 @@ func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Devic
state := lstm.Step(input, inState)
// 1. Delete inState tensors (from C land memory)
inState.(nn.LSTMState).Tensor1.MustDrop()
inState.(nn.LSTMState).Tensor2.MustDrop()
inState.(*nn.LSTMState).Tensor1.MustDrop()
inState.(*nn.LSTMState).Tensor2.MustDrop()
// 2. Then update with current state
inState = state
// 3. Delete intermediate tensors
input.MustDrop()
inputView.MustDrop()
forwardTs := linear.Forward(state.(nn.LSTMState).H()).MustSqueeze1(0, true).MustSoftmax(-1, gotch.Float, true)
forwardTs := linear.Forward(state.(*nn.LSTMState).H()).MustSqueeze1(0, true).MustSoftmax(-1, gotch.Float, true)
sampledY := forwardTs.MustMultinomial(1, false, true)
lastLabel = sampledY.Int64Values()[0]
sampledY.MustDrop()
@ -52,8 +52,8 @@ func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Devic
}
// Delete the last state
inState.(nn.LSTMState).Tensor1.MustDrop()
inState.(nn.LSTMState).Tensor2.MustDrop()
inState.(*nn.LSTMState).Tensor1.MustDrop()
inState.(*nn.LSTMState).Tensor2.MustDrop()
return string(runes)
}
@ -104,8 +104,8 @@ func main() {
lstmOut, outState := lstm.Seq(xsOnehot)
// NOTE. Although outState will not be used. There a hidden memory usage
// on C land memory that is needed to free up. Don't use `_`
outState.(nn.LSTMState).Tensor1.MustDrop()
outState.(nn.LSTMState).Tensor2.MustDrop()
outState.(*nn.LSTMState).Tensor1.MustDrop()
outState.(*nn.LSTMState).Tensor2.MustDrop()
logits := linear.Forward(lstmOut)
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true)

View File

@ -124,7 +124,7 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
// Implement RNN interface for LSTM:
// =================================
func (l *LSTM) ZeroState(batchDim int64) (retVal State) {
func (l *LSTM) ZeroState(batchDim int64) State {
var numDirections int64 = 1
if l.config.Bidirectional {
numDirections = 2
@ -134,7 +134,7 @@ func (l *LSTM) ZeroState(batchDim int64) (retVal State) {
shape := []int64{layerDim, batchDim, l.hiddenDim}
zeros := ts.MustZeros(shape, gotch.Float, l.device)
retVal = LSTMState{
retVal := &LSTMState{
Tensor1: zeros.MustShallowClone(),
Tensor2: zeros.MustShallowClone(),
}
@ -144,7 +144,7 @@ func (l *LSTM) ZeroState(batchDim int64) (retVal State) {
return retVal
}
func (l *LSTM) Step(input *ts.Tensor, inState State) (retVal State) {
func (l *LSTM) Step(input *ts.Tensor, inState State) State {
ip := input.MustUnsqueeze(1, false)
output, state := l.SeqInit(ip, inState)
@ -156,24 +156,24 @@ func (l *LSTM) Step(input *ts.Tensor, inState State) (retVal State) {
return state
}
func (l *LSTM) Seq(input *ts.Tensor) (output *ts.Tensor, state State) {
func (l *LSTM) Seq(input *ts.Tensor) (*ts.Tensor, State) {
batchDim := input.MustSize()[0]
inState := l.ZeroState(batchDim)
output, state = l.SeqInit(input, inState)
output, state := l.SeqInit(input, inState)
// Delete intermediate tensors in inState
inState.(LSTMState).Tensor1.MustDrop()
inState.(LSTMState).Tensor2.MustDrop()
inState.(*LSTMState).Tensor1.MustDrop()
inState.(*LSTMState).Tensor2.MustDrop()
return output, state
}
func (l *LSTM) SeqInit(input *ts.Tensor, inState State) (*ts.Tensor, State) {
output, h, c := input.MustLstm([]ts.Tensor{*inState.(LSTMState).Tensor1, *inState.(LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
output, h, c := input.MustLstm([]ts.Tensor{*inState.(*LSTMState).Tensor1, *inState.(*LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
return output, LSTMState{
return output, &LSTMState{
Tensor1: h,
Tensor2: c,
}
@ -243,7 +243,7 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) (retVal *GRU) {
// Implement RNN interface for GRU:
// ================================
func (g *GRU) ZeroState(batchDim int64) (retVal State) {
func (g *GRU) ZeroState(batchDim int64) State {
var numDirections int64 = 1
if g.config.Bidirectional {
numDirections = 2
@ -254,10 +254,10 @@ func (g *GRU) ZeroState(batchDim int64) (retVal State) {
tensor := ts.MustZeros(shape, gotch.Float, g.device)
return GRUState{Tensor: tensor}
return &GRUState{Tensor: tensor}
}
func (g *GRU) Step(input *ts.Tensor, inState State) (retVal State) {
func (g *GRU) Step(input *ts.Tensor, inState State) State {
unsqueezedInput := input.MustUnsqueeze(1, false)
output, state := g.SeqInit(unsqueezedInput, inState)
@ -269,21 +269,21 @@ func (g *GRU) Step(input *ts.Tensor, inState State) (retVal State) {
return state
}
func (g *GRU) Seq(input *ts.Tensor) (output *ts.Tensor, state State) {
func (g *GRU) Seq(input *ts.Tensor) (*ts.Tensor, State) {
batchDim := input.MustSize()[0]
inState := g.ZeroState(batchDim)
output, state = g.SeqInit(input, inState)
output, state := g.SeqInit(input, inState)
// Delete intermediate tensors in inState
inState.(GRUState).Tensor.MustDrop()
inState.(*GRUState).Tensor.MustDrop()
return output, state
}
func (g *GRU) SeqInit(input *ts.Tensor, inState State) (*ts.Tensor, State) {
output, h := input.MustGru(inState.(GRUState).Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst)
output, h := input.MustGru(inState.(*GRUState).Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst)
return output, GRUState{Tensor: h}
return output, &GRUState{Tensor: h}
}

View File

@ -32,10 +32,10 @@ func gruTest(rnnConfig *nn.RNNConfig, t *testing.T) {
// Step test
input := ts.MustRandn([]int64{batchDim, inputDim}, gotch.Float, gotch.CPU)
output := gru.Step(input, gru.ZeroState(batchDim).(nn.GRUState))
output := gru.Step(input, gru.ZeroState(batchDim).(*nn.GRUState))
want := []int64{layerDim, batchDim, outputDim}
got := output.(nn.GRUState).Tensor.MustSize()
got := output.(*nn.GRUState).Tensor.MustSize()
if !reflect.DeepEqual(want, got) {
fmt.Println("Step test:")
@ -97,12 +97,12 @@ func lstmTest(rnnConfig *nn.RNNConfig, t *testing.T) {
// Step test
input := ts.MustRandn([]int64{batchDim, inputDim}, gotch.Float, gotch.CPU)
output := lstm.Step(input, lstm.ZeroState(batchDim).(nn.LSTMState))
output := lstm.Step(input, lstm.ZeroState(batchDim).(*nn.LSTMState))
wantH := []int64{layerDim, batchDim, outputDim}
gotH := output.(nn.LSTMState).Tensor1.MustSize()
gotH := output.(*nn.LSTMState).Tensor1.MustSize()
wantC := []int64{layerDim, batchDim, outputDim}
gotC := output.(nn.LSTMState).Tensor2.MustSize()
gotC := output.(*nn.LSTMState).Tensor2.MustSize()
if !reflect.DeepEqual(wantH, gotH) {
fmt.Println("Step test:")

View File

@ -1281,7 +1281,7 @@ type TopItem struct {
}
// Returns the top k classes as well as the associated scores.
func (in *ImageNet) Top(input ts.Tensor, k int64) []TopItem {
func (in *ImageNet) Top(input *ts.Tensor, k int64) []TopItem {
var tensor *ts.Tensor
shape := input.MustSize()