Is there anyway to call learn.get_preds() without triggering any of the callbacks?

I know I can do something like this …

with torch.no_grad():
    for index, b in enumerate(dl):
        preds = learn.model(b)

… but would like to use learn.get_preds() if there is way to do so.

Why do I want to do this?

Because I need to be able to get the predictions against the validation set inside a Callback.after_validate method to get a key metric to be used by several metrics. Right now, calling learn.get_preds() results recursive loop.

If you want to call get_preds inside a Callback you need to use the callback FetchPreds, which is designed for that (and allows you to disable some callbacks). I’m not too sure what you use case is though.
Note that you have a learn.removed_cbs() context manager that can be useful too.

The problem is that FetchPreds runs after_validate.

I’m trying to get the optimal threshold for a multi-label model model … a value to be used in several metrics associated to my learner. I need to do this in the before_validation method. This is what I tried in my Callback:

def begin_validate(self, **kwargs):
    with self.learn.removed_cbs(self) as learn:
            self.preds_targs = learn.get_preds(inner=True)
    probs, targs = self.preds_targs
    ...

… and here’s the fun stack trace:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    191                         self._do_epoch_train()
--> 192                         self._do_epoch_validate()
    193                     except CancelEpochException:   self('after_cancel_epoch')

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in _do_epoch_validate(self, ds_idx, dl)
    172             dl,old,has = change_attrs(dl, names, [False,False])
--> 173             self.dl = dl;                                    self('begin_validate')
    174             with torch.no_grad(): self.all_batches()

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in __call__(self, event_name)
    122 
--> 123     def __call__(self, event_name): L(event_name).map(self._call_one)
    124     def _call_one(self, event_name):

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in map(self, f, *args, **kwargs)
    361              else f.__getitem__)
--> 362         return self._new(map(g, self))
    363 

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    314     def _xtra(self): return None
--> 315     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    316     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
     40 
---> 41         res = super().__call__(*((x,) + args), **kwargs)
     42         res._newchk = 0

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    305         if (use_list is not None) or not _is_array(items):
--> 306             items = list(items) if use_list else _listify(items)
    307         if match is not None:

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in _listify(o)
    241     if isinstance(o, str) or _is_array(o): return [o]
--> 242     if is_iter(o): return list(o)
    243     return [o]

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __call__(self, *args, **kwargs)
    207         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 208         return self.fn(*fargs, **kwargs)
    209 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in _call_one(self, event_name)
    125         assert hasattr(event, event_name)
--> 126         [cb(event_name) for cb in sort_by_run(self.cbs)]
    127 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in <listcomp>(.0)
    125         assert hasattr(event, event_name)
--> 126         [cb(event_name) for cb in sort_by_run(self.cbs)]
    127 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/callback/core.py in __call__(self, event_name)
     22                (self.run_valid and not getattr(self, 'training', False)))
---> 23         if self.run and _run: getattr(self, event_name, noop)()
     24         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit

<ipython-input-238-86f2413c40ea> in begin_validate(self, **kwargs)
     46         #self.probs, self.targs = torch.cat(self.probs, dim=0), torch.cat(self.targs, dim=0)
---> 47         self.loss_func.thresh = self.opt_th()
     48         self.opt_thresh = self.loss_func.thresh

<ipython-input-238-86f2413c40ea> in opt_th(self)
     62         pdb.set_trace()
---> 63         with self.learn.removed_cbs(self) as learn:
     64             self.preds_targs = learn.get_preds(inner=True)

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, **kwargs)
    217             self(event.begin_epoch if inner else _before_epoch)
--> 218             self._do_epoch_validate(dl=dl)
    219             self(event.after_epoch if inner else _after_epoch)

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in _do_epoch_validate(self, ds_idx, dl)
    176         finally:
--> 177             dl,*_ = change_attrs(dl, names, old, has);       self('after_validate')
    178 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in __call__(self, event_name)
    122 
--> 123     def __call__(self, event_name): L(event_name).map(self._call_one)
    124     def _call_one(self, event_name):

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in map(self, f, *args, **kwargs)
    361              else f.__getitem__)
--> 362         return self._new(map(g, self))
    363 

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    314     def _xtra(self): return None
--> 315     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    316     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
     40 
---> 41         res = super().__call__(*((x,) + args), **kwargs)
     42         res._newchk = 0

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    305         if (use_list is not None) or not _is_array(items):
--> 306             items = list(items) if use_list else _listify(items)
    307         if match is not None:

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in _listify(o)
    241     if isinstance(o, str) or _is_array(o): return [o]
--> 242     if is_iter(o): return list(o)
    243     return [o]

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __call__(self, *args, **kwargs)
    207         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 208         return self.fn(*fargs, **kwargs)
    209 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in _call_one(self, event_name)
    125         assert hasattr(event, event_name)
--> 126         [cb(event_name) for cb in sort_by_run(self.cbs)]
    127 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in <listcomp>(.0)
    125         assert hasattr(event, event_name)
--> 126         [cb(event_name) for cb in sort_by_run(self.cbs)]
    127 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/callback/core.py in __call__(self, event_name)
     22                (self.run_valid and not getattr(self, 'training', False)))
---> 23         if self.run and _run: getattr(self, event_name, noop)()
     24         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in after_validate(self)
    470             self.preds = learn.get_preds(ds_idx=self.ds_idx, dl=self.dl,
--> 471                 with_input=self.with_input, with_decoded=self.with_decoded, inner=True)
    472 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, **kwargs)
    218             self._do_epoch_validate(dl=dl)
--> 219             self(event.after_epoch if inner else _after_epoch)
    220             if act is None: act = getattr(self.loss_func, 'activation', noop)

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in __call__(self, event_name)
    122 
--> 123     def __call__(self, event_name): L(event_name).map(self._call_one)
    124     def _call_one(self, event_name):

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in map(self, f, *args, **kwargs)
    361              else f.__getitem__)
--> 362         return self._new(map(g, self))
    363 

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    314     def _xtra(self): return None
--> 315     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    316     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
     40 
---> 41         res = super().__call__(*((x,) + args), **kwargs)
     42         res._newchk = 0

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    305         if (use_list is not None) or not _is_array(items):
--> 306             items = list(items) if use_list else _listify(items)
    307         if match is not None:

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in _listify(o)
    241     if isinstance(o, str) or _is_array(o): return [o]
--> 242     if is_iter(o): return list(o)
    243     return [o]

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __call__(self, *args, **kwargs)
    207         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 208         return self.fn(*fargs, **kwargs)
    209 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in _call_one(self, event_name)
    125         assert hasattr(event, event_name)
--> 126         [cb(event_name) for cb in sort_by_run(self.cbs)]
    127 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in <listcomp>(.0)
    125         assert hasattr(event, event_name)
--> 126         [cb(event_name) for cb in sort_by_run(self.cbs)]
    127 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/callback/core.py in __call__(self, event_name)
     22                (self.run_valid and not getattr(self, 'training', False)))
---> 23         if self.run and _run: getattr(self, event_name, noop)()
     24         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/callback/tracker.py in after_epoch(self)
     86         else: #every improvement
---> 87             super().after_epoch()
     88             if self.new_best: self._save(f'{self.fname}')

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/callback/tracker.py in after_epoch(self)
     46         "Compare the last value to the best up to know"
---> 47         val = self.recorder.values[-1][self.idx]
     48         if self.comp(val - self.min_delta, self.best): self.best,self.new_best = val,True

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __getattr__(self, k)
    221             attr = getattr(self,self._default,None)
--> 222             if attr is not None: return getattr(attr,k)
    223         raise AttributeError(k)

AttributeError: 'TextLearner' object has no attribute 'recorder'

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    193                     except CancelEpochException:   self('after_cancel_epoch')
--> 194                     finally:                       self('after_epoch')
    195 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in __call__(self, event_name)
    122 
--> 123     def __call__(self, event_name): L(event_name).map(self._call_one)
    124     def _call_one(self, event_name):

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in map(self, f, *args, **kwargs)
    361              else f.__getitem__)
--> 362         return self._new(map(g, self))
    363 

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    314     def _xtra(self): return None
--> 315     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    316     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
     40 
---> 41         res = super().__call__(*((x,) + args), **kwargs)
     42         res._newchk = 0

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    305         if (use_list is not None) or not _is_array(items):
--> 306             items = list(items) if use_list else _listify(items)
    307         if match is not None:

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in _listify(o)
    241     if isinstance(o, str) or _is_array(o): return [o]
--> 242     if is_iter(o): return list(o)
    243     return [o]

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __call__(self, *args, **kwargs)
    207         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 208         return self.fn(*fargs, **kwargs)
    209 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in _call_one(self, event_name)
    125         assert hasattr(event, event_name)
--> 126         [cb(event_name) for cb in sort_by_run(self.cbs)]
    127 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in <listcomp>(.0)
    125         assert hasattr(event, event_name)
--> 126         [cb(event_name) for cb in sort_by_run(self.cbs)]
    127 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/callback/core.py in __call__(self, event_name)
     22                (self.run_valid and not getattr(self, 'training', False)))
---> 23         if self.run and _run: getattr(self, event_name, noop)()
     24         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/callback/tracker.py in after_epoch(self)
     86         else: #every improvement
---> 87             super().after_epoch()
     88             if self.new_best: self._save(f'{self.fname}')

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/callback/tracker.py in after_epoch(self)
     46         "Compare the last value to the best up to know"
---> 47         val = self.recorder.values[-1][self.idx]
     48         if self.comp(val - self.min_delta, self.best): self.best,self.new_best = val,True

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __getattr__(self, k)
    221             attr = getattr(self,self._default,None)
--> 222             if attr is not None: return getattr(attr,k)
    223         raise AttributeError(k)

AttributeError: 'TextLearner' object has no attribute 'recorder'

During handling of the above exception, another exception occurred:

FileNotFoundError                         Traceback (most recent call last)
<ipython-input-248-e44fcd197338> in <module>
----> 1 learn.fit_one_cycle(1, lr_max=lr)

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt)
    110     scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
    111               'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 112     self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
    113 
    114 # Cell

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    195 
    196             except CancelFitException:             self('after_cancel_fit')
--> 197             finally:                               self('after_fit')
    198 
    199     def validate(self, ds_idx=1, dl=None, cbs=None):

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in __call__(self, event_name)
    121     def ordered_cbs(self, cb_func): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, cb_func)]
    122 
--> 123     def __call__(self, event_name): L(event_name).map(self._call_one)
    124     def _call_one(self, event_name):
    125         assert hasattr(event, event_name)

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in map(self, f, *args, **kwargs)
    360              else f.format if isinstance(f,str)
    361              else f.__getitem__)
--> 362         return self._new(map(g, self))
    363 
    364     def filter(self, f, negate=False, **kwargs):

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    313     @property
    314     def _xtra(self): return None
--> 315     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    316     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
    317     def copy(self): return self._new(self.items.copy())

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
     39             return x
     40 
---> 41         res = super().__call__(*((x,) + args), **kwargs)
     42         res._newchk = 0
     43         return res

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    304         if items is None: items = []
    305         if (use_list is not None) or not _is_array(items):
--> 306             items = list(items) if use_list else _listify(items)
    307         if match is not None:
    308             if is_coll(match): match = len(match)

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in _listify(o)
    240     if isinstance(o, list): return o
    241     if isinstance(o, str) or _is_array(o): return [o]
--> 242     if is_iter(o): return list(o)
    243     return [o]
    244 

~/development/_training/ml/nlp-playground/_libs/fastcore/fastcore/foundation.py in __call__(self, *args, **kwargs)
    206             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    207         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 208         return self.fn(*fargs, **kwargs)
    209 
    210 # Cell

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in _call_one(self, event_name)
    124     def _call_one(self, event_name):
    125         assert hasattr(event, event_name)
--> 126         [cb(event_name) for cb in sort_by_run(self.cbs)]
    127 
    128     def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in <listcomp>(.0)
    124     def _call_one(self, event_name):
    125         assert hasattr(event, event_name)
--> 126         [cb(event_name) for cb in sort_by_run(self.cbs)]
    127 
    128     def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/callback/core.py in __call__(self, event_name)
     21         _run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
     22                (self.run_valid and not getattr(self, 'training', False)))
---> 23         if self.run and _run: getattr(self, event_name, noop)()
     24         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     25 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/callback/tracker.py in after_fit(self, **kwargs)
     90     def after_fit(self, **kwargs):
     91         "Load the best model."
---> 92         if not self.every_epoch: self.learn.load(f'{self.fname}')
     93 
     94 # Cell

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in load(self, file, with_opt, device, strict)
    269         distrib_barrier()
    270         file = join_path_file(file, self.path/self.model_dir, ext='.pth')
--> 271         load_model(file, self.model, self.opt, with_opt=with_opt, device=device, strict=strict)
    272         return self
    273 

~/development/_training/ml/nlp-playground/_libs/fastai2/fastai2/learner.py in load_model(file, model, opt, with_opt, device, strict)
     48     if isinstance(device, int): device = torch.device('cuda', device)
     49     elif device is None: device = 'cpu'
---> 50     state = torch.load(file, map_location=device)
     51     hasopt = set(state)=={'model', 'opt'}
     52     model_state = state['model'] if hasopt else state

~/anaconda3/envs/playground-nlp/lib/python3.7/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
    523         pickle_load_args['encoding'] = 'utf-8'
    524 
--> 525     with _open_file_like(f, 'rb') as opened_file:
    526         if _is_zipfile(opened_file):
    527             with _open_zipfile_reader(f) as opened_zipfile:

~/anaconda3/envs/playground-nlp/lib/python3.7/site-packages/torch/serialization.py in _open_file_like(name_or_buffer, mode)
    210 def _open_file_like(name_or_buffer, mode):
    211     if _is_path(name_or_buffer):
--> 212         return _open_file(name_or_buffer, mode)
    213     else:
    214         if 'w' in mode:

~/anaconda3/envs/playground-nlp/lib/python3.7/site-packages/torch/serialization.py in __init__(self, name, mode)
    191 class _open_file(_opener):
    192     def __init__(self, name, mode):
--> 193         super(_open_file, self).__init__(open(name, mode))
    194 
    195     def __exit__(self, *args):

I should mention that this Callback is set to run before the Recorder so I can update the threshold in the metrics that have a thresh attribute.

Also a weird behavior (at least to me). If I change the code:

def begin_validate(self, **kwargs):
    with self.learn.removed_cbs(self.learn.cbs) as learn:
            self.preds_targs = learn.get_preds(inner=True)
    probs, targs = self.preds_targs

so that I remove all the learner’s callbacks … I get this back before I execute the with block:

self.learn.cbs
# returns [TrainEvalCallback,Recorder,ProgressCallback,SaveModelCallback,FetchPreds,OptimizeFBetaThreshCallback,ModelReseter,RNNRegularizer,ParamScheduler]

and then when I’m in the with block I’m expected that learn has no callbacks … BUT it does:

learn.cbs
# returns [Recorder,SaveModelCallback,OptimizeFBetaThreshCallback,RNNRegularizer]

I’m obviously missing something.

SOLVED:

Well, this was a fun one.

The problem was that I was using SaveModelCallback which derives from TrackerCallback … and TrackerCallback attempts to access self.recorder in both of the lifecycle events it includes.

So long story short, if you’re going to remove Recorder callback … you probably want to remove all TrackerCallback instances as well.

2 Likes

Thanks for sharing. I’ll probably need to do that as well in the FetchPreds for wandb .

I’ll submit some PRs as well if you don’t mind today or tomorrow with a new Metric subclass, improved Callback class, and my optimized threshold callback. I think they are framework worthy.

hi @wgpubs
Could you post your working code to solve it? I’m also struggling with a custom callback to get predictions during training.

1 Like

Instead of adding your callback to Learner … if it is simply used for training, just include it in your call(s) to fit or fit_one_cycle. As the callback is no longer associated to your Learner, they won’t interfere with your call to get_preds()

In terms of examples, you can check out one of the custom callbacks I include in my blurr library (see here). Not sure what you’re doing specifically, but this example callback I think provides a taste of what you can do and also not get in the way of the fastai bits.

Lmk how it goes.

1 Like

Now it’s working :slight_smile: But I didn’t use get_preds(). To make it work, I had to go with:
with torch.no_grad(): preds = learn.model(xb)

The training was ok but now I have another problem ahahah. I can’t export the model. It keeps giving me:

TypeError: cannot serialize ‘_io.TextIOWrapper’ object

Maybe it’s because of CSVLogger?
But not only that. When I try to make a prediction (after training), my custom callback kicks in! I’ve never seen this happen.

Are you passing in your callback to Learner or to Fit? If you don’t want it permanently on there, pass the callback only while fitting.

1 Like

Oh let me try that!
Should I pass all callbacks to fit?

Edit: it seems to be working.
No errors during export and my custom callback is silent :slight_smile: I can also load the model back for inference. Thank you @muellerzr!

1 Like

What is the “xb” in your code? Would you mind posting a link to the code that finally worked?

xb = x batch. When working in the callbacks you have self.x and self.xb. You really want to modify self.xb, and self.x is there for convience

2 Likes