Questions about RunningBatchNorm in Lesson 10

I’ve spent a fair amount of effort trying to understand the implementation of RunningBatchNorm algorithm in Lesson 10, including carefully examining the code and watching/listening to the relevant part of the video near the end of Lesson 10 several times.

I found parts of the code to be confusing. In particular, it would be great to have a clear explanation of the logic behind the equations for mom1 and self.dbias, and their use in the subsequent code. Also, I don’t see how an exponentially weighted average of the number of samples in a batch is useful, since batch sizes are usually all the same except for the last batch

Below is the code for RunningBatchNorm from the Lesson 10 notebook 07_batchnorm.ipynb that is currently in the /fastai/course-3 git repo.

I added comments to help document the code.

I also added questions (marked by Q) that I haven’t been able to answer.

Could @jeremy , @stas , or @Sylvain (or whoever wrote this code) please help clarify the RunningBatchNorm algorithm by addressing these questions?

==============================================

class RunningBatchNorm(nn.Module):

# initialize
def __init__(self, nf, mom=0.1, eps=1e-5):
    super().__init__()
    
    # constants
    self.mom,self.eps = mom,eps
    
    # add scale and offset parameters to the model
    # note: nf is the number of channels
    # Q1: shouldn't self.mults and self.adds have size [1,nf,1,1]?
    self.mults = nn.Parameter(torch.ones (nf,1,1))
    self.adds = nn.Parameter(torch.zeros(nf,1,1))
    
    # register_buffer adds a persistent buffer to the module, usually used for a non-model parameter
    self.register_buffer('sums', torch.zeros(1,nf,1,1))
    self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
    self.register_buffer('batch', tensor(0.))
    self.register_buffer('count', tensor(0.))
    self.register_buffer('step', tensor(0.))
    self.register_buffer('dbias', tensor(0.))

# compute updates to buffered tensors
def update_stats(self, x):
    
    # batchsize, number of channels
    bs,nc,*_ = x.shape
    
    # Note: for a tensor t, t.detach_() means detach t from the computation graph, i.e. don't keep track of its gradients;
    #     the '_' prefix means do it "in place"
    # Q2: why don't we also use .detach_() for self.batch, self.count, self.step, and self.dbias?
    self.sums.detach_()
    self.sqrs.detach_()
    
    # the input x is a four-dimensional tensor: 
    #    dimensions 0, 2, 3 refer to batch samples, weight matrix rows, and weight matrix columns, respectively
    #    dimension 1 refers to channels
    dims = (0,2,3)
    
    # compute s and ss, which are the sum of the weights and the sum of the squares of the weights 
    #     over dimensions (0,2,3) for this batch. s and ss each consist of one number for each channel;
    #     because keepdim=True s, and ss are each of size [1,nf,1,1]
    s = x.sum(dims, keepdim=True)
    ss = (x*x).sum(dims, keepdim=True)
    
    # Notes: 
    #   x.numel() is the number of elements in the 4-D tensor x,
    #       which is the total number of weights in the batch
    #   x.numel()/nc is the number of weights per channel in the batch
    #   y = tensor.new_tensor(x) is equivalent to y = x.clone().detach(), 
    #       the latter is the preferred way to make a copy of a tensor
    #   c is a one-dimensional tensor with a value equal to the number of weights per channel for this batch
    #       note that the number of weights per channel of a batch depends on the number of samples in the 
    #       batch; not all batches have the same number of samples
    c = self.count.new_tensor(x.numel()/nc)
    
    # momentum
    # mom1 is the 'weight' to be used in lerp_() to compute EWMAs
    #     -- see pytorch documentation for lerp_
    # if mom is 0.1 and batch size is 2, then mom1 ~ 1 - 0.9/1 = 0.1
    # if mom is 0.1 and batch size is 64, then mom1 ~ 1 - 0.9/7 ~ 0.9; 
    #     in general, mom1 increases with batch size
    # Q3: What's the logic behind the following formula for mom1?
    mom1 = 1 - (1-self.mom)/math.sqrt(bs-1)
    # self.mom1 is a one-dimensional tensor, with a value equal to mom1
    self.mom1 = self.dbias.new_tensor(mom1)
    
    # update EWMAs of sums, sqrs, which, like s and ss, have size [1,1,nf,1] 
    self.sums.lerp_(s, self.mom1)
    self.sqrs.lerp_(ss, self.mom1)

    # update EWMA of count
    # self.count keeps track of the EWMA of c,
    #     which is the number of weights per channel for a batch
    # Q4: why do we need the EWMA of c?  Aren't batch sizes always the same, except for the last batch?
    self.count.lerp_(c, self.mom1)
    
 
    # Q5: what is the logic behind the following formula for dbias?
    self.dbias = self.dbias*(1-self.mom1) + self.mom1
    
    # update the total number of samples that have been processed up till now, 
    #     i.e. the number of samples in this batch and all previous batches so far
    self.batch += bs
    
    # update the total number of batches that have been processed
    self.step += 1

# apply a forward pass to the current batch
def forward(self, x):
    
    # main idea of RunningBatchNorm:
    #     to normalize the batch:
    #     in training mode, use the current EWMAs accumulated in the buffers at this step (batch), 
    #         and the *current* fitted values of the model parameters mults and adds at this step
    #     in validation mode, use the final values of the EWMAs accumulated in the buffers after training,
    #         and the final fitted values of mults and adds
    if self.training: self.update_stats(x)
        
    # get the current values of the EWMAs of sums, sqrs and count from the buffers
    sums = self.sums
    sqrs = self.sqrs
    c = self.count
    
    # if the current batch number is less than 100, scale the EWMAs by 1/self.dbias
    # Q6: Why?
    if self.step<100:
        sums = sums / self.dbias
        sqrs = sqrs / self.dbias
        c    = c    / self.dbias
        
    # scale sums by 1/c to get the mean of the weights
    means = sums/c
    
    # scale sqrs by 1/c to get the mean of the squared weights
    #     then subtract the square of the mean weight from the mean of the squared weights
    # note: we recognize this as the 'computationally efficient' formula for the variance that we've seen before 
    vars = (sqrs/c).sub_(means*means)
    
    # if there are less than 20 samples so far, clamp vars to 0.01 (in case any of them becomes very small)
    if bool(self.batch < 20): vars.clamp_min_(0.01)
        
    # normalize the batch in the usual way, i.e. subtract the mean and divide by std
    # Q7: but why do we need to add eps, when we've already clamped the vars to 0.01? 
    x = (x-means).div_((vars.add_(self.eps)).sqrt())
    
    # return a scaled and offset version of the normalized batch, where the
    #     scale factors (self.mults) and offsets (self.adds) are parameters in the model
    #     Note: there's a size mismatch: self.mults and self.adds have size [nf,1,1], while x has size [1,nf,1,1]
    return x.mul_(self.mults).add_(self.adds)

These are all good questions! Perhaps you can do some experiments and see what happens if you change some of these things, and make some guesses for each?

Would be a great learning project for others in the community to help with too. I can chime in later if there are some things that don’t get solved.

3 Likes

This what i understood.
We are updating the weight(mom1) that we pass into the linear interpolation formula(lerp). mom1 depends on mom and bs.
When bs=2 ; mom1 = mom =0.1 (default value of mom=0.1).
When bs=512; mom1 =0.96
So as bs increases mom1 increases.
Now let us look at lerp: old_value * (1-wt) + new_value * wt ; where wt = mom1
So as bs increases, mom1 increases, so weight for your current sample increases.
As you batch size increases the computed value is more “trustworthy” so you can weigh it more heavily in the linear interpolation.
Q4: when number of samples are not multiples of your batch size. Say you have 1030 samples with a bs=512. So batch 1 and batch 2 will 512 samples each and the last batch will only have 6 samples.

4 Likes

Debiasing is not there in the Simplified BatchNorm:

class RunningBatchNorm(nn.Module):
    def __init__(self, nf, mom=0.1, eps=1e-5):
        super().__init__()
        self.mom, self.eps = mom, eps
        self.mults = nn.Parameter(torch.ones (nf,1,1))
        self.adds  = nn.Parameter(torch.zeros(nf,1,1))
        self.register_buffer('sums', torch.zeros(1,nf,1,1))
        self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
        self.register_buffer('count', tensor(0.))
        self.register_buffer('factor', tensor(0.))
        self.register_buffer('offset', tensor(0.))
        self.batch = 0
        
    def update_stats(self, x):
        bs,nc,*_ = x.shape
        self.sums.detach_()
        self.sqrs.detach_()
        dims = (0,2,3)
        s    = x    .sum(dims, keepdim=True)
        ss   = (x*x).sum(dims, keepdim=True)
        c    = s.new_tensor(x.numel()/nc)
        mom1 = s.new_tensor(1 - (1-self.mom)/math.sqrt(bs-1))
        self.sums .lerp_(s , mom1)
        self.sqrs .lerp_(ss, mom1)
        self.count.lerp_(c , mom1)
        self.batch += bs
        means = self.sums/self.count
        varns = (self.sqrs/self.count).sub_(means*means)
        if bool(self.batch < 20): varns.clamp_min_(0.01)
        self.factor = self.mults / (varns+self.eps).sqrt()
        self.offset = self.adds - means*self.factor
        
    def forward(self, x):
        if self.training: self.update_stats(x)
        return x*self.factor + self.offset
2 Likes

Jeremy discusses the ‘simplified’ version of RunningBatchNormat the start of the Lesson 11 video, but I can’t find it in the course-3 git repo. Notebook 08_batchnorm.ipynb still has the version presented in the Lesson 10 video. ( @stas, @Sylvain : just so’s ya know )

1 Like

Hey, I can’t seem to answer a few of your questions you posed above and wanted to ask if you figured out the answers of the rest(besides Q3 and 4)?
Any help is appreciated. Thanks!