Details of batchnorm

Hi everyone. I am implementing a version of BatchNorm3d, and the deeper I get into it, the more I realize I do not exactly understand how BatchNorm works. Can someone take a look?

Suppose it’s BatchNorm3d, with minibatch size 8, 50 features, and image size 3x4x5. So the input to forward() has shape [8, 50, 3,4,5].

The mean and variance for the whole batch are calculated and used to update the running mean and variance.

Mean is the mean per feature (50 of them) across all pixels and samples in the batch.

But I am confused about the variance. Is it the variance within each feature’s 8 means of the minibatch? This would make sense because BatchNorm fails with a minibatch size of 1.

Or rather is it the variance per feature within all that feature’s samples and pixels of the minibatch?

Thanks for helping me to understand.

Second question… what PyTorch’s BatchNorm3d actually does seem to be inside C++ code. Where can I find that code?

Thanks! :exploding_head: