Stateful RNNs

目录

<!DOCTYPE html>

Stateful_RNNs

Stateful RNNs

In this reading notebook you will learn how to retain the state of an RNN when processing long sequences.

In [1]:
import tensorflow as tf
tf.__version__
Out[1]:
'2.0.0'

So far you have trained RNNs on entire sequences, possibly of varying length. In some applications, such as financial time series modeling or real-time speech processing, the input sequence can be very long.

One way to process such sequences is to simply chop up the sequences into separate batches. However, the internal state of the RNN would then normally be reset in between the batches. Persisting an RNN cell's state between batches is useful in such contexts.

Stateful and non-stateful RNN models

We will begin by creating two versions of the same RNN model. The first is a regular RNN that does not retain its state.

In [2]:
# Create a regular (non-stateful) RNN

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GRU

gru = Sequential([
    GRU(5, input_shape=(None, 1), name='rnn')
])

To persist RNN cell states between batches, you can use the stateful argument when you initialize an RNN layer. The default value of this argument is False. This argument is available for all RNN layer types.

In [3]:
# Create a stateful RNN

stateful_gru = Sequential([
    GRU(5, stateful=True, batch_input_shape=(2, None, 1), name='stateful_rnn')
])

Note that as well as setting stateful=True, we have also specified the batch_input_shape. This fixes the number of elements in a batch, as well as providing the sequence length and number of features. So the above model will always require a batch of 2 sequences.

When using stateful RNNs, it is necessary to supply this argument to the first layer of a Sequential model. This is because the model will always assume that each element of every subsequent batch it receives will be a continuation of the sequence from the corresponding element in the previous batch.

Another detail is that when defining a model with a stateful RNN using the functional API, you will need to specify the batch_shape argument as follows:

In [4]:
# Redefine the same stateful RNN using the functional API

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input

inputs = Input(batch_shape=(2, None, 1))
outputs = GRU(5, stateful=True, name='stateful_rnn')(inputs)

stateful_gru = Model(inputs=inputs, outputs=outputs)

Inspect the RNN states

We can inspect the RNN layer states by retrieving the recurrent layer from each model, and looking at the states property.

In [5]:
# Retrieve the RNN layer and inspect the internal state

gru.get_layer('rnn').states
Out[5]:
[None]
In [6]:
# Retrieve the RNN layer and inspect the internal state

stateful_gru.get_layer('stateful_rnn').states
Out[6]:
[<tf.Variable 'stateful_rnn_1/Variable:0' shape=(2, 5) dtype=float32, numpy=
 array([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], dtype=float32)>]

Note that the internal state of the stateful RNN has a state stored for each element in a batch, which is why the shape of the state Variable is (2, 5).

Create a simple sequence dataset

We will demonstrate the effect of statefulness on a simple sequence dataset consisting of two sequences.

In [7]:
# Create the sequence dataset

sequence_data = tf.constant([
    [[-4.], [-3.], [-2.], [-1.], [0.], [1.], [2.], [3.], [4.]],
    [[-40.], [-30.], [-20.], [-10.], [0.], [10.], [20.], [30.], [40.]]
], dtype=tf.float32)
sequence_data.shape
Out[7]:
TensorShape([2, 9, 1])

Process the sequence batch with both models

Now see what happens when you pass the batch of sequences through either model:

In [8]:
# Process the batch with both models

_1 = gru(sequence_data)
_2 = stateful_gru(sequence_data)
In [9]:
# Retrieve the RNN layer and inspect the internal state

gru.get_layer('rnn').states
Out[9]:
[None]
In [10]:
# Retrieve the RNN layer and inspect the internal state

stateful_gru.get_layer('stateful_rnn').states
Out[10]:
[<tf.Variable 'stateful_rnn_1/Variable:0' shape=(2, 5) dtype=float32, numpy=
 array([[-0.7095808 , -0.8731987 ,  0.38492265,  0.7823224 ,  0.5795383 ],
        [-1.        , -0.9999519 , -0.27701804,  0.9697435 , -0.1136002 ]],
       dtype=float32)>]

The stateful RNN model has updated and retained its state after having processed the input sequence batch. This internal state could then be used as the initial state for processing a continuation of both sequences in the next batch.

Resetting the internal state

If you need a stateful RNN to forget (or re-initialise) its state, then you can call an RNN layer's reset_states() method.

In [11]:
# Reset the internal state of the stateful RNN model

stateful_gru.get_layer('stateful_rnn').reset_states()
In [12]:
# Retrieve the RNN layer and inspect the internal state

stateful_gru.get_layer('stateful_rnn').states
Out[12]:
[<tf.Variable 'stateful_rnn_1/Variable:0' shape=(2, 5) dtype=float32, numpy=
 array([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], dtype=float32)>]

Note that reset_states() resets the state to 0., which is the default initial state for the RNN layers in TensorFlow.

Retaining internal state across batches

Passing a sequence to a stateful layer as several subsequences produces the same final output as passing the whole sequence at once.

In [13]:
# Reset the internal state of the stateful RNN model and process the full sequences

stateful_gru.get_layer('stateful_rnn').reset_states()
_ = stateful_gru(sequence_data)
stateful_gru.get_layer('stateful_rnn').states
Out[13]:
[<tf.Variable 'stateful_rnn_1/Variable:0' shape=(2, 5) dtype=float32, numpy=
 array([[-0.7095808 , -0.8731987 ,  0.38492265,  0.7823224 ,  0.5795383 ],
        [-1.        , -0.9999519 , -0.27701804,  0.9697435 , -0.1136002 ]],
       dtype=float32)>]
In [14]:
# Break the sequences into batches

sequence_batch1 = sequence_data[:, :3, :]
sequence_batch2 = sequence_data[:, 3:6, :]
sequence_batch3 = sequence_data[:, 6:, :]

print("First batch:", sequence_batch1)
print("\nSecond batch:", sequence_batch2)
print("\nThird batch:", sequence_batch3)
First batch: tf.Tensor(
[[[ -4.]
  [ -3.]
  [ -2.]]

 [[-40.]
  [-30.]
  [-20.]]], shape=(2, 3, 1), dtype=float32)

Second batch: tf.Tensor(
[[[ -1.]
  [  0.]
  [  1.]]

 [[-10.]
  [  0.]
  [ 10.]]], shape=(2, 3, 1), dtype=float32)

Third batch: tf.Tensor(
[[[ 2.]
  [ 3.]
  [ 4.]]

 [[20.]
  [30.]
  [40.]]], shape=(2, 3, 1), dtype=float32)

Note that the first element in every batch is part of the same sequence, and the second element in every batch is part of the same sequence.

In [18]:
# Reset the internal state of the stateful RNN model and process the batches in order

stateful_gru.get_layer('stateful_rnn').reset_states()
_ = stateful_gru(sequence_batch1)
_ = stateful_gru(sequence_batch2)
_ = stateful_gru(sequence_batch3)
stateful_gru.get_layer('stateful_rnn').states
Out[18]:
[<tf.Variable 'stateful_rnn_1/Variable:0' shape=(2, 5) dtype=float32, numpy=
 array([[-0.7095808 , -0.8731987 ,  0.38492265,  0.7823224 ,  0.5795383 ],
        [-1.        , -0.9999519 , -0.27701804,  0.9697435 , -0.1136002 ]],
       dtype=float32)>]

Notice that the internal state of the stateful RNN after processing each batch is the same as it was earlier when we processed the entire sequence at once.

This property can be used when training stateful RNNs, if we ensure that each example in a batch is a continuation of the same sequence as the corresponding example in the previous batch.