 # Questions about RunningBatchNorm in Lesson 10

I’ve spent a fair amount of effort trying to understand the implementation of `RunningBatchNorm` algorithm in Lesson 10, including carefully examining the code and watching/listening to the relevant part of the video near the end of Lesson 10 several times.

I found parts of the code to be confusing. In particular, it would be great to have a clear explanation of the logic behind the equations for `mom1` and `self.dbias`, and their use in the subsequent code. Also, I don’t see how an exponentially weighted average of the number of samples in a batch is useful, since batch sizes are usually all the same except for the last batch

Below is the code for `RunningBatchNorm` from the Lesson 10 notebook `07_batchnorm.ipynb` that is currently in the /fastai/course-3 git repo.

I also added questions (marked by `Q`) that I haven’t been able to answer.

Could @jeremy , @stas , or @Sylvain (or whoever wrote this code) please help clarify the `RunningBatchNorm` algorithm by addressing these questions?

==============================================

class RunningBatchNorm(nn.Module):

``````# initialize
def __init__(self, nf, mom=0.1, eps=1e-5):
super().__init__()

# constants
self.mom,self.eps = mom,eps

# add scale and offset parameters to the model
# note: nf is the number of channels
# Q1: shouldn't self.mults and self.adds have size [1,nf,1,1]?
self.mults = nn.Parameter(torch.ones (nf,1,1))

# register_buffer adds a persistent buffer to the module, usually used for a non-model parameter
self.register_buffer('sums', torch.zeros(1,nf,1,1))
self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
self.register_buffer('batch', tensor(0.))
self.register_buffer('count', tensor(0.))
self.register_buffer('step', tensor(0.))
self.register_buffer('dbias', tensor(0.))

# compute updates to buffered tensors
def update_stats(self, x):

# batchsize, number of channels
bs,nc,*_ = x.shape

# Note: for a tensor t, t.detach_() means detach t from the computation graph, i.e. don't keep track of its gradients;
#     the '_' prefix means do it "in place"
# Q2: why don't we also use .detach_() for self.batch, self.count, self.step, and self.dbias?
self.sums.detach_()
self.sqrs.detach_()

# the input x is a four-dimensional tensor:
#    dimensions 0, 2, 3 refer to batch samples, weight matrix rows, and weight matrix columns, respectively
#    dimension 1 refers to channels
dims = (0,2,3)

# compute s and ss, which are the sum of the weights and the sum of the squares of the weights
#     over dimensions (0,2,3) for this batch. s and ss each consist of one number for each channel;
#     because keepdim=True s, and ss are each of size [1,nf,1,1]
s = x.sum(dims, keepdim=True)
ss = (x*x).sum(dims, keepdim=True)

# Notes:
#   x.numel() is the number of elements in the 4-D tensor x,
#       which is the total number of weights in the batch
#   x.numel()/nc is the number of weights per channel in the batch
#   y = tensor.new_tensor(x) is equivalent to y = x.clone().detach(),
#       the latter is the preferred way to make a copy of a tensor
#   c is a one-dimensional tensor with a value equal to the number of weights per channel for this batch
#       note that the number of weights per channel of a batch depends on the number of samples in the
#       batch; not all batches have the same number of samples
c = self.count.new_tensor(x.numel()/nc)

# momentum
# mom1 is the 'weight' to be used in lerp_() to compute EWMAs
#     -- see pytorch documentation for lerp_
# if mom is 0.1 and batch size is 2, then mom1 ~ 1 - 0.9/1 = 0.1
# if mom is 0.1 and batch size is 64, then mom1 ~ 1 - 0.9/7 ~ 0.9;
#     in general, mom1 increases with batch size
# Q3: What's the logic behind the following formula for mom1?
mom1 = 1 - (1-self.mom)/math.sqrt(bs-1)
# self.mom1 is a one-dimensional tensor, with a value equal to mom1
self.mom1 = self.dbias.new_tensor(mom1)

# update EWMAs of sums, sqrs, which, like s and ss, have size [1,1,nf,1]
self.sums.lerp_(s, self.mom1)
self.sqrs.lerp_(ss, self.mom1)

# update EWMA of count
# self.count keeps track of the EWMA of c,
#     which is the number of weights per channel for a batch
# Q4: why do we need the EWMA of c?  Aren't batch sizes always the same, except for the last batch?
self.count.lerp_(c, self.mom1)

# Q5: what is the logic behind the following formula for dbias?
self.dbias = self.dbias*(1-self.mom1) + self.mom1

# update the total number of samples that have been processed up till now,
#     i.e. the number of samples in this batch and all previous batches so far
self.batch += bs

# update the total number of batches that have been processed
self.step += 1

# apply a forward pass to the current batch
def forward(self, x):

# main idea of RunningBatchNorm:
#     to normalize the batch:
#     in training mode, use the current EWMAs accumulated in the buffers at this step (batch),
#         and the *current* fitted values of the model parameters mults and adds at this step
#     in validation mode, use the final values of the EWMAs accumulated in the buffers after training,
#         and the final fitted values of mults and adds
if self.training: self.update_stats(x)

# get the current values of the EWMAs of sums, sqrs and count from the buffers
sums = self.sums
sqrs = self.sqrs
c = self.count

# if the current batch number is less than 100, scale the EWMAs by 1/self.dbias
# Q6: Why?
if self.step<100:
sums = sums / self.dbias
sqrs = sqrs / self.dbias
c    = c    / self.dbias

# scale sums by 1/c to get the mean of the weights
means = sums/c

# scale sqrs by 1/c to get the mean of the squared weights
#     then subtract the square of the mean weight from the mean of the squared weights
# note: we recognize this as the 'computationally efficient' formula for the variance that we've seen before
vars = (sqrs/c).sub_(means*means)

# if there are less than 20 samples so far, clamp vars to 0.01 (in case any of them becomes very small)
if bool(self.batch < 20): vars.clamp_min_(0.01)

# normalize the batch in the usual way, i.e. subtract the mean and divide by std
# Q7: but why do we need to add eps, when we've already clamped the vars to 0.01?

# return a scaled and offset version of the normalized batch, where the
#     scale factors (self.mults) and offsets (self.adds) are parameters in the model
#     Note: there's a size mismatch: self.mults and self.adds have size [nf,1,1], while x has size [1,nf,1,1]

These are all good questions! Perhaps you can do some experiments and see what happens if you change some of these things, and make some guesses for each?

Would be a great learning project for others in the community to help with too. I can chime in later if there are some things that don’t get solved.

3 Likes

This what i understood.
We are updating the weight(`mom1`) that we pass into the linear interpolation formula(`lerp`). `mom1` depends on `mom` and `bs`.
When `bs`=2 ; `mom1` = `mom` =0.1 (default value of `mom`=0.1).
When `bs`=512; `mom1` =0.96
So as `bs` increases `mom1` increases.
Now let us look at `lerp`: old_value * (1-wt) + new_value * wt ; where wt = mom1
So as `bs` increases, `mom1` increases, so weight for your current sample increases.
As you batch size increases the computed value is more “trustworthy” so you can weigh it more heavily in the linear interpolation.
Q4: when number of samples are not multiples of your batch size. Say you have 1030 samples with a bs=512. So batch 1 and batch 2 will 512 samples each and the last batch will only have 6 samples.

3 Likes

Debiasing is not there in the Simplified BatchNorm:

``````class RunningBatchNorm(nn.Module):
def __init__(self, nf, mom=0.1, eps=1e-5):
super().__init__()
self.mom, self.eps = mom, eps
self.mults = nn.Parameter(torch.ones (nf,1,1))
self.register_buffer('sums', torch.zeros(1,nf,1,1))
self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
self.register_buffer('count', tensor(0.))
self.register_buffer('factor', tensor(0.))
self.register_buffer('offset', tensor(0.))
self.batch = 0

def update_stats(self, x):
bs,nc,*_ = x.shape
self.sums.detach_()
self.sqrs.detach_()
dims = (0,2,3)
s    = x    .sum(dims, keepdim=True)
ss   = (x*x).sum(dims, keepdim=True)
c    = s.new_tensor(x.numel()/nc)
mom1 = s.new_tensor(1 - (1-self.mom)/math.sqrt(bs-1))
self.sums .lerp_(s , mom1)
self.sqrs .lerp_(ss, mom1)
self.count.lerp_(c , mom1)
self.batch += bs
means = self.sums/self.count
varns = (self.sqrs/self.count).sub_(means*means)
if bool(self.batch < 20): varns.clamp_min_(0.01)
self.factor = self.mults / (varns+self.eps).sqrt()
Jeremy discusses the ‘simplified’ version of RunningBatchNormat the start of the Lesson 11 video, but I can’t find it in the course-3 git repo. Notebook `08_batchnorm.ipynb` still has the version presented in the Lesson 10 video. ( @stas, @Sylvain : just so’s ya know )