DEPRECATED

Since writing this, I learned about a paper describing Online Normalization, solving the problem I was trying to solve here, but doing a much better job. If you use the method I describe below, it probably won't work well. Luckily now I have a simplified version of this implemented in Cottonwood. As of version 28, a line like

import cottonwood.experimental.online_normalization.OnlineNormalization
will get it into your code. To get a deeper sense of what batch normalization is doing and how, check out this post on the topic.



Batch Normalization (BN) is a learned element-wise shift and scale transformation used in deep neural networks. (Originally presented in
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
by Sergey Ioffe and Christian Szegedy
https://arxiv.org/abs/1502.03167
)
Nobody is completely sure why, but everyone agrees that it works really well. It's not perfect. There are variations and improvements with names like Batch Renormalization, Streaming Normalization, Online Normalization, and Filter Response Normalization. But straight up BN has been used for so long in so many places that it can't be ignored.

The challenge of using BN is that it operates on entire batches of inputs at once, up to several thousand. It uses the mean and variance of individual neurons across the entire batch to make adjustments to the scale and shift it applies to that neuron's activity. This can require a tremendous amount of memory. This is well suited to large parallel computation hardware, such as GPUs (and often runs into resource limitations even there), but it's not well suited to users running ML on their local CPUs. Cottonwood is committed to remain CPU-friendly, but until this point has lacked BN capability. This implementation has been modified to be online, that is, it operates on one input at a time.

There are two modifications to the original BN method necessary to make this happen. (This explanation assumes you are familiar with BN. Feel free to refer the original paper or to the informative Wikipedia page.) First, the batch statistics, batch mean and batch variance, as well as the partial derivatives of the batch statistics, dL_d(batch mean) and dL_d(batch variance), are taken from the previous batch, rather than the current one. They can only be calculated at the end of a batch, so they are not available for the current batch.

Second, the updated shift (beta) and scale (gamma) parameters aren't applied to the current batch, but to the following batch. The update can't take place until the gradient of both dL_d(shift) and dL_d(scale) are calculated, which doesn't happen until the batch is complete.

As a result of these two changes, forward and backward passes are calculated using parameters from the previous batch rather than the current one. I expect shift and scale to change only gradually, so using the one-batch-delayed version of them should not introduce a significant deviation in results. The dL_d(batch mean) and dL_d(batch variance) are a different story. They may change more significantly from batch to batch. If they do, they might introduce some instabilities. It may mean that this online variant of BN can't get quite as aggressive with its learning rates as original BN.

A gratuitous modification I made was to calculate a running estimate of the population statistics, population mean and variance, and to use these (rather than batch statistics) as the basis for normalizing inputs in the forward pass. I estimate population mean and population variance using an exponential-decay weighted average, which is elegant to compute online. This serves two purposes. The exponential decay serves both to emphasize the most recent batches (which may be different from early batches) and also to mitigate sensitivity to individual noisy batches. It is also a nice middle ground between using batch statistics during training and population statistics during evaluation. It has the important attributes of both (current relevance and long memory) and so does not require BN to know whether you are training or evaluating. it allows you to use the same forward pass for both.

Online Batch Normalization is implemented in Cottonwood, and resides in the experimental/blocks/normaliztion.py module. Check it out if you'd like to learn more.

You can check out other research ideas and things that I've built here.