Do you mean with skipping some stats re-calculations? or are you saying that replacing:
x = x.detach()
and not skipping any calculations has a detrimental impact on the outcome?
However it doesn’t actually skip any computation at the moment, since you’re checking
self.batch%2, which is always true, since bs is even. You should instead create and check an iteration counter - if you do that, you’ll find you still have the dreaded “Trying to backward through the graph a second time, but the buffers have already been freed” error!..
Ah, yes, good catch - thank you, Jeremy! Let’s add
self.bc = 0 # batch counter def update_stats(self, x): self.bc += 1
So, we need to get a good understanding what is to be in the graph and what not.
x = x.detach()
all temp results are not part of the graph. Only
offset are since their calculation includes
adds params. You can check:
def forward(self, x): if self.training: self.update_stats(x) if self.bc < 2: l = "factor offset mults adds sums sqrs count means varns s ss c".split() print("not leaf ", list(filter(lambda x: not getattr(self,x).is_leaf, l))) print("want grad", list(filter(lambda x: getattr(self,x).requires_grad, l))) return x*self.factor + self.offset
not leaf ['factor', 'offset'] want grad ['factor', 'offset', 'mults', 'adds']
I guess the refactoring to reduce broadcasting added this complication of having
offset to be part of the graph. Perhaps this is wrong?
If you detach
sqrs, instead of
x you end up with:
not leaf ['factor', 'offset', 'sums', 'sqrs', 'means', 'varns', 's', 'ss'] want grad ['factor', 'offset', 'mults', 'adds', 'sums', 'sqrs', 'means', 'varns', 's', 'ss']
So now all those running temps will no longer do the right thing, since they will get changed by backprop and we want them to be fixed.
Moreover you are detaching them in the wrong place. You detach them at the beginning of
update_stats, but then you make a calculation on them which involves undetached
x and they end up being on the graph again! So you want to detach them after all calculations are done if you don’t detach
x. But as I have shown above this is not right either, since a whole bunch of other temps are now on the graph and will be “adjusted” by the net.
Now going back to the very original implementation as it was presented in the class (with
dbias), we get:
not leaf ['sums', 'sqrs'] want grad ['mults', 'adds', 'sums', 'sqrs']
So it wasn’t detaching them either!
Only after you move them to the end of
not leaf  want grad ['mults', 'adds']
I’m still trying to wrap my head around this
detach thing, so please bear with me if I’m saying an incorrect thing. If what I described above is correct, then you were getting good results not because of the better BN (or at least not just because of it), but because your temps were actually backpropagated, so the stats weren’t calculated on the running averages, but on running averages that are also variables that are learnable - i.e. the network was messing (in a good way) with those numbers that we intended to be fixed. Does this make sense?
And this in a long way answers why you get the error. You tried to skip calculations on variables that are on the graph and that’s why you get the error.
If you detach all of those other temp vars, you won’t get the error. i.e. finish
l = "sums sqrs count means varns s ss c".split() for a in l: getattr(self,a).detach_()
The originally proposed:
x = x.detach()
at the very beginning of
update_stats does the same thing, but more efficiently, since the code doesn’t need to swap temp vars back and forth to require grads and then not.
To conclude: decide which variables are to be fixed and controlled only by you, and which are learnable, and then use
detach accordingly. Perhaps a significant part of the magic of
RunningBatchNorm is a side-effect of a coding mistake