Not sure if it’s on purpose, but it has caused me major headaches.
The way to_fp16 is implemented, I think you actually get the following op graph:
- A fp32 model (leaf nodes)
- A fp16 copy of the model. Gradients flow back to the fp32 model from this.
- An fp32 copy of the fp16 batchnorm layers in the model. But of course, even if calculations on the batch are done on fp16, the gradients still flow to fp16 layers and then to the original fp32 layers.
- The optimizer has only access to the original fp32 parameters. Since grad flows through the graph, it does kind of the correct thing, as they are the leaf nodes.
All three copies of the batchnorm layers are on the gpu!
If you don’t believe me this is what happens, try this:
model = nn.Sequential(whatever)
learn = Learner(data, model).to_fp16()
learn.opt # not found: it's created later apparently
learn.fit(1) # this creates opt
learn.opt.param_groups[0]['params'][0].dtype # this returns torch.float32, and so does every other one (i.e. change the 0's by other integers).
I’ve been working on something for the past three days and I finally found out the bug: my original fp32 batchnorm layers were falling out of sync with the OTHER fp32 batch norm layers, since I was modifying them directly.
Now that I think about it, in fact, @sgugger, if I’m not mistaken, I think (true) weight decay is not currently working correctly when training fp16 models…
In fastai, true_wd just multiplies the original parameters (leaf nodes) by (1-wd*lr). Check out OptimWrapper’s step function:
def step(self)->None:
"Set weight decay and step optimizer."
# weight decay outside of optimizer step (AdamW)
if self.true_wd:
for lr,wd,pg1,pg2 in zip(self._lr,self._wd,self.opt.param_groups[::2],self.opt.param_groups[1::2]):
for p in pg1['params']: p.data.mul_(1 - wd*lr)
if self.bn_wd:
for p in pg2['params']: p.data.mul_(1 - wd*lr)
self.set_val('weight_decay', listify(0, self._wd))
self.opt.step()
Remember, self.opt.param_groups points to the original fp32 leaf nodes. And so if you simply modify the data, I think it will slowly fall out of sync with the fp16 model.