I tried to do this on lesson 1 of course-v3 but it didnāt work. Hereās what i did:
models_resnet34 = torch.nn.DataParallel(models.resnet34)
learn = ConvLearner(data, models_resnet34, metrics=error_rate)
Error:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-17-00b7619ab140> in <module>
1 models_resnet34 = torch.nn.DataParallel(models.resnet34)
----> 2 learn = ConvLearner(data, models_resnet34, metrics=error_rate)
~/venv-py36/lib64/python3.6/site-packages/fastai/vision/learner.py in __init__(self, data, arch, cut, pretrained, lin_ftrs, ps, custom_head, split_on, **kwargs)
52 meta = model_meta.get(arch, _default_meta)
53 torch.backends.cudnn.benchmark = True
---> 54 body = create_body(arch(pretrained), ifnone(cut,meta['cut']))
55 nf = num_features(body) * 2
56 head = custom_head or create_head(nf, data.c, lin_ftrs, ps)
~/venv-py36/lib64/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
475 result = self._slow_forward(*input, **kwargs)
476 else:
--> 477 result = self.forward(*input, **kwargs)
478 for hook in self._forward_hooks.values():
479 hook_result = hook(self, input, result)
~/venv-py36/lib64/python3.6/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
140 if len(self.device_ids) == 1:
141 return self.module(*inputs[0], **kwargs[0])
--> 142 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
143 outputs = self.parallel_apply(replicas, inputs, kwargs)
144 return self.gather(outputs, self.output_device)
~/venv-py36/lib64/python3.6/site-packages/torch/nn/parallel/data_parallel.py in replicate(self, module, device_ids)
145
146 def replicate(self, module, device_ids):
--> 147 return replicate(module, device_ids)
148
149 def scatter(self, inputs, kwargs, device_ids):
~/venv-py36/lib64/python3.6/site-packages/torch/nn/parallel/replicate.py in replicate(network, devices, detach)
9 num_replicas = len(devices)
10
---> 11 params = list(network.parameters())
12 param_indices = {param: idx for idx, param in enumerate(params)}
13 param_copies = Broadcast.apply(devices, *params)
AttributeError: 'function' object has no attribute 'parameters'