I am trying to figure out how to use multiple GPUs to speed up training for my segmentation model.
I looked at PyTorch’s documentation (
nn.DataParallel) and this link. However, I have not had success so far.
My first attempt was something like this:
if torch.cuda.device_count() > 1:
wrapped_model = nn.DataParallel(learner.model)
learner.model = wrapped_model.module
This does not have the intended effect. I only see 1 GPU being used.
I also saw the documentation here but from what I can tell
unet_learner does not have the
parallel_ctx context manager.
The other thing I tried doing was:
callbacks = [
learner.fine_tune(20, freeze_epochs=2, wd=0.01, base_lr=0.0006, cbs=callbacks)
This is more promising, but I end up with the following error message:
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_batch_norm)
cc @sgugger if you have any advice.
Also, opened an issue just in case that makes the discussion easier.
Are you running in a notebook or in a script?
Can you share a minimal (non-)working example, preferably on one of the datasets available in fastai? I have a multi-gpu setup, I can try and run it and see what I get
Let me create a sample script. Thanks for the help !
DynamicUNets and parallel learners don’t work well together. So you should be using
DistributedDataParallel instead. That works well.
@rahulrav I am trying to use to Distributed training for unet_learner but I am facing some issues. Will you be able to guide me through the process?
Problem in Distributed Training (fastai - V1) for UNet Image Super Resolutionstrong text
DistributedDataLearner the way go. You see an approximate linear improvement on the rate of learning. So learning with 2 GPUs is ~2x as fast. With 4 is about ~3.7x as fast etc.
Let me upload a sample with DDL. There are a bunch of samples already, but they are a bit hard to discover in the repo.
fastai has a neat launcher script that makes the setup pretty simple and has nice
rank0 helpers. I will also send a PR to improve the docs around DDL.
Also look at the
train_imagenette.py example and ignore the parts that support the DistributedLearner.