Issue on PyTorch’s GitHub.
I added a few lines to Learner
and one_batch
, which set up the autoscaler and updated the training process to use AMP without compiling apex. Here’s the tutorial on how to do this.
The code is here: new pytorch AMP in fastai2
Here’s how one_batch
looks:
# self.scaler = GradientScaler() -- defined in Learner.__init__
def one_batch(self, i, b):
self.iter = i
try:
self._split(b); self('begin_batch')
# AMP - Run the forward pass with autocasting
with autocast():
self.pred = self.model(*self.xb); self('after_pred')
if len(self.yb) == 0: return
self.loss = self.loss_func(self.pred, *self.yb); self('after_loss')
if not self.training: return
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same precision that autocast used for corresponding forward ops.
self.scaler.scale(self.loss).backward(); self('after_backward')
# self.loss.backward(); self('after_backward')
# self.scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
self.scaler.step(self.opt); self('after_step')
# self.opt.step(); self('after_step')
# Updates the scale for next iteration.
self.scaler.update()
self.opt.zero_grad()
except CancelBatchException: self('after_cancel_batch')
finally: self('after_batch')
This hits several RuntimeErrors, which I cleaned up by sending data to the device directly instead of letting pytorch fail on the asserts here.
Finally failing at this point:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-21-94707bfc8cc1> in <module>
4 xb,yb = learn.dls.one_batch()
5 init_loss = learn.loss_func(learn.model(xb), yb)
----> 6 learn.fit(6)
7 assert learn.loss < init_loss
<ipython-input-16-7932a3f1e613> in fit(self, n_epoch, lr, wd, cbs, reset_opt)
131 try:
132 self.epoch=epoch; self('begin_epoch')
--> 133 self._do_epoch_train()
134 self._do_epoch_validate()
135 except CancelEpochException: self('after_cancel_epoch')
<ipython-input-16-7932a3f1e613> in _do_epoch_train(self)
104 try:
105 self.dl = self.dls.train; self('begin_train')
--> 106 self.all_batches()
107 except CancelTrainException: self('after_cancel_train')
108 finally: self('after_train')
<ipython-input-16-7932a3f1e613> in all_batches(self)
70 def all_batches(self):
71 self.n_iter = len(self.dl)
---> 72 for o in enumerate(self.dl): self.one_batch(*o)
73
74 def one_batch(self, i, b):
<ipython-input-16-7932a3f1e613> in one_batch(self, i, b)
90 # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
91 # otherwise, optimizer.step() is skipped.
---> 92 self.scaler.step(self.opt); self('after_step')
93 # self.opt.step(); self('after_step')
94 # Updates the scale for next iteration.
~/anaconda3/envs/pytorch-nightly/lib/python3.7/site-packages/torch/cuda/amp/grad_scaler.py in step(self, optimizer, *args, **kwargs)
278
279 if optimizer_state["stage"] == self.READY:
--> 280 self.unscale_(optimizer)
281
282 assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
~/anaconda3/envs/pytorch-nightly/lib/python3.7/site-packages/torch/cuda/amp/grad_scaler.py in unscale_(self, optimizer)
229 found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
230
--> 231 optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
232 optimizer_state["stage"] = self.UNSCALED
233
~/anaconda3/envs/pytorch-nightly/lib/python3.7/site-packages/torch/cuda/amp/grad_scaler.py in _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16)
180 torch._amp_non_finite_check_and_unscale_(param.grad,
181 per_device_found_inf.get(param.grad.device),
--> 182 per_device_inv_scale.get(param.grad.device))
183
184 return per_device_found_inf._per_device_tensors
RuntimeError: Could not run 'aten::_amp_non_finite_check_and_unscale_' with arguments from the 'CPUTensorId' backend. 'aten::_amp_non_finite_check_and_unscale_' is only available for these backends: [CUDATensorId, VariableTensorId].