I had an idea to try to solve the small batch problem in BatchNorm by buffering up enough of mini-batches and then calculating the stats and apply them, every N mini-batches. It didn’t work too well and requires a small lr and will require more GPU RAM to buffer up inputs, which would be a problem since usually users use a small bs when they are short on GPU RAM in first place. But I thought I’d share, in case someone has some creative ideas to improve upon my attempts:
# Based on BatchNorm implementation from 07_batchnorm.ipynb
# This version buffers up until enough mini-batches are gathered to get a good variance measurement.
# with high learning rate this doesn't work, since delaying normalization even by 2 passes often leads to explosion or vanishing of data.
class AccBatchNorm(nn.Module):
def __init__(self, nf, mom=0.1, eps=1e-5):
super().__init__()
self.mom,self.eps = mom,eps
self.mults = nn.Parameter(torch.ones (nf,1,1))
self.adds = nn.Parameter(torch.zeros(nf,1,1))
self.register_buffer('vars', torch.ones(1,nf,1,1))
self.register_buffer('means', torch.zeros(1,nf,1,1))
self.x = None
self.bs_acc = 0
self.bs_goal = 0
def forward(self, x):
if self.training:
with torch.no_grad(): m,v,ok = self.update_stats(x)
if not ok: return x*self.mults + self.adds
else: m,v = self.means,self.vars
x = (x-m) / (v+self.eps).sqrt()
return x*self.mults + self.adds
def update_stats(self, x):
bs,nc,*_ = x.shape
if not self.bs_goal:
proportion = 4 # max 4 forward runs w/o updates (except bs=1, 8 runs)
bs_min,bs_max = 8,256
bs_goal = bs*4
if bs_goal > bs_max: bs_goal = bs_max
if bs_goal < bs_min: bs_goal = bs_min
print(f"got bs={bs}, use bs_goal={bs_goal}")
self.bs_goal = bs_goal
if bs < self.bs_goal:
if self.x is None: self.x = x
else: self.x = torch.cat([self.x, x])
self.bs_acc += bs
if self.bs_acc < self.bs_goal: return None, None, False
m = self.x.mean((0,2,3), keepdim=True)
v = self.x.var ((0,2,3), keepdim=True)
# reset buffers
self.bs_acc = 0
self.x = None
else:
m = x.mean((0,2,3), keepdim=True)
v = x.var ((0,2,3), keepdim=True)
self.means.lerp_(m, self.mom)
self.vars.lerp_ (v, self.mom)
return m, v, True
I tried a bunch of different workarounds so if you are inspired to make suggestions please consider to try them in action first Just put it in the 07th nb and run it, substituting BatchNorm
with AccBatchNorm
inside conv_rbn
.
And if you don’t want to read through it, the TLDR version is:
def update_stats(self, x):
if not enough buffer up:
if self.x is None: self.x = x
else: self.x = torch.cat([self.x, x])
return early
else:
as in BN-original
Another idea I had is to calculate variance and buffer mini-batches up until it’s good enough (according to some threshold), but I think it’d have the same problem as this version, since while data is getting buffered up and BN is being delayed meanwhile gradients tend to blow up or disappear within a few mini-batches of non-action. So a very low lr is required.