Multi-gpu runtime error

I have been experimenting with fastai multi-gpu training.

Spun up a multi-gpu instance on Jarvislabs.ai with fastai 2.5.0 having 2 RTX5000 gpus.

Then I followed the tutorial to perform multi-gpu training from the docs here and created a script (since can’t run parallel training from notebook) which is as follows:

from fastai.vision.all import *
from fastai.distributed import *

block = DataBlock(
        blocks=(ImageBlock(cls=PILImageBW), CategoryBlock),
        get_items=get_image_files,
        splitter=RandomSplitter(valid_pct=0.2, seed=42),
        get_y=parent_label,
        batch_tfms=aug_transforms(mult=2., do_flip=False))

path = untar_data(URLs.MNIST_SAMPLE)
loaders = block.dataloaders(path/"train")

learn = cnn_learner(loaders, resnet34, metrics = [accuracy])

suggestion = learn.lr_find()
with learn.distrib_ctx():
    learn.fit_one_cycle(5, lr_max = suggestion.valley)

learn.unfreeze()
suggestion = learn.lr_find()
with learn.distrib_ctx():
    learn.fit_one_cycle(5, lr_max = suggestion.valley)

The only difference from the tutorial is I am trying to use a smaller simpler MNIST dataset with one channel input and so I have changed my input block accordingly. While this script works with a single GPU, with multi-GPU I get the following error (full stack trace below)

root@dac36d87ac8b:~/customAugmentationExperiment/scripts# python -m fastai.launch parallel_training.py --gpus 0
epoch     train_loss  valid_loss  accuracy  time    ███████████--------------------------------------| 57.79% [89/154 00:07<00:05 2.0509]
Traceback (most recent call last):-------------------------------------------------------------------| 0.00% [0/77 00:00<00:00]
  File "parallel_training.py", line 22, in <module>
    learn.fit_one_cycle(5, lr_max = suggestion.valley)
  File "/opt/conda/lib/python3.8/site-packages/fastai/callback/schedule.py", line 113, in fit_one_cycle
    self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 221, in fit
    self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 163, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 212, in _do_fit
    self._with_events(self._do_epoch, 'epoch', CancelEpochException)
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 163, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 206, in _do_epoch
    self._do_epoch_train()
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 198, in _do_epoch_train
    self._with_events(self.all_batches, 'train', CancelTrainException)
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 163, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 169, in all_batches
    for o in enumerate(self.dl): self.one_batch(*o)
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 194, in one_batch
    self._with_events(self._do_one_batch, 'batch', CancelBatchException)
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 163, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 172, in _do_one_batch
    self.pred = self.model(*self.xb)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 799, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/activation.py", line 102, in forward
    return F.relu(input, inplace=self.inplace)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 1294, in relu
    return handle_torch_function(relu, (input,), input, inplace=inplace)
  File "/opt/conda/lib/python3.8/site-packages/torch/overrides.py", line 1252, in handle_torch_function
    result = overloaded_arg.__torch_function__(public_api, types, args, kwargs)
  File "/opt/conda/lib/python3.8/site-packages/fastai/torch_core.py", line 340, in __torch_function__
    res = super().__torch_function__(func, types, args=args, kwargs=kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/_tensor.py", line 1023, in __torch_function__
    ret = func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 1296, in relu
    result = torch.relu_(input)
RuntimeError: Output 0 of SyncBatchNormBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can remove this warning by cloning the output of the custom Function.
Traceback (most recent call last):
  File "parallel_training.py", line 22, in <module>
    learn.fit_one_cycle(5, lr_max = suggestion.valley)
  File "/opt/conda/lib/python3.8/site-packages/fastai/callback/schedule.py", line 113, in fit_one_cycle
    self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 221, in fit
    self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 163, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 212, in _do_fit
    self._with_events(self._do_epoch, 'epoch', CancelEpochException)
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 163, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 206, in _do_epoch
    self._do_epoch_train()
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 198, in _do_epoch_train
    self._with_events(self.all_batches, 'train', CancelTrainException)
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 163, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 169, in all_batches
    for o in enumerate(self.dl): self.one_batch(*o)
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 194, in one_batch
    self._with_events(self._do_one_batch, 'batch', CancelBatchException)
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 163, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/opt/conda/lib/python3.8/site-packages/fastai/learner.py", line 172, in _do_one_batch
    self.pred = self.model(*self.xb)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 799, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/activation.py", line 102, in forward
    return F.relu(input, inplace=self.inplace)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 1294, in relu
    return handle_torch_function(relu, (input,), input, inplace=inplace)
  File "/opt/conda/lib/python3.8/site-packages/torch/overrides.py", line 1252, in handle_torch_function
    result = overloaded_arg.__torch_function__(public_api, types, args, kwargs)
  File "/opt/conda/lib/python3.8/site-packages/fastai/torch_core.py", line 340, in __torch_function__
    res = super().__torch_function__(func, types, args=args, kwargs=kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/_tensor.py", line 1023, in __torch_function__
    ret = func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 1296, in relu
    result = torch.relu_(input)
RuntimeError: Output 0 of SyncBatchNormBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can remove this warning by cloning the output of the custom Function.

Can someone help me overcome this? I had tried killing all previous processes using pkill -9 python and also restarted the instance before running it i.e. run it afresh after instance started.

This is the hardware config for GPU machines

Thu Oct  7 10:06:00 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 470.42.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Quadro RTX 5000     Off  | 00000000:1E:00.0 Off |                  Off |
| 34%   39C    P8    25W / 230W |      0MiB / 16125MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Quadro RTX 5000     Off  | 00000000:3D:00.0 Off |                  Off |
| 33%   36C    P8    13W / 230W |      0MiB / 16125MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |

Thanks!

I am also having issues running this. Try running the nbs/examples/distrib.py first. I created an issue here: https://github.com/fastai/fastai/issues/3498