Applying FocalLoss in fastai

I’m trying to apply FocalLoss in fastai as a custom loss function to train a model that has dense multi-label classification problem.
I found this code for pytorch to calc FocalLoss: https://github.com/clcarwin/focal_loss_pytorch

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int,int)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

On learn.fit the following error is returned:
> ---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
in
----> 1 learn.lr_find()

/opt/conda/envs/fastai/lib/python3.6/site-packages/fastai/train.py in lr_find(learn, start_lr, end_lr, num_it, stop_div, **kwargs)
     28     cb = LRFinder(learn, start_lr, end_lr, num_it, stop_div)
     29     a = int(np.ceil(num_it/len(learn.data.train_dl)))
---> 30     learn.fit(a, start_lr, callbacks=[cb], **kwargs)
     31 
     32 def to_fp16(learn:Learner, loss_scale:float=512., flat_master:bool=False)->Learner:

/opt/conda/envs/fastai/lib/python3.6/site-packages/fastai/basic_train.py in fit(self, epochs, lr, wd, callbacks)
    160         callbacks = [cb(self) for cb in self.callback_fns] + listify(callbacks)
    161         fit(epochs, self.model, self.loss_func, opt=self.opt, data=self.data, metrics=self.metrics,
--> 162             callbacks=self.callbacks+callbacks)
    163 
    164     def create_opt(self, lr:Floats, wd:Floats=0.)->None:

/opt/conda/envs/fastai/lib/python3.6/site-packages/fastai/basic_train.py in fit(epochs, model, loss_func, opt, data, callbacks, metrics)
     92     except Exception as e:
     93         exception = e
---> 94         raise e
     95     finally: cb_handler.on_train_end(exception)
     96 

/opt/conda/envs/fastai/lib/python3.6/site-packages/fastai/basic_train.py in fit(epochs, model, loss_func, opt, data, callbacks, metrics)
     82             for xb,yb in progress_bar(data.train_dl, parent=pbar):
     83                 xb, yb = cb_handler.on_batch_begin(xb, yb)
---> 84                 loss = loss_batch(model, xb, yb, loss_func, opt, cb_handler)
     85                 if cb_handler.on_batch_end(loss): break
     86 

/opt/conda/envs/fastai/lib/python3.6/site-packages/fastai/basic_train.py in loss_batch(model, xb, yb, loss_func, opt, cb_handler)
     23 
     24     if opt is not None:
---> 25         loss = cb_handler.on_backward_begin(loss)
     26         loss.backward()
     27         cb_handler.on_backward_end()

/opt/conda/envs/fastai/lib/python3.6/site-packages/fastai/callback.py in on_backward_begin(self, loss)
    226     def on_backward_begin(self, loss:Tensor)->None:
    227         "Handle gradient calculation on `loss`."
--> 228         self.smoothener.add_value(loss.detach().cpu())
    229         self.state_dict['last_loss'], self.state_dict['smooth_loss'] = loss, self.smoothener.smooth
    230         for cb in self.callbacks:

/opt/conda/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/module.py in __getattr__(self, name)
    523                 return modules[name]
    524         raise AttributeError("'{}' object has no attribute '{}'".format(
--> 525             type(self).__name__, name))
    526 
    527     def __setattr__(self, name, value):

AttributeError: 'FocalLoss' object has no attribute 'detach'

Any ideas on how to define a custom loss function in fastai so that it gets the detach attribute?

3 Likes

Maybe this will help.

This is in fastai v0.7 and I believe he is referencing his own custom function. This is exactly what I’m trying to do in 1.0.29 but my FocalLoss is generating error above.

He provide his FocalLoss() in the post here:

1 Like

I am also using focal loss but just in pytorch. There error is in “loss.detach().cpu()” .detach().cpu() is done to convert a .cuda() tensor type to cpu or to take the variable from gpu to cpu. So your varible is not on the gpu. Try returning loss.mean().cuda() and loss.sum().cuda() , may fix it. Or try making all the variables in it that are passed in also .cuda()

2 Likes

That is really cool, I wanted to know how you can use pytorch functions. Are you thinking there is any downside in using pytorch functions?

I couldn’t it as a pre built function in v3, but i re-watched lesson 9 and it seem(I haven’t tried it) like all you would have to do is define it and use it as the loss function. Here is the part of the video he talked about focal loss.

the lecture 9 notes also talk from timlee also talk about focal loss.

1 Like

Champion. I was messing around on a cheaper virtual machine on paper space that does not have a GPU.
Lesson learned: I’ve run the code on another machine (with GPU) and it resolved the issue. I focused on the .detach() and not the .cpu() in my troubleshooting.

We’ve defined lots of custom loss functions in the lessons - so yes absolutely you should try this! :slight_smile: Most loss functions are simple pytorch functions - you don’t often need to define a class for them.

4 Likes

Is there any such loss function as multi-label focal loss? I’m trying to determine the best way to deal with unbalanced data in the multi-label case.