RunningBatchNorm made easy to use

We should use the latest RunningBatchNorm for this
but until that is working well and fast for demonstration purposes:

class RunningBatchNorm(nn.Module):
    def __init__(self, nf, mom=0.1, eps=1e-5):
        super().__init__()
        self.nf = nf
        self.mom,self.eps = mom,eps
        self.mults = nn.Parameter(torch.ones (nf,1,1))
        self.adds = nn.Parameter(torch.zeros(nf,1,1))
        self.register_buffer('sums', torch.zeros(1,nf,1,1))
        self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
        self.register_buffer('batch', tensor(0.))
        self.register_buffer('count', tensor(0.))
        self.register_buffer('step', tensor(0.))
        self.register_buffer('dbias', tensor(0.))

    def update_stats(self, x):
        bs,nc,*_ = x.shape
        self.sums.detach_()
        self.sqrs.detach_()
        dims = (0,2,3)
        s = x.sum(dims, keepdim=True)
        ss = (x*x).sum(dims, keepdim=True)
        c = self.count.new_tensor(x.numel()/nc)
        mom1 = 1 - (1-self.mom)/math.sqrt(bs-1)
        self.mom1 = self.dbias.new_tensor(mom1)
        self.sums.lerp_(s, self.mom1)
        self.sqrs.lerp_(ss, self.mom1)
        self.count.lerp_(c, self.mom1)
        self.dbias = self.dbias*(1-self.mom1) + self.mom1
        self.batch += bs
        self.step += 1

    def forward(self, x):
        if self.training: self.update_stats(x)
        sums = self.sums
        sqrs = self.sqrs
        c = self.count
        if self.step<100:
            sums = sums / self.dbias
            sqrs = sqrs / self.dbias
            c    = c    / self.dbias
        means = sums/c
        vars = (sqrs/c).sub_(means*means)
        if bool(self.batch < 20): vars.clamp_min_(0.01)
        x = (x-means).div_((vars.add_(self.eps)).sqrt())
        return x.mul_(self.mults).add_(self.adds)

    def extra_repr(self):
        return f'{self.nf}, mom={self.mom}, eps={self.eps}'
def find_modules(module, cond, container=None, index=None):
    if cond(module):
        yield container, index

    for i, subModule in enumerate(module.children()):
        try:
            module._modules[i]
        except KeyError:
            i = list(module._modules.keys())[i]

        yield from find_modules(subModule, cond, module, i)

def runningBatchNormify(model):
    check_bn = lambda o: issubclass(o.__class__, nn.BatchNorm2d)

    for m, key in find_modules(model, check_bn):
        nf = m._modules[key].num_features
        m._modules[key] = RunningBatchNorm(nf)

You can use runningBatchNormify like:

my_model = create_cnn_model(models.resnet50, pretrained=False)
runningBatchNormify(my_model)

It replaces all nn.BatchNorm2d with the new RunningBatchNorm. Makes it easier to test it.

6 Likes

Hi hadus
Could you please help understand below things about the batchnorm impl

  1. self.sums.lerp_(s, self.mom1)
    What does here lerp does, i know its task is to extrapolate after last number by weightage of mom1 . Why do we do this here … ?

  2. self.count.new_tensor(x.numel()/nc)
    What is the .new_tensor and numel here ,what is the purpose ?

  3. In the Simple batch norm we are hardcoing it to 0.1 ,not sure why

I am stuck in this lesson as unable to understand these things…

I just copied this implementation of running batch norm from the notebook.

I don’t think it is very important to understand it low level. It works like a normal batch norm except the std and mean are calculated differently. They are a moving averages. Meaning it doesn’t matter what batch size we use.

If you want to understand it line by line I think you should start a new thread and ask it there.
Good luck :slight_smile: