Problem creating custom loss function

I am trying to create and use a custom loss function. When my initial attempts failed I decided to take a step back and implement (through cut and paste) the standard loss function used with a unet Learner in my own notebook. I thought this would be a good way to check my understanding of the size of the tensor inputs and see where the inputs differed between the standard loss function and the ones I first created.

To my disappointment my “cut and paste” loss function also does not work in that an exception is thrown during lr_find.

/opt/anaconda3/lib/python3.7/site-packages/torch/nn/ in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1786     if input.size(0) != target.size(0):
   1787         raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 1788                          .format(input.size(0), target.size(0)))
   1789     if dim == 2:
   1790         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

ValueError: Expected input batch_size (65536) to match target batch_size (8192).

I would appreciate some insight into what I am doing wrong.

Initial standard fastai code which does work:

learn = unet_learner(data, models.resnet34, metrics=[], wd=wd)
print('Loss func ', learn.loss_func)

Loss func FlattenedLoss of CrossEntropyLoss()

Here is the code I’ve pasted in (and renamed) that fails.

class MyFlattenedLoss():
    "Same as `func`, but flattens input and target."
    def __init__(self, func, *args, axis:int=-1, floatify:bool=False, is_2d:bool=True, **kwargs):
        self.func,self.axis,self.floatify,self.is_2d = func(*args,**kwargs),axis,floatify,is_2d

    def __repr__(self): return f"My FlattenedLoss of {self.func}"
    def reduction(self): return self.func.reduction
    def reduction(self, v): self.func.reduction = v

    def __call__(self, input:Tensor, target:Tensor, **kwargs)->Rank0Tensor:
        print('input shape ', input.shape)
        print('target shape ', target.shape)
        input = input.transpose(self.axis,-1).contiguous()
        target = target.transpose(self.axis,-1).contiguous()
        print('input shape ', input.shape)
        print('target shape ', target.shape)
        if self.floatify: target = target.float()
        input = input.view(-1,input.shape[-1]) if self.is_2d else input.view(-1)
        print('input shape ', input.shape)
        print('target shape ', target.shape)
        print('floatify', self.floatify, ' 2d ', self.is_2d)
        print('kwargs ', kwargs)
        print('Func ', self.func)
        print('target view ', target.view(-1).shape)
        return self.func.__call__(input, target.view(-1), **kwargs)    

def MyCrossEntropyFlat(*args, axis:int=-1, **kwargs):
    "Same as `nn.CrossEntropyLoss`, but flattens input and target."
    return MyFlattenedLoss(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)

​learn = unet_learner(data, models.resnet34, metrics=[], wd=wd)
learn.loss_func = MyCrossEntropyFlat()
print('Loss func ', learn.loss_func)

Loss func My FlattenedLoss of CrossEntropyLoss()

Exception occurs calling lr_find


Note that the learner is setup to use a batch size of 8, there are 256 classes, and the images have been
specified to be resized to [32,32]

The following output is captured before the exception:

input shape  torch.Size([8, 256, 32, 32])
target shape  torch.Size([8, 1, 32, 32])
input shape  torch.Size([8, 256, 32, 32])
target shape  torch.Size([8, 1, 32, 32])
input shape  torch.Size([65536, 32])
target shape  torch.Size([8, 1, 32, 32])
floatify False  2d  True
kwargs  {}
Func  CrossEntropyLoss()
target view  torch.Size([8192])

Try: learn.loss_func = MyCrossEntropyFlat(axis=1), thats the channel that indicates the labels.


Thank you! Specifying the axis index solved the issue.


I’m having a hard time understanding how unet_learner specify the default loss_func.
i.e. I found the following line of code under unet_learner() but it doesn’t seem to specify a loss_func?

learn = Learner(data, model, **learn_kwargs)

The default loss function is specified in the type of item list you use. For instance, with SegmentationLabelList it is CrossEntropyFlat as you can see below:

class SegmentationLabelList(ImageList):
    "`ItemList` for segmentation masks."
    def __init__(self, items:Iterator, classes:Collection=None, **kwargs):
        super().__init__(items, **kwargs)
        self.classes,self.loss_func = classes,CrossEntropyFlat(axis=1)

    def open(self, fn): return open_mask(fn)
    def analyze_pred(self, pred, thresh:float=0.5): return pred.argmax(dim=0)[None]
    def reconstruct(self, t:Tensor): return ImageSegment(t)

If you are not using or inheriting an item list that has a loss function specified, you either need to put one yourself or to specify the one you want when calling unet_learner.

1 Like


I hope I found the best place to ask my question:

I created a custom loss function and it seems to work.
However, now lr_find aborts prematurely and no plot can be shown.

Do you know what criteria a loss function needs to fulfill to be able to run lr_find successfully on it?


What error does it show ? To me, a loss just needs to extend nn.Module and implement forward. For instance, I made a focal loss like this:

def focal_loss(input, target, reduction='mean', beta=0.5, gamma=2., eps=1e-7, **kwargs):
    n = input.size(0)
    iflat = torch.sigmoid(input).view(n, -1).clamp(eps, 1-eps)
    tflat = target.view(n, -1)
    focal = -(beta*tflat*(1-iflat).pow(gamma)*iflat.log()+
    if torch.isnan(focal.mean()) or torch.isinf(focal.mean()):
    if reduction == 'mean':
        return focal.mean()
    elif reduction == 'sum':
        return focal.sum()
        return focal

class FocalLoss(nn.Module):
    def __init__(self, beta=0.5, gamma=2., reduction='mean'):
        self.beta = beta
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, input, target, **kwargs):
        return focal = focal_loss(input, target, beta=self.beta, gamma=self.gamma, reduction=self.reduction, **kwargs)

Hi @florobax,

it actually shows no error.
Instead the lr_find terminates saying ’ LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.’ pretty much immediately. And the plot is empty.

As there is no error, I am a bit confused where to look for the problem…

What is the source code for your loss? What does your_loss(next(iter( output ?

Hey @florobax,

I just found the issue…
My loss function returns a logarithm which can get negative (as the literature suggested here that this is beneficial for training).

In the source code of the LRFind however the stop criterion is a simple comparison between the soothloss (not sure what exactly that is) and the current best loss: smooth_loss > 4*self.best_loss which implicitly assumes a min loss of 0 it seems.

As soon as the best_loss is negative the whole thing aborts rather quickly.

I fixed the issue by adding a large enough offset to my loss function although this seems to be a very hacky solution.

Thanks for the help though!

Oh indeed a negative loss is pretty uncommon. The offset is a hacky solution but should not impact anything (especially the gradients), so it’s probably a good one. It seems obscure to me as to why a negative loss could be beneficial for training though. Because if that’s not important, you can also change your log(x) into log(1+x) which will have the same shape but always be positive. But that’s not really important, sometimes a hacky solution is the best solution.
Just out of curiosity, what is this log function you are using ?

PS: The smoothloss is just a somehow averaged version of the loss, so it avoids too high variations and gives you a nice curve.

I tried to implement the squared weighted kappa loss function according to a paper named

Weighted kappa loss function for multi-class classification of ordinal data in deep learning
by Torre et. al. you can find here: //

But I agree, since it is rather obscure I am not sure whether it is worth opening a Github issue / Feature Request for it.

Thanks! I hate it when you have to pay for an article… Yeah I’m not sure this is a real issue as long as almost nobody uses negative losses.

I see… Thank you :slight_smile:

1 Like

Thanks for the example! How does the focal_loss function refer to your FocalDiceLoss class?

I should have named the class FocalLoss, just edited it. focal_loss is just the functional version of FocalLoss. For instance if you do:

loss_func = FocalLoss()
loss = loss_func(y_pred, y_true)

The second line actually calls the forward method from the FocalLoss class, which calls focal_loss. Having a class and a functional version is actually not necessary, you can use either alone, but I find a class handy to store hyperparameters and the function makes it cleaner.

1 Like

Got it, thanks! Makes sense.

Do you find Focal Loss helpful for binary segmentation with class imbalance? If you have any nice examples to share would love to check it out.

In my particular problem it didn’t help much but it doesn’t mean it will never work, you’d need to test. Various options you can try when you have class imbalance would be weighted cross-entropy, different versions of dice loss (be careful can be very unstable in some cases), lovasz loss, etc. I have no working example saved with focal loss, tried it at some point but was inconclusive so I didn’t keep the data.

1 Like