Superres Multi GPU

I am trying to train the superres notebook from 2019 lesson 7 in multi GPU. I tried using the following code.

learn.model = torch.nn.DataParallel(learn.model, device_ids=[0, 1])

But hit an error while fit:

RuntimeError Traceback (most recent call last)
in ()
----> 1 do_fit(‘1a’, slice(lr*10))

in do_fit(save_name, lrs, pct_start)
1 def do_fit(save_name, lrs=slice(lr), pct_start=0.9):
----> 2 learn.fit_one_cycle(10, lrs, pct_start=pct_start)
3 learn.save(save_name)
4 learn.show_results(rows=1, imgsize=5)

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/fastai/train.py in fit_one_cycle(learn, cyc_len, max_lr, moms, div_factor, pct_start, final_div, wd, callbacks, tot_epochs, start_epoch)
20 callbacks.append(OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor, pct_start=pct_start,
21 final_div=final_div, tot_epochs=tot_epochs, start_epoch=start_epoch))
—> 22 learn.fit(cyc_len, max_lr, wd=wd, callbacks=callbacks)
23
24 def lr_find(learn:Learner, start_lr:Floats=1e-7, end_lr:Floats=10, num_it:int=100, stop_div:bool=True, wd:float=None):

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/fastai/basic_train.py in fit(self, epochs, lr, wd, callbacks)
197 callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
198 if defaults.extra_callbacks is not None: callbacks += defaults.extra_callbacks
–> 199 fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
200
201 def create_opt(self, lr:Floats, wd:Floats=0.)->None:

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/fastai/basic_train.py in fit(epochs, learn, callbacks, metrics)
99 for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):
100 xb, yb = cb_handler.on_batch_begin(xb, yb)
–> 101 loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler)
102 if cb_handler.on_batch_end(loss): break
103

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/fastai/basic_train.py in loss_batch(model, xb, yb, loss_func, opt, cb_handler)
24 if not is_listy(xb): xb = [xb]
25 if not is_listy(yb): yb = [yb]
—> 26 out = model(*xb)
27 out = cb_handler.on_loss_begin(out)
28

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
–> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
150 return self.module(*inputs[0], **kwargs[0])
151 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
–> 152 outputs = self.parallel_apply(replicas, inputs, kwargs)
153 return self.gather(outputs, self.output_device)
154

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in parallel_apply(self, replicas, inputs, kwargs)
160
161 def parallel_apply(self, replicas, inputs, kwargs):
–> 162 return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
163
164 def gather(self, outputs, output_device):

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices)
81 output = results[i]
82 if isinstance(output, Exception):
—> 83 raise output
84 outputs.append(output)
85 return outputs

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py in _worker(i, module, input, kwargs, device)
57 if not isinstance(input, (list, tuple)):
58 input = (input,)
—> 59 output = module(*input, **kwargs)
60 with lock:
61 results[i] = output

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
–> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/fastai/layers.py in forward(self, x)
153 for l in self.layers:
154 res.orig = x
–> 155 nres = l(res)
156 # We have to remove res.orig to avoid hanging refs and therefore memory leaks
157 res.orig = None

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
–> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/fastai/vision/models/unet.py in forward(self, up_in)
32 if ssh != up_out.shape[-2:]:
33 up_out = F.interpolate(up_out, s.shape[-2:], mode=‘nearest’)
—> 34 cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
35 return self.conv2(self.conv1(cat_x))
36

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
–> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py in forward(self, input)
81 input, self.running_mean, self.running_var, self.weight, self.bias,
82 self.training or not self.track_running_stats,
—> 83 exponential_average_factor, self.eps)
84
85 def extra_repr(self):

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/functional.py in batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
1695 return torch.batch_norm(
1696 input, weight, bias, running_mean, running_var,
-> 1697 training, momentum, eps, torch.backends.cudnn.enabled
1698 )
1699

RuntimeError: Expected tensor for argument #1 ‘input’ to have the same device as tensor for argument #2 ‘weight’; but device 2 does not equal 0 (while checking arguments for cudnn_batch_norm)

Same problem here… I’m not able to parallelize camvid segmentation because of this :frowning:

Same problem for me, it looks like Unets don’t like the DataParallel() wrapper too much. People on the forum suggest to parallelize through to_distributed() but as far as I understand this approach is not viable inside a Jupyter Notebook, one has to convert his notebooks to proper python scripts and then run them via command line.