involving which variables specifically? or all of them?
So RunningBatchNorm is then relying on calculating weighted averages plus letting the network further tweak them via backprop - i.e. it is so by design?
And most likely if you remove these from the beginning of update_stats
:
self.sums.detach_()
self.sqrs.detach_()
the outcome would be exactly the same, since as I mentioned in the long reply they get re-attached as soon you calculate them (due to x being in the graph), so the code above doesn’t contribute anything.
But it’s certainly true that I don’t fully understand the details of all this, and I’m sure there’s things that need to be fixed to make this work with the “occasional skipping”.
I think once you sort out which vars should be learnable and which not, then you could force the optional calculations to not involve anything on the graph, and then there will be no problem skipping those at will. Does it make sense?
I just don’t know which of them you meant to be learnable, so can’t code for it until this is known.
for example, self.count
is not on the graph, so you can verify that you can skip calculating it at will:
mom1 = self.s.new_tensor(1 - (1-self.mom)/math.sqrt(self.accumulated_bs-1))
if self.bc % 2:
self.count.lerp_(self.c , mom1)
gives no error.