Custom Loss Function - Greater Penalty For Wrong Sign

Hi everyone,

I’m working on a tabular learning model to predict a float value. The target values include both positive and negative numbers. I’d like to implement a custom loss function that applies a greater loss when the predicted value is a different sign from the target. Here’s what I have so far, but this does not work:

import fastai.tabular.all as fta

def Sign(x):
    if x < 0:
        return -1
    elif x > 0:
        return 1
    return 0

def CustomLoss(y, yhat):
    y, yhat = [float(x[0]) for x in list(y)], [float(x[0]) for x in list(yhat)]
    loss = 0
    for i in range(len(y)):
        a, b = y[i], yhat[i]
        if Sign(a) != Sign(b):
            loss += 5*(a - b)**2
            loss += (a - b)**2
    return loss

learn = fta.tabular_learner(dls, metrics=fta.mae, loss_func = CustomLoss)

Hi there,

from torch import nn
import torch

class SignLoss(nn.Module):
    def __init__(self):
        super(FunLoss, self).__init__()

    def forward(self, y, yhat):

        :param inputs:
        :param targets:
        y = y.view(-1)
        yhat = yhat.view(-1)
        residuals = yhat - y
        result = torch.where(torch.sign(y)!=torch.sign(yhat), torch.multiply(torch.pow(residuals, 2), 5), torch.pow(residuals, 2))
        return torch.sum(result)

This should do what you need.
Call it is

learn = fta.tabular_learner(dls, metrics=fta.mae, loss_func = SignLoss())

Ok, now I have more time, so some explanation:

  • it is encoded as a Torch Loss Function (but the bahvior should be rather similar). The loss function accepts self, input values (here: y) and target values (here yhat) - so it might even be that I swapped the y and yhat, please be aware of that!
  • .view can be interpreted as a reshaping, so both entities are reshaped to have the same form.
  • The residuals are your localized a-b but in vector form.
  • now, to simplify the next part unwrapping it a little may be useful
    torch.multiply(torch.pow(residuals, 2), 5), 
    torch.pow(residuals, 2)
  • torch.where does for all places where the condition (first input) is true the second input and otherwise the third input and merges accordingly
  • meaning: it calculates two signum values and compares them, if they are equal, in the end the second part is taken, otherwise the thrid part
    • second part: inner square, then *5
    • third part: just square it
  • and, as your code did, return the sum

Thank you for the reply!

When running this, I get “name ‘FunLoss’ is not defined”, even after running

from torch import nn
import torch

Where does FunLoss come from, so I can look further into this?

Oh, that was from when I wrote this small part of code

class SignLoss(nn.Module):
    def __init__(self):
        super(SignLoss, self).__init__()

this is how it should be, now there should be no FunLoss anymore.

This one works - Thank you for the guidance on this!