I can’t figure out the purpose of this code in RunningBatchNorm.forward
(07_batchnorm.ipynb)
if self.step<100:
sums = sums / self.dbias
sqrs = sqrs / self.dbias
c = c / self.dbias
means = sums/c
vars = (sqrs/c).sub_(means*means)
sums/c
and sqrs/c
cancel out the first 4 lines in the snippet above (as their relative proportion doesn’t change) and these aren’t used again.
which also renders self.dbias
useless other than to model mom1
after:
self.mom1 = self.dbias.new_tensor(mom1)
but why do we need self.mom1
? it’s a temp calculation.
and self.step
is no longer needed either.
Here is a cleaned up version:
class RunningBatchNorm(nn.Module):
def __init__(self, nf, mom=0.1, eps=1e-5):
super().__init__()
self.mom,self.eps = mom,eps
self.mults = nn.Parameter(torch.ones (nf,1,1))
self.adds = nn.Parameter(torch.zeros(nf,1,1))
self.register_buffer('sums', torch.zeros(1,nf,1,1))
self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
self.register_buffer('batch', tensor(0.))
self.register_buffer('count', tensor(0.))
def update_stats(self, x):
bs,nc,*_ = x.shape
self.sums.detach_()
self.sqrs.detach_()
dims = (0,2,3)
s = x.sum(dims, keepdim=True)
ss = (x*x).sum(dims, keepdim=True)
c = self.count.new_tensor(x.numel()/nc)
mom1 = 1 - (1-self.mom)/math.sqrt(bs-1)
self.sums.lerp_(s, mom1)
self.sqrs.lerp_(ss, mom1)
self.count.lerp_(c, mom1)
self.batch += bs
def forward(self, x):
if self.training: self.update_stats(x)
means = self.sums/self.count
vars = (self.sqrs/self.count).sub_(means*means)
if bool(self.batch < 20): vars.clamp_min_(0.01)
x = (x-means).div_((vars.add_(self.eps)).sqrt())
return x.mul_(self.mults).add_(self.adds)
I’m getting exactly the same accuracy results with this version, but it doesn’t mean it is so in generic case.
Unless the intention was to save the temp results in forward
, and then this will be needed:
if self.step<100:
sums.div_(self.dbias)
sqrs.div_(self.dbias)
c. div_(self.dbias)
means = sums/c
vars = (sqrs/c).sub_(means*means)
Since after doing:
sums = self.sums
sums = sums / self.dbias
sums
is no longer an alias to self.sums
Just had to clarify for myself, when a = b
stops being an alias in pytorch.
def dump(a, b, note): print(f"{note}\na={a}\nb={b}")
a = torch.ones(5)
b = a
dump(a, b, "init")
b = b + 1
dump(a, b, "+ new var")
b = a
b += 1
dump(a, b, "self referring +")
b = a
b.add_ = 1
dump(a, b, "add_")
gives:
init
a=tensor([1., 1., 1., 1., 1.])
b=tensor([1., 1., 1., 1., 1.])
+ new var
a=tensor([1., 1., 1., 1., 1.])
b=tensor([2., 2., 2., 2., 2.])
self referring +
a=tensor([2., 2., 2., 2., 2.])
b=tensor([2., 2., 2., 2., 2.])
add_
a=tensor([2., 2., 2., 2., 2.])
b=tensor([2., 2., 2., 2., 2.])
So b = b+1
is not affecting a
.