nn/rnn: corrected interface type conversion
This commit is contained in:
parent
b69d46eae4
commit
5414b6ed57
|
@ -18,7 +18,7 @@ const (
|
||||||
SamplingLen int64 = 1024
|
SamplingLen int64 = 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Device) (retVal string) {
|
func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.Device) string {
|
||||||
|
|
||||||
labels := data.Labels()
|
labels := data.Labels()
|
||||||
inState := lstm.ZeroState(1)
|
inState := lstm.ZeroState(1)
|
||||||
|
@ -34,15 +34,15 @@ func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Devic
|
||||||
state := lstm.Step(input, inState)
|
state := lstm.Step(input, inState)
|
||||||
|
|
||||||
// 1. Delete inState tensors (from C land memory)
|
// 1. Delete inState tensors (from C land memory)
|
||||||
inState.(nn.LSTMState).Tensor1.MustDrop()
|
inState.(*nn.LSTMState).Tensor1.MustDrop()
|
||||||
inState.(nn.LSTMState).Tensor2.MustDrop()
|
inState.(*nn.LSTMState).Tensor2.MustDrop()
|
||||||
// 2. Then update with current state
|
// 2. Then update with current state
|
||||||
inState = state
|
inState = state
|
||||||
// 3. Delete intermediate tensors
|
// 3. Delete intermediate tensors
|
||||||
input.MustDrop()
|
input.MustDrop()
|
||||||
inputView.MustDrop()
|
inputView.MustDrop()
|
||||||
|
|
||||||
forwardTs := linear.Forward(state.(nn.LSTMState).H()).MustSqueeze1(0, true).MustSoftmax(-1, gotch.Float, true)
|
forwardTs := linear.Forward(state.(*nn.LSTMState).H()).MustSqueeze1(0, true).MustSoftmax(-1, gotch.Float, true)
|
||||||
sampledY := forwardTs.MustMultinomial(1, false, true)
|
sampledY := forwardTs.MustMultinomial(1, false, true)
|
||||||
lastLabel = sampledY.Int64Values()[0]
|
lastLabel = sampledY.Int64Values()[0]
|
||||||
sampledY.MustDrop()
|
sampledY.MustDrop()
|
||||||
|
@ -52,8 +52,8 @@ func sample(data ts.TextData, lstm nn.LSTM, linear nn.Linear, device gotch.Devic
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete the last state
|
// Delete the last state
|
||||||
inState.(nn.LSTMState).Tensor1.MustDrop()
|
inState.(*nn.LSTMState).Tensor1.MustDrop()
|
||||||
inState.(nn.LSTMState).Tensor2.MustDrop()
|
inState.(*nn.LSTMState).Tensor2.MustDrop()
|
||||||
|
|
||||||
return string(runes)
|
return string(runes)
|
||||||
}
|
}
|
||||||
|
@ -104,8 +104,8 @@ func main() {
|
||||||
lstmOut, outState := lstm.Seq(xsOnehot)
|
lstmOut, outState := lstm.Seq(xsOnehot)
|
||||||
// NOTE. Although outState will not be used. There a hidden memory usage
|
// NOTE. Although outState will not be used. There a hidden memory usage
|
||||||
// on C land memory that is needed to free up. Don't use `_`
|
// on C land memory that is needed to free up. Don't use `_`
|
||||||
outState.(nn.LSTMState).Tensor1.MustDrop()
|
outState.(*nn.LSTMState).Tensor1.MustDrop()
|
||||||
outState.(nn.LSTMState).Tensor2.MustDrop()
|
outState.(*nn.LSTMState).Tensor2.MustDrop()
|
||||||
|
|
||||||
logits := linear.Forward(lstmOut)
|
logits := linear.Forward(lstmOut)
|
||||||
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true)
|
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true)
|
||||||
|
|
34
nn/rnn.go
34
nn/rnn.go
|
@ -124,7 +124,7 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
|
||||||
// Implement RNN interface for LSTM:
|
// Implement RNN interface for LSTM:
|
||||||
// =================================
|
// =================================
|
||||||
|
|
||||||
func (l *LSTM) ZeroState(batchDim int64) (retVal State) {
|
func (l *LSTM) ZeroState(batchDim int64) State {
|
||||||
var numDirections int64 = 1
|
var numDirections int64 = 1
|
||||||
if l.config.Bidirectional {
|
if l.config.Bidirectional {
|
||||||
numDirections = 2
|
numDirections = 2
|
||||||
|
@ -134,7 +134,7 @@ func (l *LSTM) ZeroState(batchDim int64) (retVal State) {
|
||||||
shape := []int64{layerDim, batchDim, l.hiddenDim}
|
shape := []int64{layerDim, batchDim, l.hiddenDim}
|
||||||
zeros := ts.MustZeros(shape, gotch.Float, l.device)
|
zeros := ts.MustZeros(shape, gotch.Float, l.device)
|
||||||
|
|
||||||
retVal = LSTMState{
|
retVal := &LSTMState{
|
||||||
Tensor1: zeros.MustShallowClone(),
|
Tensor1: zeros.MustShallowClone(),
|
||||||
Tensor2: zeros.MustShallowClone(),
|
Tensor2: zeros.MustShallowClone(),
|
||||||
}
|
}
|
||||||
|
@ -144,7 +144,7 @@ func (l *LSTM) ZeroState(batchDim int64) (retVal State) {
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *LSTM) Step(input *ts.Tensor, inState State) (retVal State) {
|
func (l *LSTM) Step(input *ts.Tensor, inState State) State {
|
||||||
ip := input.MustUnsqueeze(1, false)
|
ip := input.MustUnsqueeze(1, false)
|
||||||
|
|
||||||
output, state := l.SeqInit(ip, inState)
|
output, state := l.SeqInit(ip, inState)
|
||||||
|
@ -156,24 +156,24 @@ func (l *LSTM) Step(input *ts.Tensor, inState State) (retVal State) {
|
||||||
return state
|
return state
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *LSTM) Seq(input *ts.Tensor) (output *ts.Tensor, state State) {
|
func (l *LSTM) Seq(input *ts.Tensor) (*ts.Tensor, State) {
|
||||||
batchDim := input.MustSize()[0]
|
batchDim := input.MustSize()[0]
|
||||||
inState := l.ZeroState(batchDim)
|
inState := l.ZeroState(batchDim)
|
||||||
|
|
||||||
output, state = l.SeqInit(input, inState)
|
output, state := l.SeqInit(input, inState)
|
||||||
|
|
||||||
// Delete intermediate tensors in inState
|
// Delete intermediate tensors in inState
|
||||||
inState.(LSTMState).Tensor1.MustDrop()
|
inState.(*LSTMState).Tensor1.MustDrop()
|
||||||
inState.(LSTMState).Tensor2.MustDrop()
|
inState.(*LSTMState).Tensor2.MustDrop()
|
||||||
|
|
||||||
return output, state
|
return output, state
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *LSTM) SeqInit(input *ts.Tensor, inState State) (*ts.Tensor, State) {
|
func (l *LSTM) SeqInit(input *ts.Tensor, inState State) (*ts.Tensor, State) {
|
||||||
|
|
||||||
output, h, c := input.MustLstm([]ts.Tensor{*inState.(LSTMState).Tensor1, *inState.(LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
|
output, h, c := input.MustLstm([]ts.Tensor{*inState.(*LSTMState).Tensor1, *inState.(*LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
|
||||||
|
|
||||||
return output, LSTMState{
|
return output, &LSTMState{
|
||||||
Tensor1: h,
|
Tensor1: h,
|
||||||
Tensor2: c,
|
Tensor2: c,
|
||||||
}
|
}
|
||||||
|
@ -243,7 +243,7 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) (retVal *GRU) {
|
||||||
// Implement RNN interface for GRU:
|
// Implement RNN interface for GRU:
|
||||||
// ================================
|
// ================================
|
||||||
|
|
||||||
func (g *GRU) ZeroState(batchDim int64) (retVal State) {
|
func (g *GRU) ZeroState(batchDim int64) State {
|
||||||
var numDirections int64 = 1
|
var numDirections int64 = 1
|
||||||
if g.config.Bidirectional {
|
if g.config.Bidirectional {
|
||||||
numDirections = 2
|
numDirections = 2
|
||||||
|
@ -254,10 +254,10 @@ func (g *GRU) ZeroState(batchDim int64) (retVal State) {
|
||||||
|
|
||||||
tensor := ts.MustZeros(shape, gotch.Float, g.device)
|
tensor := ts.MustZeros(shape, gotch.Float, g.device)
|
||||||
|
|
||||||
return GRUState{Tensor: tensor}
|
return &GRUState{Tensor: tensor}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GRU) Step(input *ts.Tensor, inState State) (retVal State) {
|
func (g *GRU) Step(input *ts.Tensor, inState State) State {
|
||||||
unsqueezedInput := input.MustUnsqueeze(1, false)
|
unsqueezedInput := input.MustUnsqueeze(1, false)
|
||||||
output, state := g.SeqInit(unsqueezedInput, inState)
|
output, state := g.SeqInit(unsqueezedInput, inState)
|
||||||
|
|
||||||
|
@ -269,21 +269,21 @@ func (g *GRU) Step(input *ts.Tensor, inState State) (retVal State) {
|
||||||
return state
|
return state
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GRU) Seq(input *ts.Tensor) (output *ts.Tensor, state State) {
|
func (g *GRU) Seq(input *ts.Tensor) (*ts.Tensor, State) {
|
||||||
batchDim := input.MustSize()[0]
|
batchDim := input.MustSize()[0]
|
||||||
inState := g.ZeroState(batchDim)
|
inState := g.ZeroState(batchDim)
|
||||||
|
|
||||||
output, state = g.SeqInit(input, inState)
|
output, state := g.SeqInit(input, inState)
|
||||||
|
|
||||||
// Delete intermediate tensors in inState
|
// Delete intermediate tensors in inState
|
||||||
inState.(GRUState).Tensor.MustDrop()
|
inState.(*GRUState).Tensor.MustDrop()
|
||||||
|
|
||||||
return output, state
|
return output, state
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GRU) SeqInit(input *ts.Tensor, inState State) (*ts.Tensor, State) {
|
func (g *GRU) SeqInit(input *ts.Tensor, inState State) (*ts.Tensor, State) {
|
||||||
|
|
||||||
output, h := input.MustGru(inState.(GRUState).Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst)
|
output, h := input.MustGru(inState.(*GRUState).Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst)
|
||||||
|
|
||||||
return output, GRUState{Tensor: h}
|
return output, &GRUState{Tensor: h}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,10 +32,10 @@ func gruTest(rnnConfig *nn.RNNConfig, t *testing.T) {
|
||||||
|
|
||||||
// Step test
|
// Step test
|
||||||
input := ts.MustRandn([]int64{batchDim, inputDim}, gotch.Float, gotch.CPU)
|
input := ts.MustRandn([]int64{batchDim, inputDim}, gotch.Float, gotch.CPU)
|
||||||
output := gru.Step(input, gru.ZeroState(batchDim).(nn.GRUState))
|
output := gru.Step(input, gru.ZeroState(batchDim).(*nn.GRUState))
|
||||||
|
|
||||||
want := []int64{layerDim, batchDim, outputDim}
|
want := []int64{layerDim, batchDim, outputDim}
|
||||||
got := output.(nn.GRUState).Tensor.MustSize()
|
got := output.(*nn.GRUState).Tensor.MustSize()
|
||||||
|
|
||||||
if !reflect.DeepEqual(want, got) {
|
if !reflect.DeepEqual(want, got) {
|
||||||
fmt.Println("Step test:")
|
fmt.Println("Step test:")
|
||||||
|
@ -97,12 +97,12 @@ func lstmTest(rnnConfig *nn.RNNConfig, t *testing.T) {
|
||||||
|
|
||||||
// Step test
|
// Step test
|
||||||
input := ts.MustRandn([]int64{batchDim, inputDim}, gotch.Float, gotch.CPU)
|
input := ts.MustRandn([]int64{batchDim, inputDim}, gotch.Float, gotch.CPU)
|
||||||
output := lstm.Step(input, lstm.ZeroState(batchDim).(nn.LSTMState))
|
output := lstm.Step(input, lstm.ZeroState(batchDim).(*nn.LSTMState))
|
||||||
|
|
||||||
wantH := []int64{layerDim, batchDim, outputDim}
|
wantH := []int64{layerDim, batchDim, outputDim}
|
||||||
gotH := output.(nn.LSTMState).Tensor1.MustSize()
|
gotH := output.(*nn.LSTMState).Tensor1.MustSize()
|
||||||
wantC := []int64{layerDim, batchDim, outputDim}
|
wantC := []int64{layerDim, batchDim, outputDim}
|
||||||
gotC := output.(nn.LSTMState).Tensor2.MustSize()
|
gotC := output.(*nn.LSTMState).Tensor2.MustSize()
|
||||||
|
|
||||||
if !reflect.DeepEqual(wantH, gotH) {
|
if !reflect.DeepEqual(wantH, gotH) {
|
||||||
fmt.Println("Step test:")
|
fmt.Println("Step test:")
|
||||||
|
|
|
@ -1281,7 +1281,7 @@ type TopItem struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the top k classes as well as the associated scores.
|
// Returns the top k classes as well as the associated scores.
|
||||||
func (in *ImageNet) Top(input ts.Tensor, k int64) []TopItem {
|
func (in *ImageNet) Top(input *ts.Tensor, k int64) []TopItem {
|
||||||
|
|
||||||
var tensor *ts.Tensor
|
var tensor *ts.Tensor
|
||||||
shape := input.MustSize()
|
shape := input.MustSize()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user