Learner callbacks : features availabe to before_batch

I am trying to adapt a DataLoader to a model where one batch return 2 inputs (a mask and the data), and potentially a target.

Looking at the description in the course notebook 16 on the training process, which describes the learner’s callbacks, both xband yb are supposed to be accessible in before_batch (note: note begin_batch as in the notebook, see PR), due to the call to _split learner method

However in my case, I only have access to the yb attribute (if it is not empty as in the training dataloader) in the after_pred step… I wonder what I am missing in my understanding and would be happy for hints!

class ModelAdapter(Callback):
    """Models forward only expects on input matrix. 
    Apply mask from dataloader to both pred and targets."""

    def before_batch(self):
        """Remove cont. values from batch (mask)"""
        mask, data = self.xb  # x_cat, x_cont
        self.learn._mask = mask != 1
        if self.training:
            self.learn.yb = (data[self.learn._mask],)
        self.learn.xb = (data,)

    def after_pred(self):
        M = self._mask.shape[-1]

        if not self.training:
            # here I have access to self.yb...
            if len(self.yb):
                self.learn.yb = (self.y[self.learn._mask],)
                self.val_targets.append(self.learn.yb[0])
                

        #         self.learn.pred = self.pred.view(-1, M)[self._mask] #is this flat?
        self.learn.pred = self.pred[self._mask] #is this flat?
        if not self.training:
            self.val_preds.append(self.learn.pred)