i make a unet learner
learn = unet_learner(data, models.resnet18)
print(learn.loss_func)
i get
FlattenedLoss of CrossEntropyLoss()
Question1: what does Flattened mean here and why is it used at all? Lets say i have my input images of shape [bs, 3, 200, 200] and i have 5 classes in total. so i would expect the model to output a feature map of shape [bs, 5, 200, 200 ]. pls corrcet me if i am wrong.
Question2: how do i change the default loss fucntion here? Since i have an unbalanced dataset, i would like to try out the Focal Loss.
1 Like
The loss function can be changed by passing the loss_func
parameter into the unet_learner
. You can see an example below. The following is for the case of image classification
class FocalLoss(nn.Module):
def __init__(self, gamma=3., reduction='mean'):
super().__init__()
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
CE_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets)
pt = torch.exp(-CE_loss)
F_loss = ((1 - pt)**self.gamma) * CE_loss
if self.reduction == 'sum':
return F_loss.sum()
elif self.reduction == 'mean':
return F_loss.mean()
From the top of my head, I think flattened loss means that the output matrix is being flattened and after which the loss is computed (no difference from the direct method)
2 Likes
thanks a lot for replying.
can i conclude that its not necessary to use the flattened loss and we can simply use the crossEntropyLoss as the loss_function?
if so i assume that ur implementation of FocalLoss can also be used as the loss function in the unet_learner?