Callbacks for prediction step?

I am trying to use a pretrained model and analyze activations for some data; in other words, I am not using training, only predicting from an existing model.

I create the data, then the learner, after which I call learn.get_preds(data.train_ds), which correctly gives predictions.

What I need is a callback to collect activations of a given layer. I looked into the docs, but the impression I have is all the callbacks are only used when I train by using learn.fit(). Could you please guide me in the right direction?

Update 1: If the callbacks are only available during training, is there a way to set the learning rate to zero and “train” with callbacks to collect activations, but in effect, there will be no training because the weights aren’t updated?

Update 2: I dug into the source code and get_preds() of the learner calls another get_preds(), which in turn calls validate, which handles callbacks for steps on_batch_begin and on_batch_end. I think I am onto something.

def get_preds(self, ds_type:DatasetType=DatasetType.Valid, with_loss:bool=False, n_batch:Optional[int]=None,
              pbar:Optional[PBar]=None) -> List[Tensor]:
    "Return predictions and targets on `ds_type` dataset."
    lf = self.loss_func if with_loss else None
    return get_preds(self.model, self.dl(ds_type), cb_handler=CallbackHandler(self.callbacks),
                     activ=_loss_func2activ(self.loss_func), loss_func=lf, n_batch=n_batch, pbar=pbar)
def get_preds(model:nn.Module, dl:DataLoader, pbar:Optional[PBar]=None, cb_handler:Optional[CallbackHandler]=None,
              activ:nn.Module=None, loss_func:OptLossFunc=None, n_batch:Optional[int]=None) -> List[Tensor]:
    "Tuple of predictions and targets, and optional losses (if `loss_func`) using `dl`, max batches `n_batch`."
    res = [torch.cat(o).cpu() for o in
           zip(*validate(model, dl, cb_handler=cb_handler, pbar=pbar, average=False, n_batch=n_batch))]
    if loss_func is not None:
        with NoneReduceOnCPU(loss_func) as lf: res.append(lf(res[0], res[1]))
    if activ is not None: res[0] = activ(res[0])
    return res
def validate(model:nn.Module, dl:DataLoader, loss_func:OptLossFunc=None, cb_handler:Optional[CallbackHandler]=None,
             pbar:Optional[PBar]=None, average=True, n_batch:Optional[int]=None)->Iterator[Tuple[Union[Tensor,int],...]]:
    "Calculate `loss_func` of `model` on `dl` in evaluation mode."
    model.eval()
    with torch.no_grad():
        val_losses,nums = [],[]
        if cb_handler: cb_handler.set_dl(dl)
        for xb,yb in progress_bar(dl, parent=pbar, leave=(pbar is not None)):
            if cb_handler: xb, yb = cb_handler.on_batch_begin(xb, yb, train=False)
            val_loss = loss_batch(model, xb, yb, loss_func, cb_handler=cb_handler)
            val_losses.append(val_loss)
            if not is_listy(yb): yb = [yb]
            nums.append(first_el(yb).shape[0])
            if cb_handler and cb_handler.on_batch_end(val_losses[-1]): break
            if n_batch and (len(nums)>=n_batch): break
        nums = np.array(nums, dtype=np.float32)
        if average: return (to_np(torch.stack(val_losses)) * nums).sum() / nums.sum()
        else:       return val_losses

1 Like

Hi !

I think that what you need is called a Hook. Basically this is a Callback that you can use to grab some values out of your model. I did something similar to what you need using the following code:

m = model.eval()
x,y = data.one_batch()

def hooked(batch, ix):
    with hook_output(m.features[ix]) as hook_a: 
      preds = m(batch)

    return hook_a

hook_a = hooked(x.cuda(),0) 
2 Likes

Can you refer to the relevant part of the docs? I am struggling to understand the code you provided. I need to dig deeper into hooks and callbacks.

Sure, it is using this function. It will grab the output value of the module you ask for. Hope that helps :slight_smile:

1 Like

Thank you!