Working with optional Metadata in v2

Hi everyone,

First, I would like to thank everybody involved in this project.
The library is simply amazing!

To my problem:

I am working with fastaiv2 for the first time.
So maybe I am missing something obvious.

I am working with images that have metadata available.
I would like to be able to optionally include the metadata.
The metadata should not be coupled too tightly with my other components, as the metadata will probably be used for very different tasks. (Influence sample weights, maybe included in the loss function, etc.)

I have sub-classed TensorImage because the input requires extra initialization steps and a different show method. (The input has more than 3 channels)

First, I thought that I could simply save the additional metadata during the type_tfm to my tensor, but later I saw that the metadata gets overwritten during the collation step (or better, only the metadata of the first tensor is saved, if there are multiple attributes with the same name).

I tried to give a more detailed explanation in the following colab notebook:

But the question boils down to:
How would you include optional metadata, that will be used for different experiments, and still be able to reuse most of the pipeline?

Sorry if it seems too abstract. I hope that the colab notebook helps to show what my issue is. :slight_smile:

Thanks!

I think I’ve found a satisfactory starting point:
It can be implemented as a callback but relies on some “hidden” functions.
I have only tested it for my current use-case, but I will update my approach if I find any problems:

The goal is to access metadata during the training. Maybe just to visualize some statistics about the current batch for example. I assume that the metadata is available in a pandas.DataFrame and that the items are generated from this DataFrame. So the metadata would be in the same row of the table.

class DataFrameMetadataLogger(Callback):
    def __init__(self, metadata_keys):
        self.metadata_keys = listify(metadata_keys)
        self.metadata = {}

    def before_batch(self):
        # get idxs of current epoch
        idxs_of_epoch = self.dl._DataLoader__idxs
        # in each epoch the batches are chunked in this style -> should be moved to only calc'ed once
        idxs_of_batches = L(chunked(idxs_of_epoch, self.dl.bs, self.dl.drop_last))
        # grab the indices of the current batch
        idxs_of_cur_batch = idxs_of_batches[self.iter]
        rows = self.dl.items.iloc[idxs_of_cur_batch]
        for key in self.metadata_keys:
            self.metadata[key] = rows[key].tolist()