I’ve spent a fair amount of effort trying to understand the implementation of RunningBatchNorm
algorithm in Lesson 10, including carefully examining the code and watching/listening to the relevant part of the video near the end of Lesson 10 several times.
I found parts of the code to be confusing. In particular, it would be great to have a clear explanation of the logic behind the equations for mom1
and self.dbias
, and their use in the subsequent code. Also, I don’t see how an exponentially weighted average of the number of samples in a batch is useful, since batch sizes are usually all the same except for the last batch
Below is the code for RunningBatchNorm
from the Lesson 10 notebook 07_batchnorm.ipynb
that is currently in the /fastai/course-3 git repo.
I added comments to help document the code.
I also added questions (marked by Q
) that I haven’t been able to answer.
Could @jeremy , @stas , or @Sylvain (or whoever wrote this code) please help clarify the RunningBatchNorm
algorithm by addressing these questions?
==============================================
class RunningBatchNorm(nn.Module):
# initialize
def __init__(self, nf, mom=0.1, eps=1e-5):
super().__init__()
# constants
self.mom,self.eps = mom,eps
# add scale and offset parameters to the model
# note: nf is the number of channels
# Q1: shouldn't self.mults and self.adds have size [1,nf,1,1]?
self.mults = nn.Parameter(torch.ones (nf,1,1))
self.adds = nn.Parameter(torch.zeros(nf,1,1))
# register_buffer adds a persistent buffer to the module, usually used for a non-model parameter
self.register_buffer('sums', torch.zeros(1,nf,1,1))
self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
self.register_buffer('batch', tensor(0.))
self.register_buffer('count', tensor(0.))
self.register_buffer('step', tensor(0.))
self.register_buffer('dbias', tensor(0.))
# compute updates to buffered tensors
def update_stats(self, x):
# batchsize, number of channels
bs,nc,*_ = x.shape
# Note: for a tensor t, t.detach_() means detach t from the computation graph, i.e. don't keep track of its gradients;
# the '_' prefix means do it "in place"
# Q2: why don't we also use .detach_() for self.batch, self.count, self.step, and self.dbias?
self.sums.detach_()
self.sqrs.detach_()
# the input x is a four-dimensional tensor:
# dimensions 0, 2, 3 refer to batch samples, weight matrix rows, and weight matrix columns, respectively
# dimension 1 refers to channels
dims = (0,2,3)
# compute s and ss, which are the sum of the weights and the sum of the squares of the weights
# over dimensions (0,2,3) for this batch. s and ss each consist of one number for each channel;
# because keepdim=True s, and ss are each of size [1,nf,1,1]
s = x.sum(dims, keepdim=True)
ss = (x*x).sum(dims, keepdim=True)
# Notes:
# x.numel() is the number of elements in the 4-D tensor x,
# which is the total number of weights in the batch
# x.numel()/nc is the number of weights per channel in the batch
# y = tensor.new_tensor(x) is equivalent to y = x.clone().detach(),
# the latter is the preferred way to make a copy of a tensor
# c is a one-dimensional tensor with a value equal to the number of weights per channel for this batch
# note that the number of weights per channel of a batch depends on the number of samples in the
# batch; not all batches have the same number of samples
c = self.count.new_tensor(x.numel()/nc)
# momentum
# mom1 is the 'weight' to be used in lerp_() to compute EWMAs
# -- see pytorch documentation for lerp_
# if mom is 0.1 and batch size is 2, then mom1 ~ 1 - 0.9/1 = 0.1
# if mom is 0.1 and batch size is 64, then mom1 ~ 1 - 0.9/7 ~ 0.9;
# in general, mom1 increases with batch size
# Q3: What's the logic behind the following formula for mom1?
mom1 = 1 - (1-self.mom)/math.sqrt(bs-1)
# self.mom1 is a one-dimensional tensor, with a value equal to mom1
self.mom1 = self.dbias.new_tensor(mom1)
# update EWMAs of sums, sqrs, which, like s and ss, have size [1,1,nf,1]
self.sums.lerp_(s, self.mom1)
self.sqrs.lerp_(ss, self.mom1)
# update EWMA of count
# self.count keeps track of the EWMA of c,
# which is the number of weights per channel for a batch
# Q4: why do we need the EWMA of c? Aren't batch sizes always the same, except for the last batch?
self.count.lerp_(c, self.mom1)
# Q5: what is the logic behind the following formula for dbias?
self.dbias = self.dbias*(1-self.mom1) + self.mom1
# update the total number of samples that have been processed up till now,
# i.e. the number of samples in this batch and all previous batches so far
self.batch += bs
# update the total number of batches that have been processed
self.step += 1
# apply a forward pass to the current batch
def forward(self, x):
# main idea of RunningBatchNorm:
# to normalize the batch:
# in training mode, use the current EWMAs accumulated in the buffers at this step (batch),
# and the *current* fitted values of the model parameters mults and adds at this step
# in validation mode, use the final values of the EWMAs accumulated in the buffers after training,
# and the final fitted values of mults and adds
if self.training: self.update_stats(x)
# get the current values of the EWMAs of sums, sqrs and count from the buffers
sums = self.sums
sqrs = self.sqrs
c = self.count
# if the current batch number is less than 100, scale the EWMAs by 1/self.dbias
# Q6: Why?
if self.step<100:
sums = sums / self.dbias
sqrs = sqrs / self.dbias
c = c / self.dbias
# scale sums by 1/c to get the mean of the weights
means = sums/c
# scale sqrs by 1/c to get the mean of the squared weights
# then subtract the square of the mean weight from the mean of the squared weights
# note: we recognize this as the 'computationally efficient' formula for the variance that we've seen before
vars = (sqrs/c).sub_(means*means)
# if there are less than 20 samples so far, clamp vars to 0.01 (in case any of them becomes very small)
if bool(self.batch < 20): vars.clamp_min_(0.01)
# normalize the batch in the usual way, i.e. subtract the mean and divide by std
# Q7: but why do we need to add eps, when we've already clamped the vars to 0.01?
x = (x-means).div_((vars.add_(self.eps)).sqrt())
# return a scaled and offset version of the normalized batch, where the
# scale factors (self.mults) and offsets (self.adds) are parameters in the model
# Note: there's a size mismatch: self.mults and self.adds have size [nf,1,1], while x has size [1,nf,1,1]
return x.mul_(self.mults).add_(self.adds)