# Recalculating Loss Function using PyTorch

I’m trying to recalculate the loss returned from `learn.get_preds(with_loss=True)` using `torch.nn` and am getting different results.

``````learn.loss_func
``````

FlattenedLoss of CrossEntropyLoss()

``````valid_preds, valid_y, valid_losses = learn.get_preds(with_loss=True)
valid_losses.mean().item()
``````

0.2517395835593471

``````loss = nn.CrossEntropyLoss(reduction='mean')

output = loss(valid_preds, valid_y)
output.item()
``````

2.7582759857177734

I have confirmed that the `FlattenLoss` has no impact on the data shape in this case. I’m not that familiar with pytorch so I’m guessing I’m doing something wrong there.

I figured it out… Answering my own question here in case someone else runs into the same issue.

`torch.nn.CrossEntropyLoss` applies a softmax to the raw prediction output of the `learn.model` method.

For example:

I first get the output of the model. In this case i’m just pulling the first batch. To compute the loss for the entire validation set I would need to loop through the entire validation set. For this example one batch should suffice.

``````x, target = next(iter(learn.data.valid_dl))
preds = learn.model(x)
preds
​
tensor([[ 0.6294,  0.5137,  2.7246,  ..., -2.0160,  0.4596, -1.2756],
[-0.0355, -2.4116, 14.5470,  ..., -0.5973, -2.3760, -1.9473],
[-3.0890, -2.2336,  3.6774,  ..., -2.2486, -1.3904,  0.2988],
...,
[-2.3504,  0.5370, -0.1560,  ...,  1.7313, -3.0057, -2.0379],
[-2.2719, -0.1256, -3.6734,  ..., 12.2523, -0.1865,  0.5820],
[-2.9343, -1.9531,  0.4142,  ...,  2.2272, -0.3494, -3.1263]],
``````

Now calculate the torch `CrossEntropyLoss`

``````loss = nn.CrossEntropyLoss(reduction='mean')
​
output = loss(preds, target)
output.item()
0.3030180037021637
``````

Recomputing using scikit learn and scipy. The softmax function turns the numeric output of the neural network into probabilities.

``````from scipy.special import softmax
​from sklearn.metrics import log_loss

log_loss(target.cpu().numpy(),
softmax(preds.cpu().detach().numpy()),
labels = list(range(0,37)))
0.3030180556320232
``````