Loss function of unet_learner (FlattenedLoss of CrossEntropyLoss)

i make a unet learner

learn = unet_learner(data, models.resnet18)
print(learn.loss_func)

i get

FlattenedLoss of CrossEntropyLoss()

Question1: what does Flattened mean here and why is it used at all? Lets say i have my input images of shape [bs, 3, 200, 200] and i have 5 classes in total. so i would expect the model to output a feature map of shape [bs, 5, 200, 200 ]. pls corrcet me if i am wrong.

Question2: how do i change the default loss fucntion here? Since i have an unbalanced dataset, i would like to try out the Focal Loss.

1 Like

The loss function can be changed by passing the loss_func parameter into the unet_learner. You can see an example below. The following is for the case of image classification

class FocalLoss(nn.Module):
def __init__(self, gamma=3., reduction='mean'):
    super().__init__()
    self.gamma = gamma
    self.reduction = reduction

def forward(self, inputs, targets):
    CE_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets)
    pt = torch.exp(-CE_loss)
    F_loss = ((1 - pt)**self.gamma) * CE_loss
    if self.reduction == 'sum':
        return F_loss.sum()
    elif self.reduction == 'mean':
        return F_loss.mean()

From the top of my head, I think flattened loss means that the output matrix is being flattened and after which the loss is computed (no difference from the direct method)

2 Likes

thanks a lot for replying.
can i conclude that its not necessary to use the flattened loss and we can simply use the crossEntropyLoss as the loss_function?

if so i assume that ur implementation of FocalLoss can also be used as the loss function in the unet_learner?