We should use the latest RunningBatchNorm for this
but until that is working well and fast for demonstration purposes:
class RunningBatchNorm(nn.Module):
def __init__(self, nf, mom=0.1, eps=1e-5):
super().__init__()
self.nf = nf
self.mom,self.eps = mom,eps
self.mults = nn.Parameter(torch.ones (nf,1,1))
self.adds = nn.Parameter(torch.zeros(nf,1,1))
self.register_buffer('sums', torch.zeros(1,nf,1,1))
self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
self.register_buffer('batch', tensor(0.))
self.register_buffer('count', tensor(0.))
self.register_buffer('step', tensor(0.))
self.register_buffer('dbias', tensor(0.))
def update_stats(self, x):
bs,nc,*_ = x.shape
self.sums.detach_()
self.sqrs.detach_()
dims = (0,2,3)
s = x.sum(dims, keepdim=True)
ss = (x*x).sum(dims, keepdim=True)
c = self.count.new_tensor(x.numel()/nc)
mom1 = 1 - (1-self.mom)/math.sqrt(bs-1)
self.mom1 = self.dbias.new_tensor(mom1)
self.sums.lerp_(s, self.mom1)
self.sqrs.lerp_(ss, self.mom1)
self.count.lerp_(c, self.mom1)
self.dbias = self.dbias*(1-self.mom1) + self.mom1
self.batch += bs
self.step += 1
def forward(self, x):
if self.training: self.update_stats(x)
sums = self.sums
sqrs = self.sqrs
c = self.count
if self.step<100:
sums = sums / self.dbias
sqrs = sqrs / self.dbias
c = c / self.dbias
means = sums/c
vars = (sqrs/c).sub_(means*means)
if bool(self.batch < 20): vars.clamp_min_(0.01)
x = (x-means).div_((vars.add_(self.eps)).sqrt())
return x.mul_(self.mults).add_(self.adds)
def extra_repr(self):
return f'{self.nf}, mom={self.mom}, eps={self.eps}'
def find_modules(module, cond, container=None, index=None):
if cond(module):
yield container, index
for i, subModule in enumerate(module.children()):
try:
module._modules[i]
except KeyError:
i = list(module._modules.keys())[i]
yield from find_modules(subModule, cond, module, i)
def runningBatchNormify(model):
check_bn = lambda o: issubclass(o.__class__, nn.BatchNorm2d)
for m, key in find_modules(model, check_bn):
nf = m._modules[key].num_features
m._modules[key] = RunningBatchNorm(nf)
You can use runningBatchNormify like:
my_model = create_cnn_model(models.resnet50, pretrained=False)
runningBatchNormify(my_model)
It replaces all nn.BatchNorm2d with the new RunningBatchNorm. Makes it easier to test it.