Do you mean with skipping some stats re-calculations? or are you saying that replacing:
self.sums.detach_()
self.sqrs.detach_()
with:
x = x.detach()
and not skipping any calculations has a detrimental impact on the outcome?
However it doesn’t actually skip any computation at the moment, since you’re checking
self.batch%2
, which is always true, since bs is even. You should instead create and check an iteration counter - if you do that, you’ll find you still have the dreaded “Trying to backward through the graph a second time, but the buffers have already been freed” error!..
Ah, yes, good catch - thank you, Jeremy! Let’s add self.bc
counter:
self.bc = 0 # batch counter
def update_stats(self, x):
self.bc += 1
So, we need to get a good understanding what is to be in the graph and what not.
With:
x = x.detach()
all temp results are not part of the graph. Only factor
and offset
are since their calculation includes mults
and adds
params. You can check:
def forward(self, x):
if self.training: self.update_stats(x)
if self.bc < 2:
l = "factor offset mults adds sums sqrs count means varns s ss c".split()
print("not leaf ", list(filter(lambda x: not getattr(self,x).is_leaf, l)))
print("want grad", list(filter(lambda x: getattr(self,x).requires_grad, l)))
return x*self.factor + self.offset
not leaf ['factor', 'offset']
want grad ['factor', 'offset', 'mults', 'adds']
I guess the refactoring to reduce broadcasting added this complication of having factor
and offset
to be part of the graph. Perhaps this is wrong?
If you detach sums
and sqrs
, instead of x
you end up with:
not leaf ['factor', 'offset', 'sums', 'sqrs', 'means', 'varns', 's', 'ss']
want grad ['factor', 'offset', 'mults', 'adds', 'sums', 'sqrs', 'means', 'varns', 's', 'ss']
So now all those running temps will no longer do the right thing, since they will get changed by backprop and we want them to be fixed.
Moreover you are detaching them in the wrong place. You detach them at the beginning of update_stats
, but then you make a calculation on them which involves undetached x
and they end up being on the graph again! So you want to detach them after all calculations are done if you don’t detach x
. But as I have shown above this is not right either, since a whole bunch of other temps are now on the graph and will be “adjusted” by the net.
Now going back to the very original implementation as it was presented in the class (with dbias
), we get:
not leaf ['sums', 'sqrs']
want grad ['mults', 'adds', 'sums', 'sqrs']
So it wasn’t detaching them either!
Only after you move them to the end of update_stats
:
self.sums.detach_()
self.sqrs.detach_()
you get:
not leaf []
want grad ['mults', 'adds']
I’m still trying to wrap my head around this detach
thing, so please bear with me if I’m saying an incorrect thing. If what I described above is correct, then you were getting good results not because of the better BN (or at least not just because of it), but because your temps were actually backpropagated, so the stats weren’t calculated on the running averages, but on running averages that are also variables that are learnable - i.e. the network was messing (in a good way) with those numbers that we intended to be fixed. Does this make sense?
And this in a long way answers why you get the error. You tried to skip calculations on variables that are on the graph and that’s why you get the error.
If you detach all of those other temp vars, you won’t get the error. i.e. finish update_stats
with:
l = "sums sqrs count means varns s ss c".split()
for a in l: getattr(self,a).detach_()
The originally proposed:
x = x.detach()
at the very beginning of update_stats
does the same thing, but more efficiently, since the code doesn’t need to swap temp vars back and forth to require grads and then not.
To conclude: decide which variables are to be fixed and controlled only by you, and which are learnable, and then use detach
accordingly. Perhaps a significant part of the magic of RunningBatchNorm
is a side-effect of a coding mistake