The general idea is to keep track of which data is passed into the model and discount the contribution to the loss function of the data when the model tends to disagree with the label.
To be able to do that, we need to keep track of which data are fed into the model at each batch.
I’m thinking about using a custom Callback, before_batch, to get the indices of the data passed into the model. But I’m struggling with it because I am not very familiar yet with how to build callbacks in fastai.
Do any more advanced users have an idea of how I could do this? Or suggest any reference to become more familiar with a custom hook?
Then, I’d need to implement a custom loss function, but I’m more familiar with that.
You can build a DataLoader with multiple inputs, where the first input is the data and the second input is the index.
Then you can either change the forward function of the model to accept two inputs or create a Callback that stores the indices.
I did not try it out, but belive the callback could look like this:
class StoreIndexCallback(Callback):
def before_batch(self):
data, index = self.learn.xb # returns the input batch as a tuple (data, index)
self.index = index # store index in callback
self.learn.xb = (data, ) # pass the input as a tuple
def after_batch(self):
pred = detuplify(self.learn.pred) # get preds, which are again stored in a tuple
self.learn.pred = (pred, self.index)
For the loss function, the following could work:
def custom_loss(preds, targs):
pred, index = preds # preds were passed as a tuple, targets as tensor
... # your loss implementation with data and indices
learn = Learner(dls, model, loss_func = custom_loss)