I’m trying to understand the implementation of focal loss, partly so I can port it to other areas. I’ve seen a few implementations online, but they all implement from scratch and the lesson 9 notebook it strikes me as particularly efficient given that it uses the weight parameter of the binary cross entropy function. I’d like to do the same for CE.
I want to confirm my understanding though that for focal loss we’re multiplying the cross entropy loss (-log(pt)) with a*(1-pt)^gamma and we can do that by setting weight in the F.cross_entropy_loss to that precalculated value.
alpha,gamma = 0.25,1
p = x.sigmoid()
pt = p * t + (1-p) * (1-t)
w = alpha * t + (1-alpha) * (1-t)
return w * (1-pt).pow(gamma)
Similarly, I want to confirm we’re using sigmoid here because we’re doing BCE? And if I wanted to apply this to CE then the appropriate function would be .softmax() but the function would otherwise be identical?