Plotting Label Count Distribution of Data - Assessing Balance

Plotting Label Count Distribution of Data - Assessing Balance

Hey guys,

This is my first post in the forum.

As I was running through lesson one and attempting to load my custom dataset I was wondering how balanced my data is. What I mean by balanced is how many instances of each class/label do I have in the train/validation set.

I searched through the code base but wasn’t able to find a way to plot/print this information so I coded up a few jupyter cells which I added to my lesson one notebook. Thought it might be beneficial to others when building their dataset especially if the data set is scraped automatically.


The following code shows how to plot the distribution of the classes/labels in the train set.

# extracting the distribution
import collections
items = data.label_list.train.y.items
occurance_count = collections.Counter(items)
occurance_count = list(occurance_count.values())
classes = data.label_list.train.y.classes

# plotting
index = [i for i in range(len(occurance_count))],occurance_count)
plt.xticks(index,[classes[i] for i in range(len(occurance_count))],rotation=45)
plt.ylabel('Label Count')
plt.title('Count of Labels')

Hope you find this useful!


Thanks. I had to ravel the items before I could use them in the collection.Counter.

occurance_count = collections.Counter(items.ravel())

For fastai v2, i wanted to do something similar to this (compare train vs validation label distribution). This is what i came up with …

# *dls.xxx_ds returns tuples split into reassembles into x/y vectors ...
x,y = zip(*dls.train_ds)
xv,yv = zip(*dls.valid_ds)

# this creates our labels list.  basically transform fastai.tensor object to a simple list of ints 
y_labels = list(map(lambda a : a.item() ,y))
yv_labels = list(map(lambda a : a.item() ,yv))
# Create a dataframe of categorical counts
# Add percentages..
df.columns = ["train","valid"]
df["train_pct"] = df["train"]/df["train"].sum()
df["valid_pct"] = df["valid"]/df["valid"].sum()
df["labels"] = pd.Series(dls.vocab)