How to keep track of which data (images) are fed into the model?


I’ve spent a lot of time reading last years papers about classification on noisy labels and I was trying to implement this one with fastai: [2007.00151] Early-Learning Regularization Prevents Memorization of Noisy Labels. The PyTorch code of the paper is available here: GitHub - shengliu66/ELR: Official Implementation of Early-Learning Regularization Prevents Memorization of Noisy Labels.

The general idea is to keep track of which data is passed into the model and discount the contribution to the loss function of the data when the model tends to disagree with the label.

To be able to do that, we need to keep track of which data are fed into the model at each batch.
I’m thinking about using a custom Callback, before_batch, to get the indices of the data passed into the model. But I’m struggling with it because I am not very familiar yet with how to build callbacks in fastai.

Do any more advanced users have an idea of how I could do this? Or suggest any reference to become more familiar with a custom hook?

Then, I’d need to implement a custom loss function, but I’m more familiar with that.

Thank you for your help!

You can build a DataLoader with multiple inputs, where the first input is the data and the second input is the index.
Then you can either change the forward function of the model to accept two inputs or create a Callback that stores the indices.

I did not try it out, but belive the callback could look like this:

class StoreIndexCallback(Callback):
    def before_batch(self):
        data, index = self.learn.xb # returns the input batch as a tuple (data, index)
        self.index = index   # store index in callback
        self.learn.xb = (data, ) # pass the input as a tuple

    def after_batch(self):
        pred = detuplify(self.learn.pred)  # get preds, which are again stored in a tuple
        self.learn.pred = (pred, self.index)

For the loss function, the following could work:

def custom_loss(preds, targs): 
    pred, index = preds # preds were passed as a tuple, targets as tensor
    ... # your loss implementation with data and indices

learn = Learner(dls, model, loss_func = custom_loss)

Hope that helps a bit :slight_smile:


Thanks! I’ll give it a try

I attempted to implement ELR for a past kaggle competition here. It might give a good starting point. Its not tested thoroughly so beware of it :slight_smile:

1 Like

Thanks for sharing