Custom average of loss function in validate

I have the following situation, I’m training a model whose output is a variable length sequence. As a result, my output on a batch is of size N * L_max * C where C is the number of categories and N is the batch size. Some of my data has been padded and is ignored during the computation of my loss (using ignore_index). When I want to calculate my global loss, I want to take into account this variable length, so the weight of a batch is not L_max but the sum of the length of the sequences.

Here is the source for validate (from :

def validate(model:nn.Module, dl:DataLoader, loss_func:OptLossFunc=None, cb_handler:Optional[CallbackHandler]=None,
             pbar:Optional[PBar]=None, average=True, n_batch:Optional[int]=None)->Iterator[Tuple[Union[Tensor,int],...]]:
    "Calculate `loss_func` of `model` on `dl` in evaluation mode."
    with torch.no_grad():
        val_losses,nums = [],[]
        if cb_handler: cb_handler.set_dl(dl)
        for xb,yb in progress_bar(dl, parent=pbar, leave=(pbar is not None)):
            if cb_handler: xb, yb = cb_handler.on_batch_begin(xb, yb, train=False)
            val_loss = loss_batch(model, xb, yb, loss_func, cb_handler=cb_handler)
            if not is_listy(yb): yb = [yb]
            if cb_handler and cb_handler.on_batch_end(val_losses[-1]): break
            if n_batch and (len(nums)>=n_batch): break
        nums = np.array(nums, dtype=np.float32)
        if average: return (to_np(torch.stack(val_losses)) * nums).sum() / nums.sum()
        else:       return val_losses

The line of interest here are :

if not is_listy(yb): yb = [yb]

which means that the average of my loss will always be computed using L_max.

I can override the validate function of the fastai Library for my task but this won’t be updated with the future versions. I have searched a way to do this within the library but have not been successful yet. Do you have any ideas on if it is possible to do so?