Recalculating Loss Function using PyTorch

I’m trying to recalculate the loss returned from learn.get_preds(with_loss=True) using torch.nn and am getting different results.

learn.loss_func

FlattenedLoss of CrossEntropyLoss()

valid_preds, valid_y, valid_losses = learn.get_preds(with_loss=True)
valid_losses.mean().item()

0.2517395835593471

loss = nn.CrossEntropyLoss(reduction='mean')

output = loss(valid_preds, valid_y)
output.item()

2.7582759857177734

I have confirmed that the FlattenLoss has no impact on the data shape in this case. I’m not that familiar with pytorch so I’m guessing I’m doing something wrong there.

I figured it out… Answering my own question here in case someone else runs into the same issue.

torch.nn.CrossEntropyLoss applies a softmax to the raw prediction output of the learn.model method.

For example:

I first get the output of the model. In this case i’m just pulling the first batch. To compute the loss for the entire validation set I would need to loop through the entire validation set. For this example one batch should suffice.

x, target = next(iter(learn.data.valid_dl))
preds = learn.model(x)
preds
​
tensor([[ 0.6294,  0.5137,  2.7246,  ..., -2.0160,  0.4596, -1.2756],
        [-0.0355, -2.4116, 14.5470,  ..., -0.5973, -2.3760, -1.9473],
        [-3.0890, -2.2336,  3.6774,  ..., -2.2486, -1.3904,  0.2988],
        ...,
        [-2.3504,  0.5370, -0.1560,  ...,  1.7313, -3.0057, -2.0379],
        [-2.2719, -0.1256, -3.6734,  ..., 12.2523, -0.1865,  0.5820],
        [-2.9343, -1.9531,  0.4142,  ...,  2.2272, -0.3494, -3.1263]],
       device='cuda:0', grad_fn=<AddmmBackward>)

Now calculate the torch CrossEntropyLoss

loss = nn.CrossEntropyLoss(reduction='mean')
​
output = loss(preds, target)
output.item()
0.3030180037021637

Recomputing using scikit learn and scipy. The softmax function turns the numeric output of the neural network into probabilities.

from scipy.special import softmax
​from sklearn.metrics import log_loss

log_loss(target.cpu().numpy(), 
         softmax(preds.cpu().detach().numpy()),
         labels = list(range(0,37)))
0.3030180556320232