nn/rnn: corrected interface type conversion
This commit is contained in:
parent
b69d46eae4
commit
5414b6ed57
|
@ -18,7 +18,7 @@ const (
|
|||
SamplingLen int64 = 1024
|
||||
)
|
||||
|
||||
func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Device) (retVal string) {
|
||||
func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.Device) string {
|
||||
|
||||
labels := data.Labels()
|
||||
inState := lstm.ZeroState(1)
|
||||
|
@ -34,15 +34,15 @@ func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Devic
|
|||
state := lstm.Step(input, inState)
|
||||
|
||||
// 1. Delete inState tensors (from C land memory)
|
||||
inState.(nn.LSTMState).Tensor1.MustDrop()
|
||||
inState.(nn.LSTMState).Tensor2.MustDrop()
|
||||
inState.(*nn.LSTMState).Tensor1.MustDrop()
|
||||
inState.(*nn.LSTMState).Tensor2.MustDrop()
|
||||
// 2. Then update with current state
|
||||
inState = state
|
||||
// 3. Delete intermediate tensors
|
||||
input.MustDrop()
|
||||
inputView.MustDrop()
|
||||
|
||||
forwardTs := linear.Forward(state.(nn.LSTMState).H()).MustSqueeze1(0, true).MustSoftmax(-1, gotch.Float, true)
|
||||
forwardTs := linear.Forward(state.(*nn.LSTMState).H()).MustSqueeze1(0, true).MustSoftmax(-1, gotch.Float, true)
|
||||
sampledY := forwardTs.MustMultinomial(1, false, true)
|
||||
lastLabel = sampledY.Int64Values()[0]
|
||||
sampledY.MustDrop()
|
||||
|
@ -52,8 +52,8 @@ func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Devic
|
|||
}
|
||||
|
||||
// Delete the last state
|
||||
inState.(nn.LSTMState).Tensor1.MustDrop()
|
||||
inState.(nn.LSTMState).Tensor2.MustDrop()
|
||||
inState.(*nn.LSTMState).Tensor1.MustDrop()
|
||||
inState.(*nn.LSTMState).Tensor2.MustDrop()
|
||||
|
||||
return string(runes)
|
||||
}
|
||||
|
@ -104,8 +104,8 @@ func main() {
|
|||
lstmOut, outState := lstm.Seq(xsOnehot)
|
||||
// NOTE. Although outState will not be used. There a hidden memory usage
|
||||
// on C land memory that is needed to free up. Don't use `_`
|
||||
outState.(nn.LSTMState).Tensor1.MustDrop()
|
||||
outState.(nn.LSTMState).Tensor2.MustDrop()
|
||||
outState.(*nn.LSTMState).Tensor1.MustDrop()
|
||||
outState.(*nn.LSTMState).Tensor2.MustDrop()
|
||||
|
||||
logits := linear.Forward(lstmOut)
|
||||
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true)
|
||||
|
|
34
nn/rnn.go
34
nn/rnn.go
|
@ -124,7 +124,7 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
|
|||
// Implement RNN interface for LSTM:
|
||||
// =================================
|
||||
|
||||
func (l *LSTM) ZeroState(batchDim int64) (retVal State) {
|
||||
func (l *LSTM) ZeroState(batchDim int64) State {
|
||||
var numDirections int64 = 1
|
||||
if l.config.Bidirectional {
|
||||
numDirections = 2
|
||||
|
@ -134,7 +134,7 @@ func (l *LSTM) ZeroState(batchDim int64) (retVal State) {
|
|||
shape := []int64{layerDim, batchDim, l.hiddenDim}
|
||||
zeros := ts.MustZeros(shape, gotch.Float, l.device)
|
||||
|
||||
retVal = LSTMState{
|
||||
retVal := &LSTMState{
|
||||
Tensor1: zeros.MustShallowClone(),
|
||||
Tensor2: zeros.MustShallowClone(),
|
||||
}
|
||||
|
@ -144,7 +144,7 @@ func (l *LSTM) ZeroState(batchDim int64) (retVal State) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (l *LSTM) Step(input *ts.Tensor, inState State) (retVal State) {
|
||||
func (l *LSTM) Step(input *ts.Tensor, inState State) State {
|
||||
ip := input.MustUnsqueeze(1, false)
|
||||
|
||||
output, state := l.SeqInit(ip, inState)
|
||||
|
@ -156,24 +156,24 @@ func (l *LSTM) Step(input *ts.Tensor, inState State) (retVal State) {
|
|||
return state
|
||||
}
|
||||
|
||||
func (l *LSTM) Seq(input *ts.Tensor) (output *ts.Tensor, state State) {
|
||||
func (l *LSTM) Seq(input *ts.Tensor) (*ts.Tensor, State) {
|
||||
batchDim := input.MustSize()[0]
|
||||
inState := l.ZeroState(batchDim)
|
||||
|
||||
output, state = l.SeqInit(input, inState)
|
||||
output, state := l.SeqInit(input, inState)
|
||||
|
||||
// Delete intermediate tensors in inState
|
||||
inState.(LSTMState).Tensor1.MustDrop()
|
||||
inState.(LSTMState).Tensor2.MustDrop()
|
||||
inState.(*LSTMState).Tensor1.MustDrop()
|
||||
inState.(*LSTMState).Tensor2.MustDrop()
|
||||
|
||||
return output, state
|
||||
}
|
||||
|
||||
func (l *LSTM) SeqInit(input *ts.Tensor, inState State) (*ts.Tensor, State) {
|
||||
|
||||
output, h, c := input.MustLstm([]ts.Tensor{*inState.(LSTMState).Tensor1, *inState.(LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
|
||||
output, h, c := input.MustLstm([]ts.Tensor{*inState.(*LSTMState).Tensor1, *inState.(*LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
|
||||
|
||||
return output, LSTMState{
|
||||
return output, &LSTMState{
|
||||
Tensor1: h,
|
||||
Tensor2: c,
|
||||
}
|
||||
|
@ -243,7 +243,7 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) (retVal *GRU) {
|
|||
// Implement RNN interface for GRU:
|
||||
// ================================
|
||||
|
||||
func (g *GRU) ZeroState(batchDim int64) (retVal State) {
|
||||
func (g *GRU) ZeroState(batchDim int64) State {
|
||||
var numDirections int64 = 1
|
||||
if g.config.Bidirectional {
|
||||
numDirections = 2
|
||||
|
@ -254,10 +254,10 @@ func (g *GRU) ZeroState(batchDim int64) (retVal State) {
|
|||
|
||||
tensor := ts.MustZeros(shape, gotch.Float, g.device)
|
||||
|
||||
return GRUState{Tensor: tensor}
|
||||
return &GRUState{Tensor: tensor}
|
||||
}
|
||||
|
||||
func (g *GRU) Step(input *ts.Tensor, inState State) (retVal State) {
|
||||
func (g *GRU) Step(input *ts.Tensor, inState State) State {
|
||||
unsqueezedInput := input.MustUnsqueeze(1, false)
|
||||
output, state := g.SeqInit(unsqueezedInput, inState)
|
||||
|
||||
|
@ -269,21 +269,21 @@ func (g *GRU) Step(input *ts.Tensor, inState State) (retVal State) {
|
|||
return state
|
||||
}
|
||||
|
||||
func (g *GRU) Seq(input *ts.Tensor) (output *ts.Tensor, state State) {
|
||||
func (g *GRU) Seq(input *ts.Tensor) (*ts.Tensor, State) {
|
||||
batchDim := input.MustSize()[0]
|
||||
inState := g.ZeroState(batchDim)
|
||||
|
||||
output, state = g.SeqInit(input, inState)
|
||||
output, state := g.SeqInit(input, inState)
|
||||
|
||||
// Delete intermediate tensors in inState
|
||||
inState.(GRUState).Tensor.MustDrop()
|
||||
inState.(*GRUState).Tensor.MustDrop()
|
||||
|
||||
return output, state
|
||||
}
|
||||
|
||||
func (g *GRU) SeqInit(input *ts.Tensor, inState State) (*ts.Tensor, State) {
|
||||
|
||||
output, h := input.MustGru(inState.(GRUState).Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst)
|
||||
output, h := input.MustGru(inState.(*GRUState).Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst)
|
||||
|
||||
return output, GRUState{Tensor: h}
|
||||
return output, &GRUState{Tensor: h}
|
||||
}
|
||||
|
|
|
@ -32,10 +32,10 @@ func gruTest(rnnConfig *nn.RNNConfig, t *testing.T) {
|
|||
|
||||
// Step test
|
||||
input := ts.MustRandn([]int64{batchDim, inputDim}, gotch.Float, gotch.CPU)
|
||||
output := gru.Step(input, gru.ZeroState(batchDim).(nn.GRUState))
|
||||
output := gru.Step(input, gru.ZeroState(batchDim).(*nn.GRUState))
|
||||
|
||||
want := []int64{layerDim, batchDim, outputDim}
|
||||
got := output.(nn.GRUState).Tensor.MustSize()
|
||||
got := output.(*nn.GRUState).Tensor.MustSize()
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
fmt.Println("Step test:")
|
||||
|
@ -97,12 +97,12 @@ func lstmTest(rnnConfig *nn.RNNConfig, t *testing.T) {
|
|||
|
||||
// Step test
|
||||
input := ts.MustRandn([]int64{batchDim, inputDim}, gotch.Float, gotch.CPU)
|
||||
output := lstm.Step(input, lstm.ZeroState(batchDim).(nn.LSTMState))
|
||||
output := lstm.Step(input, lstm.ZeroState(batchDim).(*nn.LSTMState))
|
||||
|
||||
wantH := []int64{layerDim, batchDim, outputDim}
|
||||
gotH := output.(nn.LSTMState).Tensor1.MustSize()
|
||||
gotH := output.(*nn.LSTMState).Tensor1.MustSize()
|
||||
wantC := []int64{layerDim, batchDim, outputDim}
|
||||
gotC := output.(nn.LSTMState).Tensor2.MustSize()
|
||||
gotC := output.(*nn.LSTMState).Tensor2.MustSize()
|
||||
|
||||
if !reflect.DeepEqual(wantH, gotH) {
|
||||
fmt.Println("Step test:")
|
||||
|
|
|
@ -1281,7 +1281,7 @@ type TopItem struct {
|
|||
}
|
||||
|
||||
// Returns the top k classes as well as the associated scores.
|
||||
func (in *ImageNet) Top(input ts.Tensor, k int64) []TopItem {
|
||||
func (in *ImageNet) Top(input *ts.Tensor, k int64) []TopItem {
|
||||
|
||||
var tensor *ts.Tensor
|
||||
shape := input.MustSize()
|
||||
|
|
Loading…
Reference in New Issue
Block a user