Display number of images in train/valid dataset

Hi guys,

I use ImageDataLoaders.from_folder to load train and validation data. Question: What is a more efficient way to display the number of images in each category of the train/valid dataset, primarly to identify class imbalances.

This is the way, I calculate and plot the number of images in a category at the moment. But thats a bit slow.

import plotly.express as px
import pandas as pd

labels = ["a", "b", "c", "d", "e", "f", "g", "h", "i"]
train_cls_distribution = {}
train_cls_distribution["labels"] = labels
train_cls_distribution["amount"] = []

for idx, label in enumerate(labels):
    train_cls_distribution["amount"].append(len(
        list(
            filter(
                lambda ds: bool(ds[1] == TensorCategory(idx)), 
                dls.train.dataset
            )
        )
    ))

df = pd.DataFrame(train_cls_distribution)
fig = px.bar(df, x='labels', y='amount')
fig.show()

I guess that is more a Python-specific question, but maybe there is something built-in in FastAI to do this in a more efficient manner.

Thanks in advance for any help.

Hello Tobias,

Looping through dls.train.datataset to create the mapping train_cls_distribution is slow for me as well. dls.train.datataset is a list of tuples of the form (image, label) (or more precisely here (PIL Image, TensorCategory) ). For some reason I don’t clearly understand, when I loop over this list, Google Colab shows at the bottom some load_image() function being called; I suspect the slowness is due to this.

If you’re using ImageDataLoaders.from_folder(), I’m assuming you have data in the ImageNet-dataset style; that is, a folder for each class. An alternative may be to use the Python os module; something along the lines of

import os
DIR = './data/imagenette2-160/train'
folder_list = os.listdir(DIR)
for folder in folder_list:
    print(folder, len([file_name for file_name in os.listdir(DIR + folder)]))

I hope this helps.

Hi Pankaj,

thanks for the hint. It’s indeed much faster to do it the way you suggested, I was thinking a bit too complicated here. I used glob to do the job.

import glob

DIR = './data/imagenette2-160/train'
labels = ["a", "b", "c", "d", "e", "f", "g", "h", "i"]
for label in labels:
    print(len(glob.glob(f"{DIR}/label/*")
1 Like