# Understanding this difference between binary_cross_entropy loss and mnist_loss

Folks, I am trying to understand why this change is made in binary_cross_entropy loss vs mnist_loss.

Basically, in mnist_loss, the loss function uses torch.where as follows:

``````torch.where(targets==1, 1-predictions, predictions)
``````

The intuition/explanation provided for the line of code above is that the function will measure how far/distant each prediction is from 1 if it should be 1, and how distant it is from 0 if it should be 0. This makes sense.

However, in the binary_cross_entropy loss function, the torch.where usage changes as follows:

``````torch.where(targets==1, predictions, 1-predictions)
``````

I don’t understand why torch.where is not the same as in mnist_loss? Here as well, are we not trying to find how far/distant the prediction is from a 1 if it should be a 1, and how distant it is from a 0 if it should be a 0?

The full line of code is:

``````-torch.where(targets==1, predictions, 1-predictions).log()
``````

Does taking the negative or the log have something to do with the swapping of arguments 2 and 3 compared to mnist_loss? Why should we not use within binary_cross_entropy loss the following:

``````torch.where(targets==1, 1-predictions, predictions)
``````

If anyone has the intuition behind this change, please do share.

UPDATED
It looks like both code and the book (at least the copy I have) contain an error. Your original formula is correct:

``````def binary_cross_entropy(inputs, targets):
inputs = inputs.sigmoid()
return -torch.where(targets==1, inputs, 1-inputs).log().mean()
``````

I think the intuition is this (someone please correct me if I’m wrong): if `target` is `1`, we want `input` to be closer to `1` and if `target` is `0`, we want `input` to be closer to `0`. `-log(x)` reachs `Infinity` when `x` reaches `0` and reachs `0` when `x` reaches `1`. Therefore, in the above function, if `target` is `1` and `input` is high, `-log(input)` yields low loss. Similarly, if `target` is `1` and `input` is low, `-log(input)` yields high loss. That’s exactly what is expected. Vice versa for when `target` is `0`.

That snippet doesn’t look correct. From https://github.com/fastai/fastbook/blob/master/06_multicat.ipynb, the code of `binary_cross_entropy` is:

``````def binary_cross_entropy(inputs, targets):
inputs = inputs.sigmoid()
return -torch.where(targets==1, 1-inputs, inputs).log().mean()
``````
1 Like

the change I think is because of the logarithm.
value of predictions in both losses is between [0,1]
so applying log give values ranging between [-inf,0]
when you are expecting your targets to be 1 and if predictions are closer to 0 you get a higher loss value.

Thanks a lot for checking and the great explanation. That makes sense to me.

My copy of the textbook, and my copy of the notebook from the fastback repo, has the definition I wrote about, i.e.

``````def binary_cross_entropy(inputs, targets):
inputs = inputs.sigmoid()
return -torch.where(targets==1, inputs, 1-inputs).log().mean()``````
1 Like

Thank you so much for this! I have spent a lot of time wondering how the (wrong) function could be considered a loss function since it obviously increases when predictions are closer to targets. It made me question my whole understanding of the book so far. Glad to have it resolved!

1 Like