I am trying to figure out how to use an appropriate loss function in the full implementation of the MNIST dataset.
From my research I believe I have to change the original function to use
nn.CrossEntropyLoss() for a classification problem.
This is my setup:
# Train dataset train_x = torch.cat(train_stacked_nums).view(-1,28*28) train_y = torch.cat([tensor([i]*len(t)) for i,t in enumerate(train_stacked_nums)]).unsqueeze(1) train_dset = list(zip(train_x, train_y)) # Validation dataset valid_x = torch.cat(valid_stacked_nums).view(-1,28*28) valid_y = torch.cat([tensor([i]*len(t)) for i,t in enumerate(valid_stacked_nums)]).unsqueeze(1) valid_dset = list(zip(valid_x, valid_y)) print(train_x.shape, train_y.shape, valid_x.shape, valid_y.shape)
The shape of my training and validation Xs and Ys is:
torch.Size([60000, 784]), torch.Size([60000, 1]), torch.Size([10000, 784]), torch.Size([10000, 1])
I am using the same linear function:
# Create a function for matrix multiplication def linear1 (xb): return xb@weights + bias preds = linear1(train_x) preds.shape, preds
The result of the shape of predictions and a sample of the first one is:
(torch.Size([60000, 1]), tensor([16.6286], grad_fn=<SelectBackward0>))
I thought of defining the loss function as:
# Loss function def mnist_loss_cel(preds, targs): l = nn.CrossEntropyLoss() return l(preds, targs.squeeze())
But when I run a test it doesn’t seem to be working
# Initialise the weights weights = init_params((28*28,1)) bias = init_params(1) ## Creating the data loaders # Training data loader dl = DataLoader(train_dset, batch_size=256) # Validation data loader valid_dl = DataLoader(valid_dset, batch_size=256) # Create a batch for testing batch = train_x[:4] # Predict the result preds = linear1(batch)
The result of preds is:
tensor([[ 9.1690], [-1.9303], [ 8.9400], [ 2.8060]], grad_fn=<AddBackward0>)
But the result of my loss is always zero
loss = mnist_loss_cel(preds,train_y[:4]) loss
The result I keep getting is:
I ran a test with the original loss function:
# Original loss function def mnist_loss(preds, targs): preds = preds.sigmoid() return torch.where(targs==1, 1-preds, preds).mean() # Calculate the loss loss = mnist_loss(preds,train_y[:4]) loss
and I got what I would think is an appropriate result:
What am I missing and why isn’t it working properly?