Can Batch Normalization replace tanh in RNN?
Question
Can Batch Normalization (BN
) be inserted in RNN after $x_t@W_{xh}$, and after $h_{t-1}@W_{hh}$ to remove $f=tanh$ and bias $b_h$? If possible, will this eliminate both exploding and vanishing gradient problems?
I believe the effect of tanh
to adjust the values from [-inf, +inf]
into (-1, 1)
can be replaced with the standardization in BN
and it makes the bias unnecessary at $x_t@W_{xh}$ and $h_{t-1}@W_{hh}$.
The auto differentiation of tanh should be able to be replaced with those of BN too.
Background
It is said tanh addresses the exploding gradient problem in RNN, hence it is being used.
An issue with recurrent neural networks is potentially exploding gradients given the repeated back-propagation mechanism.
After the addition operator the absolute value of c(t) is potentially larger than 1. Passing it through a tanh operator ensures the values are scaled between -1 and 1 again, thus increasing stability during back-propagation over many timesteps.
The vanishing gradient problem is the main problem in RNN. Also, to keep the gradient in the linear region of the activation function, we need a function whose second derivative can sustain for a long range before going to zero. Tanh is pretty good with these properties.
Batch Normalization can address both exploding and vanishing gradient problems.
Solving the vanishing gradient problem. Although internal covariate shift may not improve accuracy, it is somewhat involved in the vanishing gradient problem. When the distribution of inputs shifts, it will fall prone to inherently smaller gradient updates due to the activation functions (for example, sigmoid, which puts miniscule values at anything ±2.5, or ReLU, which sets any x 0 to 0). Batch normalization helps make sure that the signal is heard and not diminished by shifting distributions from the end to the beginning of the network during backpropagation.
Solving the exploding gradient problem. As batch normalization smooths the optimization landscape, it gets rid of the extreme gradients that accumulate, leading to the elimination of the major weight fluctuations that result from gradient build-up. This dramatically stabilizes learning.
Research
There are several articles and papers which suggest it is possible but could not find a simple implementation diagram or a code example.
For RNNs, this means computing the relevant statistics over the mini-batch and the time/step dimension, so the normalization is applied only over the vector depths. This also means that you only batch normalize the transformed input (so in the vertical directions, e.g. BN(W_x * x)) since the horizontal (across time) connections are time-dependent and shouldn't just be plainly averaged.
We propose a reparameterization of LSTM that brings the benefits of batch normalization to recurrent neural networks. Whereas previous works only apply batch normalization to the input-to-hidden transformation of RNNs, we demonstrate that it is both possible and beneficial to batch-normalize the hidden-to-hidden transition, thereby reducing internal covariate shift between time steps.
Although batch normalization has demonstrated significant training speed-ups and generalization benefits in feed-forward networks, it is proven to be difficult to apply in recurrent architectures (Laurent et al., 2016; Amodei et al., 2015). It has found limited use in stacked RNNs, where the normalization is applied “vertically”, i.e. to the input of each RNN, but not “horizontally” between timesteps. RNNs are deeper in the time direction, and as such batch normalization would be most beneficial when applied horizontally. However, Laurent et al. (2016) hypothesized that applying batch normalization in this way hurts training because of exploding gradients due to repeated rescaling. Our findings run counter to this hypothesis. We show that it is both possible and highly beneficial to apply batch normalization in the hidden-to-hidden transition of recurrent models.
Topic batch-normalization rnn
Category Data Science