Running batch norm tweaks

Me thinking - perhaps let’s not worry about optimizations first, since those optimizations complicate things. Figure out the stat computation skipping w/o affecting the training quality - and once there optimizing should be easy.

As of now, even if I go back to your class’s version minus dbias, it won’t train well with even a tiny amount of skipping re-calculations.

Here is pretty much that version with a few minor tweaks (making batch non-var + adding iter), also added bs-1 or 1 to avoid div by zero with bs=1.

class RunningBatchNorm(nn.Module):
    def __init__(self, nf, mom=0.1, eps=1e-5):
        super().__init__()
        self.mom,self.eps = mom,eps
        self.mults = nn.Parameter(torch.ones (nf,1,1))
        self.adds = nn.Parameter(torch.zeros(nf,1,1))
        self.register_buffer('sums', torch.zeros(1,nf,1,1))
        self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
        self.register_buffer('count', tensor(0.))
        self.batch = 0
        self.iter  = 0
        
    def update_stats(self, x):
        bs,nc,*_ = x.shape
        self.batch += bs
        self.iter += 1

        self.sums.detach_()
        self.sqrs.detach_()        

        if self.batch > 10000 and not self.iter%10: return

        dims = (0,2,3)
        s  = x    .sum(dims, keepdim=True)
        ss = (x*x).sum(dims, keepdim=True)
        c = self.count.new_tensor(x.numel()/nc)

        mom1 = 1 - (1-self.mom)/math.sqrt(bs-1 or 1)
        self.mom1 = s.new_tensor(mom1)
        self.sums.lerp_ (s,  self.mom1)
        self.sqrs.lerp_ (ss, self.mom1)
        self.count.lerp_(c,  self.mom1)
        
    def forward(self, x):
        if self.training: self.update_stats(x)
            
        means = self.sums/self.count
        vars = (self.sqrs/self.count).sub_(means.pow(2))
        if self.batch < 20: vars.clamp_min_(0.01)
            
        x = (x-means).div_((vars.add_(self.eps)).sqrt())
                
        return x.mul_(self.mults).add_(self.adds)

Feel free to tweak if self.batch > 10000 and not self.iter%10: return - the code doesn’t do well even at skipping 1 out of 10 calcs, after 10000 samples. This class’s version simply skips accumulating sums and sqrs.

And earlier we made attempts to always accumulate sums and sqrs and it didn’t fair any better.

So in addition to the above adjustments to the code (in particular mom1), learning rate, and batch size, I’m running on a nightly version of PyTorch (well, sort of).
What led me to believe it’s the usual suspects (i.e. hyperparameters) is after I checked for NaNs either in the RBN module or just the loss. It turned out that the number of batchs it goes varied with the hyperparameters. Also, printing the weight magnitudes (just print([(n,p.abs().max().item()) for n,p in learn.model.named_parameters]) after each minibatch had suspicious results.)

That’s exciting! I think this is the first successful training of rbn with skipping! :slight_smile:

1 Like

This should replace nn.BatchNorm2d after it is done.
How would we adapt this to work in place of BatchNorm1d/3d?

If you change the shapes to be 3d or 5d, and adjust taking mean and var over dimensions 0, 2 or 0, 2, 3, 4, it’ll do the right thing. Fun fact: BatchNorm in PyTorch reduces to the “BatchNorm1d” case internally.

2 Likes

HI,I run the code from beginning of lesson 11…However , the error is raised,The only thing that I change is :

tensor(0.) to torch.tensor(0.)

The code :

class RunningBatchNorm(nn.Module):
    def __init__(self, nf, mom=0.1, eps=1e-5):
        super().__init__()
        self.mom,self.eps = mom,eps
        self.mults = nn.Parameter(torch.ones (nf,1,1))
        self.adds = nn.Parameter(torch.zeros(nf,1,1))
        self.register_buffer('sums', torch.zeros(1,nf,1,1))
        self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
        self.register_buffer('count', torch.tensor(0.))
        self.register_buffer('factor', torch.tensor(0.))
        self.register_buffer('offset', torch.tensor(0.))
        self.batch = 0
def update_stats(self, x):
    bs,nc,*_ = x.shape
    self.sums.detach_()
    self.sqrs.detach_()
    dims = (0,2,3)
    s = x.sum(dims, keepdim=True)
    ss = (x*x).sum(dims, keepdim=True)
    c = s.new_tensor(x.numel()/nc)
    mom1 =  s.new_tensor(1 - (1-self.mom)/math.sqrt(bs-1))

    self.sums.lerp_(s, mom1)
    self.sqrs.lerp_(ss,mom1)
    self.count.lerp_(c, mom1)
    self.batch += bs
    means = self.sums/self.count
    varns = (self.sqrs/self.count).sub_(means*means)
    if bool(self.batch<20):
        varns.clamp_min_(0.01)
    self.factor = self.mults/(varns+self.eps).sqrt()
    self.offset = self.adds-means*self.factor
    
def forward(self, x):
    if self.training: 
        self.update_stats(x)
    return x*self.factor+self.offset

The error message something like :

one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor []] is at version 6; expected version 5 instead…

I suspect this error happens is because of this operation

self.count.lerp_(c, mom1)

But I don’t really have whole concept of this ,please check and help…

Thanks in advance and have a great day !!!