Small error in chapter 4 of Fastbook

I was skimming the book, and found a teeny tiny error in the accuracy calculation in 04_mnist_basics.ipynb.

The error is in the accuracy calculation for the baseline 3 or 7 classifier. The code below calculates the accuracy by taking the mean accuracy of 3s, the mean accuracy of 7s, and dividing it by 2. However, there are different numbers of samples in the validation set for 3s and 7s.

accuracy_3s =      is_3(valid_3_tens).float() .mean()
accuracy_7s = (1 - is_3(valid_7_tens).float()).mean()

#Output: (tensor(0.9168), tensor(0.9854), tensor(0.9511))

In another cell, I ran:

valid_3_tens.shape[0] == valid_7_tens.shape[0]
#Output: False

When calculating accuracy, I used the following code:

# Create vectors for the predictions and actuals for the 3s
valid_3s_preds = is_3(valid_3_tens).float()
valid_3s_actual = torch.ones_like(valid_3s_preds)

# Create vectors for the predictions and actuals for the 7s
valid_7s_preds = is_3(valid_7_tens).float()
valid_7s_actual = torch.zeros_like(valid_7s_preds)

# Concatenate everything together
preds =[valid_3s_preds, valid_7s_preds]).float()
actuals =[valid_3s_actual, valid_7s_actual]).float()

# Calculate accuracy
(preds == actuals).float().mean()
# Output: tensor(0.9514)

It’s only a small difference in accuracy, but I’m worried this calculation may mislead readers how to properly calculate accuracy by taking the mean of the accuracy of each class.


Bear with me, this is going to be a lengthy post:

I don’t think there is an error. Let me try and explain my view point.

Let’s start with the 3’s.

accuracy_3s = is_3(valid_3_tens).float() .mean()

Here, we use the function is_3 to calculate the accuracy, so let us see how it was defined:

def is_3(x): 
    return mnist_distance(x,mean3) < mnist_distance(x,mean7)

We check if mnist_distance(x, mean3) is less than mnist_distance(x, mean7)

So how is mnist_distance calculated?

def mnist_distance(a,b): 
    return (a-b).abs().mean((-1,-2))

It uses broadcasting to calculate the L1 distance between the mean_3 and individual images passed in.

So, it does not matter that:


Let’s go step by step:

When we call:


What it is doing internally is checking if

mnist_distance(valid_3_tens,mean3) < mnist_distance(valid_3_tens,mean7)

Note: Here’s the most important point
It checks every if the distance of every image in the validation set of 3 is closer to the mean 3 or mean 7. So the whole process only works on valid_3_tens. So it doesn’t matter that the number of validation images of 3’s is not equal to the number of validation images of 7’s

Which returns an array of [True, True, …, True, True] with a size of the number of validation images of 3’s.
We then convert this array into a float and get its mean. So what does this mean tell us?
How accurate did we classify the Three’s.

When we repeat the whole process with valid_7_tens, we will be working only on the validation images of 7’s and we get how accurate we classified the 7’s.

So, to get the overall accuracy of the model, the model who’s original goal was to classify 3’s from 7’s, it is safe to average out this two accuracies to get the overall accuracy of our model.

Therefore, the syntax:

Is correct.

Tell me what you think :smiley:

To illustrate my point, I’d like to exaggerate the problem.

Let’s say we have a validation dataset comprising 9 images of 3’s and 1 image of a 7. Let’s also say our base classifier gets 6 of the 3’s correct (accuracy = 0.67 for 3’s) and 1 of the 7’s correct (accuracy = 1.0 for the 7’s).

Accuracy for the overall classifier should be given by the fraction of inferences we got correct, calculated below,

acc = (6 + 1) / 10 = 0.7,

and not the average accuracy of each class in the validation set, calculated below

acc = (0.67 + 1.0) / 2 = 0.84.

The classes in the validation set for 3’s and 7’s are not balanced, so we can’t just take the mean of them to represent the accuracy of the classifier overall.

I have no qualms about how accuracy is calculated for each class, just that for an imbalanced dataset, we can’t just take the mean of class accuracies to represent the classifier overall.


I see it now. I think you are right.