Folks, I am trying to understand why this change is made in binary_cross_entropy loss vs mnist_loss.
Basically, in mnist_loss, the loss function uses torch.where as follows:
torch.where(targets==1, 1-predictions, predictions)
The intuition/explanation provided for the line of code above is that the function will measure how far/distant each prediction is from 1 if it should be 1, and how distant it is from 0 if it should be 0. This makes sense.
However, in the binary_cross_entropy loss function, the torch.where usage changes as follows:
torch.where(targets==1, predictions, 1-predictions)
I don’t understand why torch.where is not the same as in mnist_loss? Here as well, are we not trying to find how far/distant the prediction is from a 1 if it should be a 1, and how distant it is from a 0 if it should be a 0?
The full line of code is:
-torch.where(targets==1, predictions, 1-predictions).log()
Does taking the negative or the log have something to do with the swapping of arguments 2 and 3 compared to mnist_loss? Why should we not use within binary_cross_entropy loss the following:
torch.where(targets==1, 1-predictions, predictions)
If anyone has the intuition behind this change, please do share.