How to quantize weights in forward pass during training in Keras?
In Keras, I would like to train a network with binary weights in the manner of Coubariaux, et al., but I cannot figure out where the quantization (binarization) should occur within the code.
A core aspect of the training method is this:
- At the beginning of each batch during training, the stored real (e.g., int32) weights are converted to binary values (either by rounding or in a stochastic/probabilistic manner) and stored separately from the real-valued weights.
- Binary-valued weights are used in the forward pass, to compute activations.
- Real-valued weights are used in the backward pass, to compute the gradients.
- Weight updates are applied to the real-valued weights, not the binary ones.
- Binary-valued weights are not altered until the next batch, at which time they are recomputed by binarizing the recently updated real-valued weights.
If I do the binarization in Layer.call()
, I believe it will occur for every forward pass (for every sample), but it should only occur once per batch.
If I do the binarization in a Callback.on_batch_begin()
, I don't think I can specify the use of binary weights for the forward pass, and real-valued weights for gradient computation.
Any suggestions? Unfortunately, my Python knowledge is not deep, so it's been somewhat challenging for me to understand the flow of the code.
**Note that when I talk about binary values, I do not mean 1-bit (e.g., 8 binary values placed into an int8). The binary-valued weights can still be represented by int32, float32, whatever.
Topic keras neural-network binary
Category Data Science