Batch Renormalization

Hi,

Recently I’ve found batch renormalization paper after investigating instability in training ( fluctuating validation loss) due to small batch size because of memory constraint - high res segmentation task. One option is to freeze or drop the batchnorm layer but the paper states that it substantially improves the training-inference performance. Does anyone have experience with it or a readily available implementation in pytorch ?

Thanks a lot !

I don`t have experience with batch renormalization but a very good alternative is to use gradient checkpointing :
Pytorch : https://pytorch.org/docs/stable/checkpoint.html
Tensorflow : https://github.com/openai/gradient-checkpointing

This allows to create larger batches on higher resolution images to compensate batch norm instability with small batches. Very useful in medical imaging overall.

Group norm (https://arxiv.org/abs/1803.08494) is also a great theoretical alternative but I never found pretrained weights on imagenet using group norm.

3 Likes

Thank you so much

This looks like an interesting avenue, being used by the #1 leader on the current kaggle airbus competition. https://github.com/mapillary/inplace_abn/blob/master/README.md

2 Likes

I’ve seen you on the leaderboard :slight_smile: Let me check that too thanks a lot !

I’ve tried their wideresnet deeplab v3 model but I am able to fit even less sample per batch compared to resnet_{18, 34, 50} unet variants.

Did you get good results using inplace abn model in that github repo ? My best model is still single resnet18 Dynamic Unet LB: 0.731 with TTA but unfortunately couldn’t make deeplab v3 to give good results.

Shame it didn’t work. No such luck here either.