I think to_fp16 behaves in a weird way internally

Not sure if it’s on purpose, but it has caused me major headaches.

The way to_fp16 is implemented, I think you actually get the following op graph:

  • A fp32 model (leaf nodes)
  • A fp16 copy of the model. Gradients flow back to the fp32 model from this.
  • An fp32 copy of the fp16 batchnorm layers in the model. But of course, even if calculations on the batch are done on fp16, the gradients still flow to fp16 layers and then to the original fp32 layers.
  • The optimizer has only access to the original fp32 parameters. Since grad flows through the graph, it does kind of the correct thing, as they are the leaf nodes.

All three copies of the batchnorm layers are on the gpu!

If you don’t believe me this is what happens, try this:

model = nn.Sequential(whatever)
learn = Learner(data, model).to_fp16()

learn.opt # not found: it's created later apparently

learn.fit(1) # this creates opt

learn.opt.param_groups[0]['params'][0].dtype # this returns torch.float32, and so does every other one (i.e. change the 0's by other integers).

I’ve been working on something for the past three days and I finally found out the bug: my original fp32 batchnorm layers were falling out of sync with the OTHER fp32 batch norm layers, since I was modifying them directly.

Now that I think about it, in fact, @sgugger, if I’m not mistaken, I think (true) weight decay is not currently working correctly when training fp16 models…

In fastai, true_wd just multiplies the original parameters (leaf nodes) by (1-wd*lr). Check out OptimWrapper’s step function:

def step(self)->None:
    "Set weight decay and step optimizer."
    # weight decay outside of optimizer step (AdamW)
    if self.true_wd:
        for lr,wd,pg1,pg2 in zip(self._lr,self._wd,self.opt.param_groups[::2],self.opt.param_groups[1::2]):
            for p in pg1['params']: p.data.mul_(1 - wd*lr)
            if self.bn_wd:
                for p in pg2['params']: p.data.mul_(1 - wd*lr)
        self.set_val('weight_decay', listify(0, self._wd))
    self.opt.step()

Remember, self.opt.param_groups points to the original fp32 leaf nodes. And so if you simply modify the data, I think it will slowly fall out of sync with the fp16 model.

I’m not very familiar with the mixed-precision training, so could be off here, but a few commments:

an fp32 copy of the fp16 batchnorm layers in the model

There aren’t fp16 batchnorm layers in the model from what I see. Batchnorm layers are left as fp32. Well, actually they are converted to fp16, then converted back as fastai’s model2half(mdl) calls PyTorch’s mdl.half() then goes through and calls float() on all batchnorm layers. Are you perhaps observing the effects of this when the model is created? After this it doesn’t look to be doing anything special with batchnorm layers, they are fp32 the whole time. So:

>>> mdl = nn.Sequential(nn.Linear(3,4), nn.BatchNorm2d(4))
>>> bn_w = mdl[1].weight
>>> bn_d = bn_w.data
>>> mdl[0].weight.dtype, bn_w.dtype, bn_d
(torch.float32, torch.float32, tensor([1., 1., 1., 1.]))

>>> mh = model2half(mdl) # Called by to_fp16 to convert model
>>> bn_w2 = mh[1].weight
>>> bn_d2 = bn_w2.data
>>> mh[0].weight.dtype, bn_w2.dtype, bn_d2
(torch.float16, torch.float32, tensor([1., 1., 1., 1.]))

I also note that calling to_fp16 will replace the weight.data tensors, so any references taken to the data before this will be invalid. The Parameter (weight) remains unchanged so you could reference that. This is how PyTorch’s Module.to() does it not a fastai specific thing. So:

>>> bn_w is bn_w2, bn_d is bn_d2 # Only Parameter is the same, not data:
(True, False)
>>> bn_d += 1
>>> bn_d, bn_d2
(tensor([2., 2., 2., 2.]), tensor([1., 1., 1., 1.]))

So if updating the batchnorm weights yourself you’d need to be careful of this. I do wonder why it is done this way, seems like you could fairly easily reverse the logic in model2half to just do a half() on the non-batchnorm layers rather than this way. Though seems unlikely to matter, if your BN stats are so sensitive that the fp32->fp16->fp32 reduces performance you’re probably in trouble anyway. And if taking references to parameters then to work with fastai stuff you should probably do this in a callback as there may well be other callbacks messing around with the model so you shouldn’t assume it is unchanged after passing to Learner.
I also note fastai only does this special handling for batchnorm layers, not other types of normalisation layers. Is there any reason they shouldn’t also get this treatment? Looking into how APEX does it they seem to keep all normalisations in fp32.

The optimizer has only access to the original fp32 parameters

Not sure if you’d seen but while the optimiser only works on fp32, the MixedPrecision callback has both the fp32 master parameters (which the optimiser works on) and the fp16 model parameters and handles explicit updates between the two, copying from master to model in on_step_end.

Thank you for your reply. Your answer seems to be spot on, which means I have no idea what’s going on with my model or why it stops working when I explicitly modify the parameters :frowning: The training loss keeps improving normally, but the validation loss just explodes (even after one epoch).

It seems you are right. Torch has some special handling for modules, because, look

Look:

Then when you do backward on something that uses C, the grad flows through C and then B and then A:

But if A is of type BatchNorm2d, then A.half() does change A.weight to “half”. Not so if A is of type nn.Parameter. Torch modules are magic then :slight_smile:

Although “bn_w is bn_w2” being True is incredibly weird to me! I thought “is” means “pointing to the exact same thing”. If they are the same thing, how can bn_w.data be different than bn_w2.data??

However, if I modify bn_d, bn_d2 gets modified too! So they are not the same thing, but they get updated together?

Now, the mixed precision callback I hadn’t seen. Seems you are right about this too.

Okay, so for a bunch more hours debugging… :frowning:

I’ll have to have a closer look later to understand other things but you’re misunderstanding the point of those is tests. bn_w is bn_w2 and so bn_w.data is bn_w2.data, it just isn’t what bn_d is referencing anymore. When you call half(as to_fp16 does) bn_w.data is changed to point to a new tensor. So if you get a reference to it before calling then that reference will not be to the weight data the batchnorm is using. If the reference is to the Parameter not the .data tensor you’re fine.

Aaah, that makes sense. Thank you.

On your broader issues, given that the MixedPrecision callback is overwriting weights then any updates you make outside of this will be overwritten at on_step_end. So you would likely want to update the optimisers copy of the weights not the model’s copy. Or probably better write a custom optimiser (which can just wrap an existing one of course) and pass that into your learner, then the MixedPrecision callback should handle keeping the model weights in sync with the copy your optimiser deals with.

Though this might depend a bit on what you’re doing. I think the core thing might be whether your operation affects the final loss. If it does then operations on the model weights should be part of the computational graph and so included in the backward pass which is used to derive the gradient w.r.t each model parameter, which will then be used to update the fp32 optimiser copy of the parameter which will then overwrite the model copy.
But then I could be off here. I still don’t wholly understand the whole backward process conceptually even though I get a lot of the parts.