How to deal with BatchNorm and batch size of 1?

Hi all. I want to further train (fine-tune) a pre-trained model that contains many nn.BatchNorm3d’s between layers. It was originally trained with bs=8. Unfortunately my GPU and the new samples will fit only bs=1.

Anyway I run into a bizarre training issue, see
https://forums.fast.ai/t/model-train-gives-much-lower-loss-than-model-eval/83211

I suspect the problem is related to bs=1. It causes many problems for BatchNorm, because the variance of each feature becomes 0. There are various attempts to make BatchNorm work with bs=1 - but I am not sure they any good.

How would you handle this situation of training a pre-existing model with BS=1?

  • Should I simply freeze the BatchNorm’s?
  • Is there any way to collect the mean & variance across the whole epoch and feed them to the BatchNorm’s?
  • Any other pointers are welcome!

Are you running gradient accumulation? I think that could help

Grad Accum is a good idea to get a more stable optimisation, but will not fix the issue of BatchNorm.

One solution could be replace the batchnorms with (GroupNorm or LayerNorm). Other quick idea are reduce model size input or use 16 bit precision to be able to fit more than 1 item at a time.

3 Likes

Thanks for your suggestions. I will need a few days to check them out, because life.

One idea I found via a search is to freeze the mean and sd found by the pretraining, so that they are always used unchanged. Allow gamma and beta to be learned.

Stackoverflow claims that calling eval() on nn.BatchNorm will freeze the running mean and std, and allow gamma and beta to be learned.

I have never used training callbacks, and am not confident with them. Will this do what I am asking for above?

learn = Learner(data, modelR, loss_func=lossfn, cbs=BnFreeze)

Thanks for sharing your expertise.

I recall in the 2019 lectures Jeremy discussing “runningBatchNorm” in some detail to try and solve this type of issue. No idea if this made it into fastaiV2 but I would assume so.

Even better than Jeremy’s implementation of “runningBatchNorm”, which was an IIR filter on mean and variance, would be a variant of Welford’s online algorithm https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford’s_online_algorithm with a cap on the count value. Even better would be to fix some constant “maxcount”, and do an initialisation of (mean=0,M2=maxcount) and then just assume count=maxcount all the time. This has the effect of setting the first few batches as assuming a mean=0,stdev=1 (which should be the result of a good initialisation of random parameters), and then the algorithm will slowly forget those initial values and remember the more recent maxcount values. Setting maxcount to ~256 should result in a smoothly updating batchnorm as though you had a batchsize of 256, which one would expect to yield reasonably smooth statistics.

I’m not sufficiently up to speed with the internals of the library to make these sorts of modifications myself, but it looks like you guys could take this idea and run with it.

Hi hushitz. Thanks for your insights. I have actually made some progress on the issue, but have not yet posted them. Basically, it is to freeze the running means and stds as found by the pretraining, and let gamma and alpha continue to be learned. Empirically it eliminates the training anomalies. I’ll look at your ideas and post my code later. :slightly_smiling_face:

Hi @Pomo, I’m having the same issue. Do you have an update on this? It’s greatly appreciated.

It has been many months since I looked at the issue. If I recall right, I could only fit a batch size = 1. The pretrained model used batchnorm between all layers. So I wrote a batchnorm that normalized using the running mean and std, and tracked them further as batches came in. Then some code to walk the pretrained model and replace the true batchnorm. It worked well enough. If the code would be actually helpful to you, I will look for it.

Another option is to use a different norm, like Layer or Instance, that does not depend on the batch size. The whole normalization thing seems handwavy to me. I’d try to avoid batchnorm if possible, but it might not be possible if you are using a pretrained model.

HTH :slightly_smiling_face: