Extending Focal Loss to Multiclass Detection

Hi guys!

I wanted to see if my intuition behind focal loss is correct.

Basically Focal Loss is only concerned with answering the question of “does this anchor box contains an object or is a background?”. This means that it does not provide you with what class the object is. Hence, we use Binary Cross Entropy after applying Sigmoid activation to do the prediction for the anchor box.

To know what class the anchor box contains we need to use Categorical Cross Entropy after applying Softmax on a one-hot encoding of the vector w.r.t to the anchor box.

Does this mean to extend Focal Loss we simply have to switch Sigmoid with Softmax and Binary Cross Entropy with Categorical Cross Entropy?

Hi. I has a multiclass focal loss implementation. See this kornia discussion.

I had this need as well. Here’s what I came up with:
This is bad advice. It seems like it works, but isn’t really working properly. Need to do more work to get it to actually work by changing the underlying FocalLoss

class MCFocalLoss(FocalLoss):
    def __init__(self, thresh=0.5, **kwargs):
        store_attr()
        super().__init__(**kwargs)
    
    def activation(self, x): return torch.sigmoid(x)
    
    def decodes(self, x): return x>self.thresh

I haven’t tested it that extensively yet, but so far it seems to be working properly