Fastai v1.0.54[LINUX] - Unet : resnet50 : lr_find()/fit() error: 'CUDA error: an illegal memory access was encountered'

Error in unet_learner with resnet50 backbone.

OS: Ubuntu 18.04
fastai v1.0.54
torch. version : 1.1.0
torch.backends.cudnn.version() : 7501
torch.version.cuda : 10.0.130

Code to reproduce:

from fastai.vision.learner import unet_learner
from torchvision.models import resnet50

learn = unet_learner(data, resnet50)

learn.lr_find()

Error stack trace:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~/.conda/envs/arcgis/lib/python3.6/site-packages/fastai/basic_train.py in fit(epochs, learn, callbacks, metrics)
    100                 xb, yb = cb_handler.on_batch_begin(xb, yb)
--> 101                 loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler)
    102                 if cb_handler.on_batch_end(loss): break

~/.conda/envs/arcgis/lib/python3.6/site-packages/fastai/basic_train.py in loss_batch(model, xb, yb, loss_func, opt, cb_handler)
     25     if not is_listy(yb): yb = [yb]
---> 26     out = model(*xb)
     27     out = cb_handler.on_loss_begin(out)

~/.conda/envs/arcgis/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():

~/.conda/envs/arcgis/lib/python3.6/site-packages/fastai/layers.py in forward(self, x)
    135             res.orig = x
--> 136             nres = l(res)
    137             # We have to remove res.orig to avoid hanging refs and therefore memory leaks

~/.conda/envs/arcgis/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():

~/.conda/envs/arcgis/lib/python3.6/site-packages/fastai/vision/models/unet.py in forward(self, up_in)
     29         print(up_in.shape)
---> 30         up_out = self.shuf(up_in)
     31         ssh = s.shape[-2:]

~/.conda/envs/arcgis/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():

~/.conda/envs/arcgis/lib/python3.6/site-packages/fastai/layers.py in forward(self, x)
    216     def forward(self,x):
--> 217         x = self.shuf(self.relu(self.conv(x)))
    218         return self.blur(self.pad(x)) if self.blur else x

~/.conda/envs/arcgis/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():

~/.conda/envs/arcgis/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
     91         for module in self._modules.values():
---> 92             input = module(input)
     93         return input

~/.conda/envs/arcgis/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():

~/.conda/envs/arcgis/lib/python3.6/site-packages/torch/nn/modules/conv.py in forward(self, input)
    337         return F.conv2d(input, self.weight, self.bias, self.stride,
--> 338                         self.padding, self.dilation, self.groups)
    339 

RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
<ipython-input-6-1d9bc8b090e2> in <module>
----> 1 model.lr_find()

~/.conda/envs/arcgis/lib/python3.6/site-packages/arcgis/learn/models/_arcgis_model.py in lr_find(self, allow_plot)
    283         self._check_requisites()
    284 
--> 285         self.learn.lr_find()
    286         from IPython.display import clear_output
    287         clear_output()

~/.conda/envs/arcgis/lib/python3.6/site-packages/fastai/train.py in lr_find(learn, start_lr, end_lr, num_it, stop_div, wd)
     30     cb = LRFinder(learn, start_lr, end_lr, num_it, stop_div)
     31     epochs = int(np.ceil(num_it/len(learn.data.train_dl)))
---> 32     learn.fit(epochs, start_lr, callbacks=[cb], wd=wd)
     33 
     34 def to_fp16(learn:Learner, loss_scale:float=None, max_noskip:int=1000, dynamic:bool=True, clip:float=None,

~/.conda/envs/arcgis/lib/python3.6/site-packages/fastai/basic_train.py in fit(self, epochs, lr, wd, callbacks)
    198         callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
    199         if defaults.extra_callbacks is not None: callbacks += defaults.extra_callbacks
--> 200         fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
    201 
    202     def create_opt(self, lr:Floats, wd:Floats=0.)->None:

~/.conda/envs/arcgis/lib/python3.6/site-packages/fastai/basic_train.py in fit(epochs, learn, callbacks, metrics)
    110         exception = e
    111         raise
--> 112     finally: cb_handler.on_train_end(exception)
    113 
    114 loss_func_name2activ = {'cross_entropy_loss': F.softmax, 'nll_loss': torch.exp, 'poisson_nll_loss': torch.exp,

~/.conda/envs/arcgis/lib/python3.6/site-packages/fastai/callback.py in on_train_end(self, exception)
    321     def on_train_end(self, exception:Union[bool,Exception])->None:
    322         "Handle end of training, `exception` is an `Exception` or False if no exceptions during training."
--> 323         self('train_end', exception=exception)
    324 
    325     @property

~/.conda/envs/arcgis/lib/python3.6/site-packages/fastai/callback.py in __call__(self, cb_name, call_mets, **kwargs)
    249         if call_mets:
    250             for met in self.metrics: self._call_and_update(met, cb_name, **kwargs)
--> 251         for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)
    252 
    253     def set_dl(self, dl:DataLoader):

~/.conda/envs/arcgis/lib/python3.6/site-packages/fastai/callback.py in _call_and_update(self, cb, cb_name, **kwargs)
    239     def _call_and_update(self, cb, cb_name, **kwargs)->None:
    240         "Call `cb_name` on `cb` and update the inner state."
--> 241         new = ifnone(getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs), dict())
    242         for k,v in new.items():
    243             if k not in self.state_dict:

~/.conda/envs/arcgis/lib/python3.6/site-packages/fastai/callbacks/lr_finder.py in on_train_end(self, **kwargs)
     33     def on_train_end(self, **kwargs:Any)->None:
     34         "Cleanup learn model weights disturbed during LRFinder exploration."
---> 35         self.learn.load('tmp', purge=False)
     36         if hasattr(self.learn.model, 'reset'): self.learn.model.reset()
     37         for cb in self.callbacks:

~/.conda/envs/arcgis/lib/python3.6/site-packages/fastai/basic_train.py in load(self, file, device, strict, with_opt, purge, remove_module)
    265         elif isinstance(device, int): device = torch.device('cuda', device)
    266         source = self.path/self.model_dir/f'{file}.pth' if is_pathlike(file) else file
--> 267         state = torch.load(source, map_location=device)
    268         if set(state.keys()) == {'model', 'opt'}:
    269             model_state = state['model']

~/.conda/envs/arcgis/lib/python3.6/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
    385         f = f.open('rb')
    386     try:
--> 387         return _load(f, map_location, pickle_module, **pickle_load_args)
    388     finally:
    389         if new_fd:

~/.conda/envs/arcgis/lib/python3.6/site-packages/torch/serialization.py in _load(f, map_location, pickle_module, **pickle_load_args)
    572     unpickler = pickle_module.Unpickler(f, **pickle_load_args)
    573     unpickler.persistent_load = persistent_load
--> 574     result = unpickler.load()
    575 
    576     deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)

~/.conda/envs/arcgis/lib/python3.6/site-packages/torch/_utils.py in _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks)
    131 
    132 def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
--> 133     tensor = _rebuild_tensor(storage, storage_offset, size, stride)
    134     tensor.requires_grad = requires_grad
    135     # NB: This line exists only for backwards compatibility; the

~/.conda/envs/arcgis/lib/python3.6/site-packages/torch/_utils.py in _rebuild_tensor(storage, storage_offset, size, stride)
    126 def _rebuild_tensor(storage, storage_offset, size, stride):
    127     # first construct a tensor with the correct dtype/device
--> 128     t = torch.tensor([], dtype=storage.dtype, device=storage.device)
    129     return t.set_(storage, storage_offset, size, stride)
    130 

RuntimeError: CUDA error: an illegal memory access was encountered

I wonder what storage.device is in the last part of the stack trace.