@jeremy, looking at the latest incarnation of RunningBatchNorm
, why are we recalculating everything for inference? Here is a refactored version:
#export
class RunningBatchNorm(nn.Module):
def __init__(self, nf, mom=0.1, eps=1e-5):
super().__init__()
self.mom, self.eps = mom, eps
self.mults = nn.Parameter(torch.ones (nf,1,1))
self.adds = nn.Parameter(torch.zeros(nf,1,1))
self.register_buffer('sums', torch.zeros(1,nf,1,1))
self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
self.register_buffer('count', tensor(0.))
self.register_buffer('factor', tensor(0.))
self.register_buffer('offset', tensor(0.))
self.batch = 0
def update_stats(self, x):
bs,nc,*_ = x.shape
self.sums.detach_()
self.sqrs.detach_()
dims = (0,2,3)
s = x .sum(dims, keepdim=True)
ss = (x*x).sum(dims, keepdim=True)
c = s.new_tensor(x.numel()/nc)
mom1 = s.new_tensor(1 - (1-self.mom)/math.sqrt(bs-1))
self.sums .lerp_(s , mom1)
self.sqrs .lerp_(ss, mom1)
self.count.lerp_(c , mom1)
self.batch += bs
means = self.sums/self.count
vars = (self.sqrs/self.count).sub_(means*means)
if bool(self.batch < 20): vars.clamp_min_(0.01)
self.factor = self.mults / (vars+self.eps).sqrt()
self.offset = self.adds - means*self.factor
def forward(self, x):
if self.training: self.update_stats(x)
return x*self.factor + self.offset
The only thing I can’t figure out is how to get rid of the first 3 buffers - they no longer need to be saved in the model and can be normal vars, but if I replace them with normal vars I have the device issue CUDA vs. CPU, e.g. if I replace:
#self.register_buffer('sums', torch.zeros(1,nf,1,1))
self.sums = torch.zeros(1,nf,1,1)
I get:
---> 24 self.sums .lerp_(s , mom1)
25 self.sqrs .lerp_(ss, mom1)
26 self.count.lerp_(c , mom1)
RuntimeError: Expected tensor to have CPU Backend, but got tensor with CUDA Backend (while checking arguments for CPU_tensor_apply)
So I have to then do an explicit cuda()
or to()
when assigning a tensor in those vars, but I don’t know how to do it so that it’ll work transparently regardless of user’s setup. It seems that register_buffer
does the right thing.
p.s. RunningBatchNorm
uses a variable vars
- which is a built-in function. so that’s probably not a good idea