Mixed precision training

Continuing the documentation on the fastai_v1 development here is a brief piece about mixed precision training. A very nice and clear introduction to it is this video from NVIDIA.

What’s half precision?

In neural nets, all the computations are usually done in single precision, which means all the floats in all the arrays that represent inputs, activations, weights… are 32-bit floats (FP32 in the rest of this post). An idea to reduce memory usage (and avoid those annoying cuda errors) has been to try and do the same thing in half-precision, which means using 16-bits floats (or FP16 in the rest of this post). By definition, they take half the space in RAM, and in theory could allow you to double the size of your model and double your batch size.

Another very nice feature is that NVIDIA developed its latest GPUs (the Volta generation) to take fully advantage of half-precision tensors. Basically, if you give half-precision tensors to those, they’ll stack them so that each core can do more operations at the same time, and theoretically gives an 8x speed-up (sadly, just in theory).

So training at half precision is better for your memory usage, way faster if you have a Volta GPU (still a tiny bit faster if you don’t since the computations are easiest). How do we do it? Super easily in pytorch, we just have to put .half() everywhere: on the inputs of our model and all the parameters. Problem is that you usually won’t see the same accuracy in the end (so it happens sometimes) because half-precision is… well… not as precise ;).

Problems with half-precision:

To understand the problems with half precision, let’s look briefly at what an FP16 looks like (more information here).

half

The sign bit gives us +1 or -1, then we have 5 bits to code an exponent between -14 and 15, while the fraction part has the remaining 10 bits. Compared to FP32, we have a smaller range of possible values (2e-14 to 2e15 roughly, compared to 2e-126 to 2e127 for FP32) but also a smaller offset.

For instance, between 1 and 2, the FP16 format only represents the number 1, 1+2e-10, 1+2*2e-10… which means that 1 + 0.0001 = 1 in half precision. That’s what will cause a certain numbers of problems, specifically three that can occur and mess up your training.

  1. The weight update is imprecise: inside your optimizer, you basically do w = w - lr * w.grad for each weight of your network. The problem in performing this operation in half precision is that very often, w.grad is several orders of magnitude below w, and the learning rate is also small. The situation where w=1 and lr*w.grad is 0.0001 (or lower) is therefore very common, but the update doesn’t do anything in those cases.
  2. Your gradients can underflow. In FP16, your gradients can easily be replaced by 0 because they are too low.
  3. Your activations or loss can overflow. The opposite problem from the gradients: it’s easier to hit nan (or infinity) in FP16 precision, and your training might more easily diverge.

The solution: mixed precision training

To address those three problems, we don’t fully train in FP16 precision. As the name mixed training implies, some of the operations will be done in FP16, others in FP32. This is mainly to take care of the first problem listed aboved. For the next two there are additional tricks.

The main idea is that we want to do the forward pass and the gradient computation in half precision (to go fast) but the update in single precision (to be more precise). It’s okay if w and grad are both half floats, but when we do the operation w = w - lr * grad, we need to compute it in FP32. That way our 1 + 0.0001 is going to be 1.0001.

This is why we keep a copy of the weights in FP32 (called master model). Then, our training loop will look like:

  1. compute the output with the FP16 model, then the loss
  2. back-propagate the gradients in half-precision.
  3. copy the gradients in FP32 precision
  4. do the update on the master model (in FP32 precision)
  5. copy the master model in the FP16 model.

Note that we lose precision during step 5, and that the 1.0001 in one of the weights will go back to 1. But if the next update corresponds to add 0.0001 again, since the optimizer step is done on the master model, the 1.0001 will become 1.0002 and if we eventually go like this up to 1.0005, the FP16 model will be able to tell the difference.

That takes care of problem 1. For the second problem, we use something called gradient scaling: to avoid the gradients getting zeroed by the FP16 precision, we multiply the loss by a scale factor (scale=512 is a good value in our experiments). That way we can push the gradients to the right in the next figure, and have them not become zero.

Of course we don’t want those 512-scaled gradients to be in the weight update, so after converting them into FP32, we can divide them by this scale factor (once they have no risks of becoming 0). This changes the loop to:

  1. compute the output with the FP16 model, then the loss.
  2. multiply the loss by scale then back-propagate the gradients in half-precision.
  3. copy the gradients in FP32 precision then divide them by scale.
  4. do the update on the master model (in FP32 precision).
  5. copy the master model in the FP16 model.

For the last problem, the tricks offered by NVIDIA are to leave the batchnorm layers in single precision (they don’t have many weights so it’s not a big memory challenge) and compute the loss in single precision (which means converting the last output of the model in single precision before passing it to the loss).

Implementing all of this in the new callback system of fastai_v1 is surprisingly easy. You can see it in the notebook 004a, and it all fits in one callback where the code is simple to read. In practice, mixed-precision training roughly gives 2x boost of speed. To take full advantage of it check that

  • you are on a Volta GPU (which means an AWS p3 instance)
  • you aren’t slowed down by the CPU and data aug (it comes faster than you might think)
  • the weight matrices have dimensions that are all multiple of 8s (that’s to please the GPUs)

For now, this gets a full training of CIFAR10 to 94% with AdamW and 1cycle in 7min20s in a notebook (baseline is 6min45s, the fastai DawnBench entry). There are a few other tricks to add to get there but compared to 13-14min in single precision, it’s still a huge improvement!

63 Likes

This is awesome and I must say one of the main reasons to jump into fastai.

I am using learn = learn.to_fp16() in a 1080 Ti. Based on this https://www.pcper.com/reviews/Graphics-Cards/What-GP102-Could-Mean-NVIDIA I understand that FP16 performance on GP102 (1080 Ti, etc.) would be extremely poor (1:32 or 1:64 FP32) but I am not seeing that. Speed Im getting is the same as in fp32 and I am able to fit 2x batch size in memory.

Has anybody explored 1080 Ti FP16 performance in fastai?

3 Likes

Thanks for the awesome explanation. What does your overall memory consumption look like when using mixed precision vs single precision? Are you able to increase your batch sizes in practice or does keeping a single precision copy of your weights cancel out the single precision gains w/ respect to overall memory usage. I’m sure it varies based on which model you use, but any insights on what you’ve actually seen in practice would been interesting to know.

I didn’t try any of this. Didn’t have a lot of time to experiment with mixed precision with the development of the library.
If anyone has useful finds to share, I’d be interested too!

so, I was thinking of trying out the FP16 support in fastai today, but didn’t quite manage. I’m running on a V100 on GCP, and from what I see it does have support for FP16.

If I understand the docs correctly, all I need to do is append call the to_fp16() method on the Learner object (ConvLearner in my case), as documented here http://docs.fast.ai/callbacks.fp16.html#to_fp16
However, the kernel just dies when calling the fit variants on the learner object after converting it to fp16. Not only that, but the notebook cannot initalise the kernel on reloading. I’m forced to restart the whole jupyter process to resume work again.

Is this something you’ve run into before ? There isn’t much in jupyter logs either. How can I go about getting some more debug info on this ?

I’ve never seen that - but not sure anyone has tried on GCP before. I wonder if it’s a CUDA version issue. Show us output of fastai.show_install(0) please.

I tried running the notebook again, and I run into the same issue. The show_install(0) output and jupyter logs below.

Now, I can see that the notebook cannot reconnect back to the kernel after crashing, because of (zeromq)resources not being released properly. But, I don’t think much can be done about that, as it’s deeper in jupyter codebase.

But, I still don’t see anything relevant to CUDA/cuDNN causing the crash.

=== Software === 
python version  : 3.7.0
fastai version  : 1.0.14
torch version   : 1.0.0.dev20181025
nvidia driver   : 396.54
torch cuda ver  : 9.2.148
torch cuda is   : available
torch cudnn ver : 7104
torch cudnn is  : enabled

=== Hardware === 
nvidia gpus     : 1
torch available : 1
  - gpu0        : 16160MB | Tesla V100-SXM2-16GB

=== Environment === 
platform        : Linux-4.15.0-1023-gcp-x86_64-with-debian-buster-sid
distro          : #24-Ubuntu SMP Wed Oct 10 13:28:59 UTC 2018
conda env       : Unknown
python          : /home/suvash/.conda/envs/fastai-course/bin/python
sys.path        : 
/work/course-v3/nbs/dl1
/home/suvash/.conda/envs/fastai-course/lib/python37.zip
/home/suvash/.conda/envs/fastai-course/lib/python3.7
/home/suvash/.conda/envs/fastai-course/lib/python3.7/lib-dynload
/home/suvash/.conda/envs/fastai-course/lib/python3.7/site-packages
/home/suvash/.conda/envs/fastai-course/lib/python3.7/site-packages/IPython/extensions
/home/suvash/.ipython
Oct 27 11:19:02 suvash-1 jupyter-notebook[3695]: [W 11:19:02.201 NotebookApp] Notebook course-v3/nbs/dl1/dhcd-resnet32-version-1.ipynb is not trusted
Oct 27 11:19:03 suvash-1 jupyter-notebook[3695]: [I 11:19:03.003 NotebookApp] Kernel started: bdb044ff-db4f-4a61-b5f7-0fdd7d458c72
Oct 27 11:19:03 suvash-1 jupyter-notebook[3695]: [I 11:19:03.400 NotebookApp] Adapting to protocol v5.1 for kernel bdb044ff-db4f-4a61-b5f7-0fdd7d458c72
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]: [I 11:20:33.000 NotebookApp] KernelRestarter: restarting kernel (1/5), keep random ports
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]: WARNING:root:kernel bdb044ff-db4f-4a61-b5f7-0fdd7d458c72 restarted
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]: Traceback (most recent call last):
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:   File "/home/suvash/.conda/envs/fastai-course/lib/python3.7/runpy.py", line 193, in _run_module_as_main
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:     "__main__", mod_spec)
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:   File "/home/suvash/.conda/envs/fastai-course/lib/python3.7/runpy.py", line 85, in _run_code
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:     exec(code, run_globals)
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:   File "/home/suvash/.conda/envs/fastai-course/lib/python3.7/site-packages/ipykernel_launcher.py", line 16, in <module>
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:     app.launch_new_instance()
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:   File "/home/suvash/.conda/envs/fastai-course/lib/python3.7/site-packages/traitlets/config/application.py", line 657, in launch_instance
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:     app.initialize(argv)
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:   File "<decorator-gen-123>", line 2, in initialize
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:   File "/home/suvash/.conda/envs/fastai-course/lib/python3.7/site-packages/traitlets/config/application.py", line 87, in catch_config_error
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:     return method(app, *args, **kwargs)
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:   File "/home/suvash/.conda/envs/fastai-course/lib/python3.7/site-packages/ipykernel/kernelapp.py", line 469, in initialize
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:     self.init_sockets()
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:   File "/home/suvash/.conda/envs/fastai-course/lib/python3.7/site-packages/ipykernel/kernelapp.py", line 238, in init_sockets
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:     self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:   File "/home/suvash/.conda/envs/fastai-course/lib/python3.7/site-packages/ipykernel/kernelapp.py", line 180, in _bind_socket
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:     s.bind("tcp://%s:%i" % (self.ip, port))
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:   File "zmq/backend/cython/socket.pyx", line 547, in zmq.backend.cython.socket.Socket.bind
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]:   File "zmq/backend/cython/checkrc.pxd", line 25, in zmq.backend.cython.checkrc._check_rc
Oct 27 11:20:33 suvash-1 jupyter-notebook[3695]: zmq.error.ZMQError: Address already in use
Oct 27 11:20:36 suvash-1 jupyter-notebook[3695]: [I 11:20:36.009 NotebookApp] KernelRestarter: restarting kernel (2/5), keep random ports
Oct 27 11:20:36 suvash-1 jupyter-notebook[3695]: WARNING:root:kernel bdb044ff-db4f-4a61-b5f7-0fdd7d458c72 restarted
Oct 27 11:20:36 suvash-1 jupyter-notebook[3695]: Traceback (most recent call last):
Oct 27 11:20:36 suvash-1 jupyter-notebook[3695]:   File "/home/suvash/.conda/envs/fastai-course/lib/python3.7/runpy.py", line 193, in _run_module_as_main
Oct 27 11:20:36 suvash-1 jupyter-notebook[3695]:     "__main__", mod_spec)
Oct 27 11:20:36 suvash-1 jupyter-notebook[3695]:   File "/home/suvash/.conda/envs/fastai-course/lib/python3.7/runpy.py", line 85, in _run_code
Oct 27 11:20:36 suvash-1 jupyter-notebook[3695]:     exec(code, run_globals)
Oct 27 11:20:36 suvash-1 jupyter-notebook[3695]:   File "/home/suvash/.conda/envs/fastai-course/lib/python3.7/site-packages/ipykernel_launcher.py", line 16, in <module>
Oct 27 11:20:36 suvash-1 jupyter-notebook[3695]:     app.launch_new_instance()
Oct 27 11:20:36 suvash-1 jupyter-notebook[3695]:   File "/home/suvash/.conda/envs/fastai-course/lib/python3.7/site-packages/traitlets/config/application.py", line 657, in launch_instance

( As you can see I’m not running fastai image, but that shouldn’t be an issue either. I prebuild an image with the 396 drivers and run pre-emptive instances based on that.)

How do I do predictions with a learner after doing to_fp16?
If I try to do:

learn.get_preds()
>>---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-166-b1155c30714e> in <module>
----> 1 p,t = learn.get_preds()

~/anaconda3/envs/fastai/lib/python3.6/site-packages/fastai/basic_train.py in get_preds(self, is_test)
    180     def get_preds(self, is_test:bool=False) -> List[Tensor]:
    181         "Return predictions and targets on the valid or test set, depending on `is_test`."
--> 182         return get_preds(self.model, self.data.holdout(is_test), cb_handler=CallbackHandler(self.callbacks, []))
    183 
    184     def validate(self, dl=None, callbacks=None, metrics=None):

~/anaconda3/envs/fastai/lib/python3.6/site-packages/fastai/basic_train.py in get_preds(model, dl, pbar, cb_handler)
     34 def get_preds(model:Model, dl:DataLoader, pbar:Optional[PBar]=None, cb_handler:Optional[CallbackHandler]=None) -> List[Tensor]:
     35     "Predict the output of the elements in the dataloader."
---> 36     return [torch.cat(o).cpu() for o in zip(*validate(model, dl, pbar=pbar, cb_handler=cb_handler, average=False))]
     37 
     38 def validate(model:Model, dl:DataLoader, loss_func:OptLossFunc=None,

~/anaconda3/envs/fastai/lib/python3.6/site-packages/fastai/basic_train.py in validate(model, dl, loss_func, cb_handler, pbar, average)
     45         for xb,yb in progress_bar(dl, parent=pbar, leave=(pbar is not None)):
     46             if cb_handler: xb, yb = cb_handler.on_batch_begin(xb, yb, train=False)
---> 47             val_losses.append(loss_batch(model, xb, yb, loss_func, cb_handler=cb_handler))
     48             if not is_listy(yb): yb = [yb]
     49             nums.append(yb[0].shape[0])

~/anaconda3/envs/fastai/lib/python3.6/site-packages/fastai/basic_train.py in loss_batch(model, xb, yb, loss_func, opt, cb_handler)
     16     if not is_listy(xb): xb = [xb]
     17     if not is_listy(yb): yb = [yb]
---> 18     out = model(*xb)
     19     out = cb_handler.on_loss_begin(out)
     20 

~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

~/anaconda3/envs/fastai/lib/python3.6/site-packages/fastai/vision/models/darknet.py in forward(self, x)
     31         self.layers = nn.Sequential(*layers)
     32 
---> 33     def forward(self, x): return self.layers(x)

~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
     90     def forward(self, input):
     91         for module in self._modules.values():
---> 92             input = module(input)
     93         return input
     94 

~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
     90     def forward(self, input):
     91         for module in self._modules.values():
---> 92             input = module(input)
     93         return input
     94 

~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/conv.py in forward(self, input)
    311     def forward(self, input):
    312         return F.conv2d(input, self.weight, self.bias, self.stride,
--> 313                         self.padding, self.dilation, self.groups)
    314 
    315 

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same

I experience this. As a workaround, I’ve been recreating the learner without mixed precision then reloading the weights which allows get_preds to work. a bug though it looks like.

How do you do that, I am unable to do so. The problem appers to come from the validate() function.

Just do the same steps without the instruction to move to_fp16. Workaround …

train:

learn = create_cnn(args)
learn = learn.to_fp16()
learn.fit(args)
learn.save('weights')

predict:

learn = create_cnn(args)
learn.load('weights')
preds,y = learn.get_preds()
2 Likes

does not work for me, get the same error:
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same
Its kind of logical, I am loading fp16 weights, and trying to eval in fp32 validation.

Hmm, OK interesting. I had (possibly mistakenly) assumed master weights themselves would be kept fp32 while resource eating calcs would be in fp16. (EDIT: post 1 seems to say as such). It works for me just fine in my use case, but will keep an eye on this thread while I try to understand mixed precision a little more.

This seems to work:

learn = Learner(data128, darknet53, metrics=[accuracy_thresh, f1])
learn.to_fp16()
learn.load('darknet53_128')

Then you add this

learn.data.valid_dl.add_tfm(to_half)

and then only a warning:

p_v,t_v = learn.get_preds()
>>/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py:1126: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
  warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")

And the output is fp32:

p_v.dtype
>>torch.float32
3 Likes

@suvash Were you able to solve for this?Have been facing something similar .Notebook dies when trying to use fp16 on a v100 ,I am also on GCP

You’re using CUDA 9.2, I think that’s the issue.

I’ve been doing mixed-precision tests on GCP V100s as well as PaperSpace V100s, both without a hitch.

I made this ansible script: https://vxlabs.com/2018/11/21/a-simple-ansible-script-to-convert-a-clean-ubuntu-18-04-to-a-cuda-10-pytorch-1-0rc-fastai-miniconda3-deep-learning-machine/

… if you follow the instructions carefully (most important is to download CUDNN locally before running) it will provision your clean Ubuntu 18.04 GCE or PaperSpace node with CUDA 10, PyTorch 1.0 dev (a Python 3.7 build . I made), fastai and a starter conda environment within which mixed-precision training with fastai works like a charm.

2 Likes

This does look interesting. You’re talking about the 410 drivers for CUDA 10, right ?
I’m hoping there’ll be an official pytorch build with CUDA 10 support soon. This looks like your own custom build ?

I also had this issue.
I remove pytorch completely and installed it again. Then everything worked fine.
Even though I could be able to increase the batch size, it didn’t give me any speed improvement.

Tangent question - has anyone tried MPT on the newer RTX cards? While they don’t have the full capacity of a Volta, supposedly they do not have FP16 crippled like the 10-series cards do, so I’m curious to see the results.

It is indeed using my custom build of PyTorch, currently of the November 24 PyTorch master.