Hi everyone, I am working through the Further Research section of chapter 4 of Deep Learning for Coders…

I am currently struggling to understand how to write the loss function for the full mnist dataset, not just the three’s and seven’s.

When dealing with just 3’s and 7’s we are given this:

`def mnist_loss(predictions, targets): return torch.where(targets==1, 1-predictions, predictions).mean()`

But I am trying to understand how I would write this for all the targets.

This is how I set up the dataset:

`train_x = torch.cat([stacked_ones, stacked_twos, stacked_threes, stacked_fours, stacked_fives, stacked_sixes, stacked_sevens, stacked_eigths, stacked_nines]).view(-1, 28*28)

train_y = tensor([1]*len(ones) + [2]*len(twos) + [3]*len(threes) + [4]*len(fours) + [5]*len(fives) + [6]*len(sixes) + [7]*len(sevens) + [8]*len(eights) + [9]*len(nines)).unsqueeze(1)

dset = list(zip(train_x,train_y))

x,y = dset[0]

x.shape,y, len(dset)

valid_x = torch.cat([valid_1_tens, valid_2_tens, valid_3_tens, valid_4_tens, valid_5_tens, valid_6_tens, valid_7_tens, valid_8_tens, valid_9_tens]).view(-1, 28*28)

valid_y = tensor([1]*len(valid_1_tens) + [2]*len(valid_2_tens) + [3]*len(valid_3_tens) + [4]*len(valid_4_tens) + [5]*len(valid_5_tens) + [6]*len(valid_6_tens) + [7]*len(valid_7_tens) + [8]*len(valid_8_tens) + [9]*len(valid_9_tens)).unsqueeze(1)

valid_dset = list(zip(valid_x,valid_y))`

Given that I am not dealing with the target being just a 0 or a 1, does anyone have any hints for creating this loss function?

Thanks!