# MNIST Loss Function From Chapter 4

Hi everyone, I am working through the Further Research section of chapter 4 of Deep Learning for Coders…

I am currently struggling to understand how to write the loss function for the full mnist dataset, not just the three’s and seven’s.

When dealing with just 3’s and 7’s we are given this:
`def mnist_loss(predictions, targets): return torch.where(targets==1, 1-predictions, predictions).mean()`

But I am trying to understand how I would write this for all the targets.

This is how I set up the dataset:
`train_x = torch.cat([stacked_ones, stacked_twos, stacked_threes, stacked_fours, stacked_fives, stacked_sixes, stacked_sevens, stacked_eigths, stacked_nines]).view(-1, 28*28)

train_y = tensor([1]*len(ones) + [2]*len(twos) + [3]*len(threes) + [4]*len(fours) + [5]*len(fives) + [6]*len(sixes) + [7]*len(sevens) + [8]*len(eights) + [9]*len(nines)).unsqueeze(1)

dset = list(zip(train_x,train_y))
x,y = dset[0]
x.shape,y, len(dset)

valid_x = torch.cat([valid_1_tens, valid_2_tens, valid_3_tens, valid_4_tens, valid_5_tens, valid_6_tens, valid_7_tens, valid_8_tens, valid_9_tens]).view(-1, 28*28)
valid_y = tensor([1]*len(valid_1_tens) + [2]*len(valid_2_tens) + [3]*len(valid_3_tens) + [4]*len(valid_4_tens) + [5]*len(valid_5_tens) + [6]*len(valid_6_tens) + [7]*len(valid_7_tens) + [8]*len(valid_8_tens) + [9]*len(valid_9_tens)).unsqueeze(1)
valid_dset = list(zip(valid_x,valid_y))`

Given that I am not dealing with the target being just a 0 or a 1, does anyone have any hints for creating this loss function?

Thanks!

Hi @NathanSepulveda
I found it a bit tricky too, got frustrated and then went to the fastai documentation to realize I need to use a different kind of block for the targets when we define the DataBlock! When we use this block for the targets, the loss function also has to be updated. Once you identify the type of block it should not be hard to find which `loss_func` to use.

P.S. I wanted to give a response without giving too much away and not ruining your experience.

But yes, the first step is to use the `CategoryBlock`.
But if you have not manually defined a `DataBlock` so far, don’t worry about it!
The `loss_func` for binary classification (3 or 7), as given in the book, can be `mse_loss`.
When you have multiple classes (0, 1, …, 9) there’s a loss function called `CrossEntropyLoss`, which will be explained in the forthcoming chapters.