ShowGraphCallback and learn.predict do not work together

The code:

learner = language_model_learner(dls, AWD_LSTM, cbs=[ShowGraphCallback])
learner.predict("this movie was", 10, temperature=0.75) 

throws the following error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-44-9c6dc26dc508> in <module>
----> 1 learner.predict("this movie was", 10, temperature=0.75)

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastai/text/learner.py in predict(self, text, n_words, no_unk, temperature, min_p, no_bar, decoder, only_last_word)
    161         if no_unk: unk_idx = self.dls.vocab.index(UNK)
    162         for _ in (range(n_words) if no_bar else progress_bar(range(n_words), leave=False)):
--> 163             with self.no_bar(): preds,_ = self.get_preds(dl=[(idxs[None],)])
    164             res = preds[0][-1]
    165             if no_unk: res[unk_idx] = 0.

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastai/text/learner.py in get_preds(self, concat_dim, **kwargs)
    179 
    180     @delegates(Learner.get_preds)
--> 181     def get_preds(self, concat_dim=1, **kwargs): return super().get_preds(concat_dim=1, **kwargs)
    182 
    183 # Cell

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastai/learner.py in get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)
    232         ctx_mgrs = self.validation_context(cbs=L(cbs)+[cb], inner=inner)
    233         if with_loss: ctx_mgrs.append(self.loss_not_reduced())
--> 234         with ContextManagers(ctx_mgrs):
    235             self._do_epoch_validate(dl=dl)
    236             if act is None: act = getattr(self.loss_func, 'activation', noop)

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/utils.py in __enter__(self)
    645     "Wrapper for `contextlib.ExitStack` which enters a collection of context managers"
    646     def __init__(self, mgrs): self.default,self.stack = L(mgrs),ExitStack()
--> 647     def __enter__(self): self.default.map(self.stack.enter_context)
    648     def __exit__(self, *args, **kwargs): self.stack.__exit__(*args, **kwargs)
    649 

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in map(self, f, *args, **kwargs)
    270              else f.format if isinstance(f,str)
    271              else f.__getitem__)
--> 272         return self._new(map(g, self))
    273 
    274     def filter(self, f, negate=False, **kwargs):

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    216     @property
    217     def _xtra(self): return None
--> 218     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    219     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
    220     def copy(self): return self._new(self.items.copy())

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
    197     def __call__(cls, x=None, *args, **kwargs):
    198         if not args and not kwargs and x is not None and isinstance(x,cls): return x
--> 199         return super().__call__(x, *args, **kwargs)
    200 
    201 # Cell

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    207         if items is None: items = []
    208         if (use_list is not None) or not _is_array(items):
--> 209             items = list(items) if use_list else _listify(items)
    210         if match is not None:
    211             if is_coll(match): match = len(match)

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in _listify(o)
    114     if isinstance(o, list): return o
    115     if isinstance(o, str) or _is_array(o): return [o]
--> 116     if is_iter(o): return list(o)
    117     return [o]
    118 

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in __call__(self, *args, **kwargs)
    177             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    178         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 179         return self.fn(*fargs, **kwargs)
    180 
    181 # Cell

~/miniconda3/envs/fastai2/lib/python3.7/contextlib.py in enter_context(self, cm)
    425         _cm_type = type(cm)
    426         _exit = _cm_type.__exit__
--> 427         result = _cm_type.__enter__(cm)
    428         self._push_cm_exit(cm, _exit)
    429         return result

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/utils.py in __enter__(self)
    645     "Wrapper for `contextlib.ExitStack` which enters a collection of context managers"
    646     def __init__(self, mgrs): self.default,self.stack = L(mgrs),ExitStack()
--> 647     def __enter__(self): self.default.map(self.stack.enter_context)
    648     def __exit__(self, *args, **kwargs): self.stack.__exit__(*args, **kwargs)
    649 

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in map(self, f, *args, **kwargs)
    270              else f.format if isinstance(f,str)
    271              else f.__getitem__)
--> 272         return self._new(map(g, self))
    273 
    274     def filter(self, f, negate=False, **kwargs):

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    216     @property
    217     def _xtra(self): return None
--> 218     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    219     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
    220     def copy(self): return self._new(self.items.copy())

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
    197     def __call__(cls, x=None, *args, **kwargs):
    198         if not args and not kwargs and x is not None and isinstance(x,cls): return x
--> 199         return super().__call__(x, *args, **kwargs)
    200 
    201 # Cell

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    207         if items is None: items = []
    208         if (use_list is not None) or not _is_array(items):
--> 209             items = list(items) if use_list else _listify(items)
    210         if match is not None:
    211             if is_coll(match): match = len(match)

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in _listify(o)
    114     if isinstance(o, list): return o
    115     if isinstance(o, str) or _is_array(o): return [o]
--> 116     if is_iter(o): return list(o)
    117     return [o]
    118 

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in __call__(self, *args, **kwargs)
    177             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    178         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 179         return self.fn(*fargs, **kwargs)
    180 
    181 # Cell

~/miniconda3/envs/fastai2/lib/python3.7/contextlib.py in enter_context(self, cm)
    425         _cm_type = type(cm)
    426         _exit = _cm_type.__exit__
--> 427         result = _cm_type.__enter__(cm)
    428         self._push_cm_exit(cm, _exit)
    429         return result

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastai/learner.py in __enter__(self)
    208 
    209     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None
--> 210     def __enter__(self): self(_before_epoch); return self
    211     def __exit__(self, exc_type, exc_value, tb): self(_after_epoch)
    212 

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastai/learner.py in __call__(self, event_name)
    131     def ordered_cbs(self, event): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, event)]
    132 
--> 133     def __call__(self, event_name): L(event_name).map(self._call_one)
    134 
    135     def _call_one(self, event_name):

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in map(self, f, *args, **kwargs)
    270              else f.format if isinstance(f,str)
    271              else f.__getitem__)
--> 272         return self._new(map(g, self))
    273 
    274     def filter(self, f, negate=False, **kwargs):

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    216     @property
    217     def _xtra(self): return None
--> 218     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    219     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
    220     def copy(self): return self._new(self.items.copy())

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
    197     def __call__(cls, x=None, *args, **kwargs):
    198         if not args and not kwargs and x is not None and isinstance(x,cls): return x
--> 199         return super().__call__(x, *args, **kwargs)
    200 
    201 # Cell

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    207         if items is None: items = []
    208         if (use_list is not None) or not _is_array(items):
--> 209             items = list(items) if use_list else _listify(items)
    210         if match is not None:
    211             if is_coll(match): match = len(match)

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in _listify(o)
    114     if isinstance(o, list): return o
    115     if isinstance(o, str) or _is_array(o): return [o]
--> 116     if is_iter(o): return list(o)
    117     return [o]
    118 

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastcore/foundation.py in __call__(self, *args, **kwargs)
    177             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    178         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 179         return self.fn(*fargs, **kwargs)
    180 
    181 # Cell

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastai/learner.py in _call_one(self, event_name)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name), event_name
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastai/learner.py in <listcomp>(.0)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name), event_name
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastai/callback/core.py in __call__(self, event_name)
     42                (self.run_valid and not getattr(self, 'training', False)))
     43         res = None
---> 44         if self.run and _run: res = getattr(self, event_name, noop)()
     45         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     46         return res

~/miniconda3/envs/fastai2/lib/python3.7/site-packages/fastai/callback/progress.py in before_fit(self)
     75         self.run = not hasattr(self.learn, 'lr_finder') and not hasattr(self, "gather_preds")
     76         self.nb_batches = []
---> 77         assert hasattr(self.learn, 'progress')
     78 
     79     def after_train(self): self.nb_batches.append(self.train_iter)

AssertionError: 

Does anyone know if the error is intended or it is a bug?

This is a bug. Thanks for finding. I’ve put in a fix which will appear hopefully in the next release.

In the meantime, @muellerzr came up with a workaround: add ShowGraphCallback to the cbs argument in the .fit method of the learner, and don’t add it at the initial stage where you construct the learner.

See this demo notebook for how to do that.

2 Likes