Understanding this difference between binary_cross_entropy loss and mnist_loss

Folks, I am trying to understand why this change is made in binary_cross_entropy loss vs mnist_loss.

Basically, in mnist_loss, the loss function uses torch.where as follows:

torch.where(targets==1, 1-predictions, predictions)

The intuition/explanation provided for the line of code above is that the function will measure how far/distant each prediction is from 1 if it should be 1, and how distant it is from 0 if it should be 0. This makes sense.

However, in the binary_cross_entropy loss function, the torch.where usage changes as follows:

torch.where(targets==1, predictions, 1-predictions)

I don’t understand why torch.where is not the same as in mnist_loss? Here as well, are we not trying to find how far/distant the prediction is from a 1 if it should be a 1, and how distant it is from a 0 if it should be a 0?

The full line of code is:

-torch.where(targets==1, predictions, 1-predictions).log()

Does taking the negative or the log have something to do with the swapping of arguments 2 and 3 compared to mnist_loss? Why should we not use within binary_cross_entropy loss the following:

torch.where(targets==1, 1-predictions, predictions)

If anyone has the intuition behind this change, please do share.

UPDATED
It looks like both code and the book (at least the copy I have) contain an error. Your original formula is correct:

def binary_cross_entropy(inputs, targets):
    inputs = inputs.sigmoid()
    return -torch.where(targets==1, inputs, 1-inputs).log().mean()

I think the intuition is this (someone please correct me if I’m wrong): if target is 1, we want input to be closer to 1 and if target is 0, we want input to be closer to 0. -log(x) reachs Infinity when x reaches 0 and reachs 0 when x reaches 1. Therefore, in the above function, if target is 1 and input is high, -log(input) yields low loss. Similarly, if target is 1 and input is low, -log(input) yields high loss. That’s exactly what is expected. Vice versa for when target is 0.

Original answer
That snippet doesn’t look correct. From https://github.com/fastai/fastbook/blob/master/06_multicat.ipynb, the code of binary_cross_entropy is:

def binary_cross_entropy(inputs, targets):
    inputs = inputs.sigmoid()
    return -torch.where(targets==1, 1-inputs, inputs).log().mean()
1 Like

the change I think is because of the logarithm.
value of predictions in both losses is between [0,1]
so applying log give values ranging between [-inf,0]
when you are expecting your targets to be 1 and if predictions are closer to 0 you get a higher loss value.

Thanks a lot for checking and the great explanation. That makes sense to me.

My copy of the textbook, and my copy of the notebook from the fastback repo, has the definition I wrote about, i.e.

def binary_cross_entropy(inputs, targets):
    inputs = inputs.sigmoid()
    return -torch.where(targets==1, inputs, 1-inputs).log().mean()
1 Like

Thank you so much for this! I have spent a lot of time wondering how the (wrong) function could be considered a loss function since it obviously increases when predictions are closer to targets. It made me question my whole understanding of the book so far. Glad to have it resolved!

1 Like

You’re welcome! Glad it helps.

1 Like