"learn.predict" doesn't seem to handle arguments set to None

So, I’m attempting to use huggingface transformer models that sometimes require None to be passed in. For example, the forward() function to the XLMForQuestionAnswering models looks like this:

input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, lengths=None, cache=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, is_impossible=None, cls_index=None, p_mask=None

I need to be able to pass in None to everything but input_ids, attention_mask, token_type_ids, cls_index, and p_mask. I can do this and everything trains fine, BUT

… when I call learn.predict('I really like everything') I get the following error:

TypeError                                 Traceback (most recent call last)
~/development/projects/blurr/_libs/fastai2/fastai2/torch_core.py in to_concat(xs, dim)
    216     #   in this case we return a big list
--> 217     try:    return retain_type(torch.cat(xs, dim=dim), xs[0])
    218     except: return sum([L(retain_type(o_.index_select(dim, tensor(i)).squeeze(dim), xs[0])

TypeError: expected Tensor as element 0 in argument 0, but got NoneType

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
<ipython-input-24-986299fba05f> in <module>
----> 1 learn.predict('I really liked the movie')

~/development/projects/blurr/_libs/fastai2/fastai2/learner.py in predict(self, item, rm_type_tfms, with_input)
    240     def predict(self, item, rm_type_tfms=None, with_input=False):
    241         dl = self.dls.test_dl([item], rm_type_tfms=rm_type_tfms, num_workers=0)
--> 242         inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
    243         i = getattr(self.dls, 'n_inp', -1)
    244         inp = (inp,) if i==1 else tuplify(inp)

~/development/projects/blurr/_libs/fastai2/fastai2/learner.py in get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, **kwargs)
    227             for mgr in ctx_mgrs: stack.enter_context(mgr)
    228             self(event.begin_epoch if inner else _before_epoch)
--> 229             self._do_epoch_validate(dl=dl)
    230             self(event.after_epoch if inner else _after_epoch)
    231             if act is None: act = getattr(self.loss_func, 'activation', noop)

~/development/projects/blurr/_libs/fastai2/fastai2/learner.py in _do_epoch_validate(self, ds_idx, dl)
    183             with torch.no_grad(): self.all_batches()
    184         except CancelValidException:                         self('after_cancel_validate')
--> 185         finally:                                             self('after_validate')
    187     @log_args(but='cbs')

~/development/projects/blurr/_libs/fastai2/fastai2/learner.py in __call__(self, event_name)
    132     def ordered_cbs(self, event): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, event)]
--> 134     def __call__(self, event_name): L(event_name).map(self._call_one)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)

~/development/projects/blurr/_libs/fastcore/fastcore/foundation.py in map(self, f, *args, **kwargs)
    374              else f.format if isinstance(f,str)
    375              else f.__getitem__)
--> 376         return self._new(map(g, self))
    378     def filter(self, f, negate=False, **kwargs):

~/development/projects/blurr/_libs/fastcore/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    325     @property
    326     def _xtra(self): return None
--> 327     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    328     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
    329     def copy(self): return self._new(self.items.copy())

~/development/projects/blurr/_libs/fastcore/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
     45             return x
---> 47         res = super().__call__(*((x,) + args), **kwargs)
     48         res._newchk = 0
     49         return res

~/development/projects/blurr/_libs/fastcore/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    316         if items is None: items = []
    317         if (use_list is not None) or not _is_array(items):
--> 318             items = list(items) if use_list else _listify(items)
    319         if match is not None:
    320             if is_coll(match): match = len(match)

~/development/projects/blurr/_libs/fastcore/fastcore/foundation.py in _listify(o)
    252     if isinstance(o, list): return o
    253     if isinstance(o, str) or _is_array(o): return [o]
--> 254     if is_iter(o): return list(o)
    255     return [o]

~/development/projects/blurr/_libs/fastcore/fastcore/foundation.py in __call__(self, *args, **kwargs)
    218             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    219         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 220         return self.fn(*fargs, **kwargs)
    222 # Cell

~/development/projects/blurr/_libs/fastai2/fastai2/learner.py in _call_one(self, event_name)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    139     def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)

~/development/projects/blurr/_libs/fastai2/fastai2/learner.py in <listcomp>(.0)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    139     def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)

~/development/projects/blurr/_libs/fastai2/fastai2/callback/core.py in __call__(self, event_name)
     22         _run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
     23                (self.run_valid and not getattr(self, 'training', False)))
---> 24         if self.run and _run: getattr(self, event_name, noop)()
     25         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit

~/development/projects/blurr/_libs/fastai2/fastai2/callback/core.py in after_validate(self)
     93     def after_validate(self):
     94         "Concatenate all recorded tensors"
---> 95         if self.with_input:     self.inputs  = detuplify(to_concat(self.inputs, dim=self.concat_dim))
     96         if not self.save_preds: self.preds   = detuplify(to_concat(self.preds, dim=self.concat_dim))
     97         if not self.save_targs: self.targets = detuplify(to_concat(self.targets, dim=self.concat_dim))

~/development/projects/blurr/_libs/fastai2/fastai2/torch_core.py in to_concat(xs, dim)
    211 def to_concat(xs, dim=0):
    212     "Concat the element in `xs` (recursively if they are tuples/lists of tensors)"
--> 213     if is_listy(xs[0]): return type(xs[0])([to_concat([x[i] for x in xs], dim=dim) for i in range_of(xs[0])])
    214     if isinstance(xs[0],dict):  return {k: to_concat([x[k] for x in xs], dim=dim) for k in xs[0].keys()}
    215     #We may receives xs that are not concatenatable (inputs of a text classifier for instance),

~/development/projects/blurr/_libs/fastai2/fastai2/torch_core.py in <listcomp>(.0)
    211 def to_concat(xs, dim=0):
    212     "Concat the element in `xs` (recursively if they are tuples/lists of tensors)"
--> 213     if is_listy(xs[0]): return type(xs[0])([to_concat([x[i] for x in xs], dim=dim) for i in range_of(xs[0])])
    214     if isinstance(xs[0],dict):  return {k: to_concat([x[k] for x in xs], dim=dim) for k in xs[0].keys()}
    215     #We may receives xs that are not concatenatable (inputs of a text classifier for instance),

~/development/projects/blurr/_libs/fastai2/fastai2/torch_core.py in to_concat(xs, dim)
    217     try:    return retain_type(torch.cat(xs, dim=dim), xs[0])
    218     except: return sum([L(retain_type(o_.index_select(dim, tensor(i)).squeeze(dim), xs[0])
--> 219                           for i in range_of(o_)) for o_ in xs], L())
    221 # Cell

~/development/projects/blurr/_libs/fastai2/fastai2/torch_core.py in <listcomp>(.0)
    217     try:    return retain_type(torch.cat(xs, dim=dim), xs[0])
    218     except: return sum([L(retain_type(o_.index_select(dim, tensor(i)).squeeze(dim), xs[0])
--> 219                           for i in range_of(o_)) for o_ in xs], L())
    221 # Cell

~/development/projects/blurr/_libs/fastcore/fastcore/utils.py in range_of(x)
    170 def range_of(x):
    171     "All indices of collection `x` (i.e. `list(range(len(x)))`)"
--> 172     return list(range(len(x)))
    174 # Cell

TypeError: object of type 'NoneType' has no len()

Is there something I need to do to make this work? Or does something need to be updated with the library to support this use case?

Thanks - wg

I don’t see how to adapt predict to make this work in this case. fastai is not designed to leave None arguments alone, and supporting it would require quite q bit of work.