F.softmax output is NaN, resolved using Temperature

Hi there,
I’m trying to implement a NN for the complete MNIST set as suggested at the end for chapter 4.

I’ve almost done, but I’ve a problem with the last layer of the model, the F.softmax method.
Sometimes the output tensor from softmax contains NaN (not a number), while debugging I’ve seen that the input tensor for the softmax contains very large values, and the exponential inside the softmax transform those values to Infinite, and the final resulting value is NaN

For example, first epoch, first line of input tensor for softmax

tensor([[1.1537e-10, 3.7890e-33, 1.0000e+00, …, 6.8583e-22, 1.9325e-11,
2.3996e-06], etc.], grad_fn=)

the output is

tensor([[nan, nan, nan, …, nan, nan, nan], etc.])

I’ve resolved by writing my own softmax implementation:

def softmax(preds):
  temperature = 90
  ex = torch.exp(preds/temperature)
  return ex / torch.sum(ex, axis=0)

The key point I think is the temperature, I’ve set it to 90 because I’ve seen that the highest value in preds is 90 more or less, i think it acts like, i don’t know, it smooths the input preds… :man_shrugging:

BUT WHY I had to write my own softmax and F.softmax didn’t work for me?

The accuracy using F.softmax


The accuracy using my own softmax



This is called the log-sum trick, and this is one of the reasons why we compute the CrossEntropy with the activation at the same time. (cross entropy with logits).
Would you mind sharing the network you are using.
Probably it is missing some normalization.

thanks @tcapelle
the model is very simple

def model(x):
  res = x@w1 + b1
  res = res.max(tensor(0.0)) # ReLU
  res = res@w2 + b2
  res = softmax(res) # MY SOFTMAX MY SOFTMAX MY SOFTMAX
  return res

def softmax(preds):
  temperature = 90
  ex = torch.exp(preds/temperature)
  return ex / torch.sum(ex, axis=0)

And you train this with wich loss func?
You can remove the softmax from your func and use CrossEntropy directly (it contains the logits)

To learn better the training process I tried to wrote the cross-entropy loss function:

def loss_func(pred, y):
  global num_classes
  target = torch.nn.functional.one_hot(y, num_classes)
  return -torch.sum(target * torch.log(pred))

@tcapelle however the NaN issue occurs even at the first epoch, before the first loss evaluation.