Me thinking - perhaps let’s not worry about optimizations first, since those optimizations complicate things. Figure out the stat computation skipping w/o affecting the training quality - and once there optimizing should be easy.
As of now, even if I go back to your class’s version minus dbias, it won’t train well with even a tiny amount of skipping re-calculations.
Here is pretty much that version with a few minor tweaks (making batch non-var + adding iter), also added bs-1 or 1
to avoid div by zero with bs=1.
class RunningBatchNorm(nn.Module):
def __init__(self, nf, mom=0.1, eps=1e-5):
super().__init__()
self.mom,self.eps = mom,eps
self.mults = nn.Parameter(torch.ones (nf,1,1))
self.adds = nn.Parameter(torch.zeros(nf,1,1))
self.register_buffer('sums', torch.zeros(1,nf,1,1))
self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
self.register_buffer('count', tensor(0.))
self.batch = 0
self.iter = 0
def update_stats(self, x):
bs,nc,*_ = x.shape
self.batch += bs
self.iter += 1
self.sums.detach_()
self.sqrs.detach_()
if self.batch > 10000 and not self.iter%10: return
dims = (0,2,3)
s = x .sum(dims, keepdim=True)
ss = (x*x).sum(dims, keepdim=True)
c = self.count.new_tensor(x.numel()/nc)
mom1 = 1 - (1-self.mom)/math.sqrt(bs-1 or 1)
self.mom1 = s.new_tensor(mom1)
self.sums.lerp_ (s, self.mom1)
self.sqrs.lerp_ (ss, self.mom1)
self.count.lerp_(c, self.mom1)
def forward(self, x):
if self.training: self.update_stats(x)
means = self.sums/self.count
vars = (self.sqrs/self.count).sub_(means.pow(2))
if self.batch < 20: vars.clamp_min_(0.01)
x = (x-means).div_((vars.add_(self.eps)).sqrt())
return x.mul_(self.mults).add_(self.adds)
Feel free to tweak if self.batch > 10000 and not self.iter%10: return
- the code doesn’t do well even at skipping 1 out of 10 calcs, after 10000 samples. This class’s version simply skips accumulating sums and sqrs.
And earlier we made attempts to always accumulate sums and sqrs and it didn’t fair any better.