MNIST Loss Function From Chapter 4

Hi everyone, I am working through the Further Research section of chapter 4 of Deep Learning for Coders…

I am currently struggling to understand how to write the loss function for the full mnist dataset, not just the three’s and seven’s.

When dealing with just 3’s and 7’s we are given this:
def mnist_loss(predictions, targets): return torch.where(targets==1, 1-predictions, predictions).mean()

But I am trying to understand how I would write this for all the targets.

This is how I set up the dataset:
`train_x = torch.cat([stacked_ones, stacked_twos, stacked_threes, stacked_fours, stacked_fives, stacked_sixes, stacked_sevens, stacked_eigths, stacked_nines]).view(-1, 28*28)

train_y = tensor([1]*len(ones) + [2]*len(twos) + [3]*len(threes) + [4]*len(fours) + [5]*len(fives) + [6]*len(sixes) + [7]*len(sevens) + [8]*len(eights) + [9]*len(nines)).unsqueeze(1)

dset = list(zip(train_x,train_y))
x,y = dset[0]
x.shape,y, len(dset)

valid_x = torch.cat([valid_1_tens, valid_2_tens, valid_3_tens, valid_4_tens, valid_5_tens, valid_6_tens, valid_7_tens, valid_8_tens, valid_9_tens]).view(-1, 28*28)
valid_y = tensor([1]*len(valid_1_tens) + [2]*len(valid_2_tens) + [3]*len(valid_3_tens) + [4]*len(valid_4_tens) + [5]*len(valid_5_tens) + [6]*len(valid_6_tens) + [7]*len(valid_7_tens) + [8]*len(valid_8_tens) + [9]*len(valid_9_tens)).unsqueeze(1)
valid_dset = list(zip(valid_x,valid_y))`

Given that I am not dealing with the target being just a 0 or a 1, does anyone have any hints for creating this loss function?

Thanks!

Hi @NathanSepulveda
I found it a bit tricky too, got frustrated and then went to the fastai documentation to realize I need to use a different kind of block for the targets when we define the DataBlock! When we use this block for the targets, the loss function also has to be updated. Once you identify the type of block it should not be hard to find which loss_func to use.

P.S. I wanted to give a response without giving too much away and not ruining your experience. :slight_smile:

Hi, thanks for the reply!

I actually am more confused now. I don’t think I am defining a DataBlock in the above code I had.

Looking into the docs though, should I assume I need to use a CategoryBlock? I think I would benefit from giving a bit more away!

Hi @NathanSepulveda
I may have jumped a little farther than I should have.
But yes, the first step is to use the CategoryBlock.
But if you have not manually defined a DataBlock so far, don’t worry about it!
The loss_func for binary classification (3 or 7), as given in the book, can be mse_loss.
When you have multiple classes (0, 1, …, 9) there’s a loss function called CrossEntropyLoss, which will be explained in the forthcoming chapters. :slight_smile: