This flag is used to have truncated back-propagation through time: the gradient is propagated through the hidden states of the LSTM across the time dimension in the batch and then, in the next batch, the last hidden states are used as input states for the LSTM.
This allows the LSTM to use longer context at training time while constraining the number of steps back for the gradient computation.
I know of two scenarios where this is common:
- Language modeling (LM).
- Time series modeling.
The training set is a list of sequences, potentially coming from a few documents (LM) or complete time series. During data preparation, the batches are created so that each sequence in a batch is the continuation of the sequence at the same position in the previous batch. This allows having document-level/long time series context when computing predictions.
In these cases, your data is longer than the sequence length dimension in the batch. This may be due to constraints in the available GPU memory (therefore limiting the maximum batch size) or by design due to any other reasons.
Update: Note that the stateful
flag affects both training and inference time. If you disable it, you must ensure that at inference time each prediction gets the previous hidden state. For this, you can either create a new model with stateful=True
and copy the parameters from the trained model with model.set_weights()
or pass it manually. Due to this inconvenience, some people simply set stateful = True
always and force the model not use the stored hidden state during training by invoking model.reset_states()
.