I’ve spent a fair amount of effort trying to understand the implementation of
RunningBatchNorm algorithm in Lesson 10, including carefully examining the code and watching/listening to the relevant part of the video near the end of Lesson 10 several times.
I found parts of the code to be confusing. In particular, it would be great to have a clear explanation of the logic behind the equations for
self.dbias, and their use in the subsequent code. Also, I don’t see how an exponentially weighted average of the number of samples in a batch is useful, since batch sizes are usually all the same except for the last batch
Below is the code for
RunningBatchNorm from the Lesson 10 notebook
07_batchnorm.ipynb that is currently in the /fastai/course-3 git repo.
I added comments to help document the code.
I also added questions (marked by
Q) that I haven’t been able to answer.
# initialize def __init__(self, nf, mom=0.1, eps=1e-5): super().__init__() # constants self.mom,self.eps = mom,eps # add scale and offset parameters to the model # note: nf is the number of channels # Q1: shouldn't self.mults and self.adds have size [1,nf,1,1]? self.mults = nn.Parameter(torch.ones (nf,1,1)) self.adds = nn.Parameter(torch.zeros(nf,1,1)) # register_buffer adds a persistent buffer to the module, usually used for a non-model parameter self.register_buffer('sums', torch.zeros(1,nf,1,1)) self.register_buffer('sqrs', torch.zeros(1,nf,1,1)) self.register_buffer('batch', tensor(0.)) self.register_buffer('count', tensor(0.)) self.register_buffer('step', tensor(0.)) self.register_buffer('dbias', tensor(0.)) # compute updates to buffered tensors def update_stats(self, x): # batchsize, number of channels bs,nc,*_ = x.shape # Note: for a tensor t, t.detach_() means detach t from the computation graph, i.e. don't keep track of its gradients; # the '_' prefix means do it "in place" # Q2: why don't we also use .detach_() for self.batch, self.count, self.step, and self.dbias? self.sums.detach_() self.sqrs.detach_() # the input x is a four-dimensional tensor: # dimensions 0, 2, 3 refer to batch samples, weight matrix rows, and weight matrix columns, respectively # dimension 1 refers to channels dims = (0,2,3) # compute s and ss, which are the sum of the weights and the sum of the squares of the weights # over dimensions (0,2,3) for this batch. s and ss each consist of one number for each channel; # because keepdim=True s, and ss are each of size [1,nf,1,1] s = x.sum(dims, keepdim=True) ss = (x*x).sum(dims, keepdim=True) # Notes: # x.numel() is the number of elements in the 4-D tensor x, # which is the total number of weights in the batch # x.numel()/nc is the number of weights per channel in the batch # y = tensor.new_tensor(x) is equivalent to y = x.clone().detach(), # the latter is the preferred way to make a copy of a tensor # c is a one-dimensional tensor with a value equal to the number of weights per channel for this batch # note that the number of weights per channel of a batch depends on the number of samples in the # batch; not all batches have the same number of samples c = self.count.new_tensor(x.numel()/nc) # momentum # mom1 is the 'weight' to be used in lerp_() to compute EWMAs # -- see pytorch documentation for lerp_ # if mom is 0.1 and batch size is 2, then mom1 ~ 1 - 0.9/1 = 0.1 # if mom is 0.1 and batch size is 64, then mom1 ~ 1 - 0.9/7 ~ 0.9; # in general, mom1 increases with batch size # Q3: What's the logic behind the following formula for mom1? mom1 = 1 - (1-self.mom)/math.sqrt(bs-1) # self.mom1 is a one-dimensional tensor, with a value equal to mom1 self.mom1 = self.dbias.new_tensor(mom1) # update EWMAs of sums, sqrs, which, like s and ss, have size [1,1,nf,1] self.sums.lerp_(s, self.mom1) self.sqrs.lerp_(ss, self.mom1) # update EWMA of count # self.count keeps track of the EWMA of c, # which is the number of weights per channel for a batch # Q4: why do we need the EWMA of c? Aren't batch sizes always the same, except for the last batch? self.count.lerp_(c, self.mom1) # Q5: what is the logic behind the following formula for dbias? self.dbias = self.dbias*(1-self.mom1) + self.mom1 # update the total number of samples that have been processed up till now, # i.e. the number of samples in this batch and all previous batches so far self.batch += bs # update the total number of batches that have been processed self.step += 1 # apply a forward pass to the current batch def forward(self, x): # main idea of RunningBatchNorm: # to normalize the batch: # in training mode, use the current EWMAs accumulated in the buffers at this step (batch), # and the *current* fitted values of the model parameters mults and adds at this step # in validation mode, use the final values of the EWMAs accumulated in the buffers after training, # and the final fitted values of mults and adds if self.training: self.update_stats(x) # get the current values of the EWMAs of sums, sqrs and count from the buffers sums = self.sums sqrs = self.sqrs c = self.count # if the current batch number is less than 100, scale the EWMAs by 1/self.dbias # Q6: Why? if self.step<100: sums = sums / self.dbias sqrs = sqrs / self.dbias c = c / self.dbias # scale sums by 1/c to get the mean of the weights means = sums/c # scale sqrs by 1/c to get the mean of the squared weights # then subtract the square of the mean weight from the mean of the squared weights # note: we recognize this as the 'computationally efficient' formula for the variance that we've seen before vars = (sqrs/c).sub_(means*means) # if there are less than 20 samples so far, clamp vars to 0.01 (in case any of them becomes very small) if bool(self.batch < 20): vars.clamp_min_(0.01) # normalize the batch in the usual way, i.e. subtract the mean and divide by std # Q7: but why do we need to add eps, when we've already clamped the vars to 0.01? x = (x-means).div_((vars.add_(self.eps)).sqrt()) # return a scaled and offset version of the normalized batch, where the # scale factors (self.mults) and offsets (self.adds) are parameters in the model # Note: there's a size mismatch: self.mults and self.adds have size [nf,1,1], while x has size [1,nf,1,1] return x.mul_(self.mults).add_(self.adds)