RuntimeError: CUDA error: device-side assert triggered when using Transformers

Hi!
I’m trying to follow the transformers tutorial of FastAIv2 but with data from the Arxiv dataset in Kaggle. I get a

RuntimeError: CUDA error: device-side assert triggered

when trying to run either of

learn.lr_find(suggestions=True)

or

learn.validate()

My code is relatively simple. Starting from the dataframe papers where I save the metadata about some arxiv articles, I do

# Import model and tokenizer
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')

# Create a transformation of the tokenizer according to the fastai tutorial above
class TransformersTokenizer(Transform):
    def __init__(self, tokenizer): self.tokenizer = tokenizer
    def encodes(self, x): 
        toks = self.tokenizer.tokenize(x)
        return tensor(self.tokenizer.convert_tokens_to_ids(toks))
    def decodes(self, x): return TitledStr(self.tokenizer.decode(x.cpu().numpy()))

# Create DataBlock: notice that in the tutorial they prefer using TfmdList
arxiv_lm = DataBlock(blocks=TextBlock.from_df(text_cols = 'abstract', is_lm=True),
                    get_x=ColReader('text'),
                    splitter = RandomSplitter(valid_pct=0.2, seed=None)
                    )

dls = arxiv_lm.dataloaders(papers, bs=64, tok_tfm= TransformersTokenizer(tokenizer))
dls.show_batch(max_n=6)

# This selects as output the first term in the prediction, also taken from the tutorial
class DropOutput(Callback):
    def after_pred(self): 
        self.learn.pred = self.pred[0]

#Learn model
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), cbs=[DropOutput], metrics=Perplexity()).to_fp16()
learn.lr_find(suggestions=True)

The biggest difference I see with the tutorial is the use of TfmdLists instead of the DataBlock, but I feel more comfortable.
I would like to know if there is some way of avoiding this error. I have also just discovered the Blurr library whose main use seems to be precisely combining HuggingFace transformers in FastAI. If you need more information you may want to check the notebook with the error and for completion the error is

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153     def _with_events(self, f, event_type, ex, final=noop):
--> 154         try:       self(f'before_{event_type}')       ;f()
    155         except ex: self(f'after_cancel_{event_type}')

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _do_fit(self)
    195             self.epoch=epoch
--> 196             self._with_events(self._do_epoch, 'epoch', CancelEpochException)
    197 

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153     def _with_events(self, f, event_type, ex, final=noop):
--> 154         try:       self(f'before_{event_type}')       ;f()
    155         except ex: self(f'after_cancel_{event_type}')

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _do_epoch(self)
    189     def _do_epoch(self):
--> 190         self._do_epoch_train()
    191         self._do_epoch_validate()

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _do_epoch_train(self)
    181         self.dl = self.dls.train
--> 182         self._with_events(self.all_batches, 'train', CancelTrainException)
    183 

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153     def _with_events(self, f, event_type, ex, final=noop):
--> 154         try:       self(f'before_{event_type}')       ;f()
    155         except ex: self(f'after_cancel_{event_type}')

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in all_batches(self)
    159         self.n_iter = len(self.dl)
--> 160         for o in enumerate(self.dl): self.one_batch(*o)
    161 

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in one_batch(self, i, b)
    177         self._split(b)
--> 178         self._with_events(self._do_one_batch, 'batch', CancelBatchException)
    179 

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153     def _with_events(self, f, event_type, ex, final=noop):
--> 154         try:       self(f'before_{event_type}')       ;f()
    155         except ex: self(f'after_cancel_{event_type}')

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _do_one_batch(self)
    162     def _do_one_batch(self):
--> 163         self.pred = self.model(*self.xb)
    164         self('after_pred')

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(

/opt/conda/lib/python3.7/site-packages/transformers/modeling_bert.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, output_attentions, output_hidden_states, return_dict)
    847             output_hidden_states=output_hidden_states,
--> 848             return_dict=return_dict,
    849         )

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(

/opt/conda/lib/python3.7/site-packages/transformers/modeling_bert.py in forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions, output_hidden_states, return_dict)
    482                     encoder_attention_mask,
--> 483                     output_attentions,
    484                 )

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(

/opt/conda/lib/python3.7/site-packages/transformers/modeling_bert.py in forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions)
    401             head_mask,
--> 402             output_attentions=output_attentions,
    403         )

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(

/opt/conda/lib/python3.7/site-packages/transformers/modeling_bert.py in forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions)
    338             encoder_attention_mask,
--> 339             output_attentions,
    340         )

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(

/opt/conda/lib/python3.7/site-packages/transformers/modeling_bert.py in forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions)
    239     ):
--> 240         mixed_query_layer = self.query(hidden_states)
    241 

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/linear.py in forward(self, input)
     92     def forward(self, input: Tensor) -> Tensor:
---> 93         return F.linear(input, self.weight, self.bias)
     94 

/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1686         if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
-> 1687             return handle_torch_function(linear, tens_ops, input, weight, bias=bias)
   1688     if input.dim() == 2 and bias is not None:

/opt/conda/lib/python3.7/site-packages/torch/overrides.py in handle_torch_function(public_api, relevant_args, *args, **kwargs)
   1062         # implementations can do equality/identity comparisons.
-> 1063         result = overloaded_arg.__torch_function__(public_api, types, args, kwargs)
   1064 

/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1691     else:
-> 1692         output = input.matmul(weight.t())
   1693         if bias is not None:

RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
<ipython-input-14-35d7aa25ab99> in <module>
----> 1 learn.lr_find(suggestions=True)

/opt/conda/lib/python3.7/site-packages/fastai/callback/schedule.py in lr_find(self, start_lr, end_lr, num_it, stop_div, show_plot, suggestions)
    222     n_epoch = num_it//len(self.dls.train) + 1
    223     cb=LRFinder(start_lr=start_lr, end_lr=end_lr, num_it=num_it, stop_div=stop_div)
--> 224     with self.no_logging(): self.fit(n_epoch, cbs=cb)
    225     if show_plot: self.recorder.plot_lr_find()
    226     if suggestions:

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    203             self.opt.set_hypers(lr=self.lr if lr is None else lr)
    204             self.n_epoch = n_epoch
--> 205             self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
    206 
    207     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    154         try:       self(f'before_{event_type}')       ;f()
    155         except ex: self(f'after_cancel_{event_type}')
--> 156         finally:   self(f'after_{event_type}')        ;final()
    157 
    158     def all_batches(self):

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in __call__(self, event_name)
    130     def ordered_cbs(self, event): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, event)]
    131 
--> 132     def __call__(self, event_name): L(event_name).map(self._call_one)
    133 
    134     def _call_one(self, event_name):

/opt/conda/lib/python3.7/site-packages/fastcore/foundation.py in map(self, f, gen, *args, **kwargs)
    177     def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step))
    178 
--> 179     def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args, gen=gen, **kwargs))
    180     def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, negate, **kwargs))
    181     def filter(self, f=noop, negate=False, gen=False, **kwargs):

/opt/conda/lib/python3.7/site-packages/fastcore/basics.py in map_ex(iterable, f, gen, *args, **kwargs)
    605     res = map(g, iterable)
    606     if gen: return res
--> 607     return list(res)
    608 
    609 # Cell

/opt/conda/lib/python3.7/site-packages/fastcore/basics.py in __call__(self, *args, **kwargs)
    595             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    596         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 597         return self.func(*fargs, **kwargs)
    598 
    599 # Cell

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _call_one(self, event_name)
    134     def _call_one(self, event_name):
    135         assert hasattr(event, event_name), event_name
--> 136         [cb(event_name) for cb in sort_by_run(self.cbs)]
    137 
    138     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in <listcomp>(.0)
    134     def _call_one(self, event_name):
    135         assert hasattr(event, event_name), event_name
--> 136         [cb(event_name) for cb in sort_by_run(self.cbs)]
    137 
    138     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

/opt/conda/lib/python3.7/site-packages/fastai/callback/core.py in __call__(self, event_name)
     42                (self.run_valid and not getattr(self, 'training', False)))
     43         res = None
---> 44         if self.run and _run: res = getattr(self, event_name, noop)()
     45         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     46         return res

/opt/conda/lib/python3.7/site-packages/fastai/callback/fp16.py in after_fit(self)
     67     run_before=TrainEvalCallback
     68     def before_fit(self): self.learn.model = convert_network(self.model, dtype=torch.float16)
---> 69     def after_fit(self): self.learn.model = convert_network(self.model, dtype=torch.float32)
     70 
     71 # Cell

/opt/conda/lib/python3.7/site-packages/fastai/fp16_utils.py in convert_network(network, dtype)
     66         if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
     67             continue
---> 68         convert_module(module, dtype)
     69         if isinstance(module, torch.nn.RNNBase) or isinstance(module, torch.nn.modules.rnn.RNNBase):
     70             module.flatten_parameters()

/opt/conda/lib/python3.7/site-packages/fastai/fp16_utils.py in convert_module(module, dtype)
     50         if param is not None:
     51             if param.data.dtype.is_floating_point:
---> 52                 param.data = param.data.to(dtype=dtype)
     53             if param._grad is not None and param._grad.data.dtype.is_floating_point:
     54                 param._grad.data = param._grad.data.to(dtype=dtype)

RuntimeError: CUDA error: device-side assert triggered