I’m working on a tabular learning model to predict a float value. The target values include both positive and negative numbers. I’d like to implement a custom loss function that applies a greater loss when the predicted value is a different sign from the target. Here’s what I have so far, but this does not work:
import fastai.tabular.all as fta
if x < 0:
elif x > 0:
def CustomLoss(y, yhat):
y, yhat = [float(x) for x in list(y)], [float(x) for x in list(yhat)]
loss = 0
for i in range(len(y)):
a, b = y[i], yhat[i]
if Sign(a) != Sign(b):
loss += 5*(a - b)**2
loss += (a - b)**2
learn = fta.tabular_learner(dls, metrics=fta.mae, loss_func = CustomLoss)
from torch import nn
def forward(self, y, yhat):
y = y.view(-1)
yhat = yhat.view(-1)
residuals = yhat - y
result = torch.where(torch.sign(y)!=torch.sign(yhat), torch.multiply(torch.pow(residuals, 2), 5), torch.pow(residuals, 2))
This should do what you need.
Call it is
learn = fta.tabular_learner(dls, metrics=fta.mae, loss_func = SignLoss())
Ok, now I have more time, so some explanation:
- it is encoded as a Torch Loss Function (but the bahvior should be rather similar). The loss function accepts self, input values (here: y) and target values (here yhat) - so it might even be that I swapped the y and yhat, please be aware of that!
- .view can be interpreted as a reshaping, so both entities are reshaped to have the same form.
- The residuals are your localized a-b but in vector form.
- now, to simplify the next part unwrapping it a little may be useful
torch.multiply(torch.pow(residuals, 2), 5),
- torch.where does for all places where the condition (first input) is true the second input and otherwise the third input and merges accordingly
- meaning: it calculates two signum values and compares them, if they are equal, in the end the second part is taken, otherwise the thrid part
- second part: inner square, then *5
- third part: just square it
- and, as your code did, return the sum
Thank you for the reply!
When running this, I get “name ‘FunLoss’ is not defined”, even after running
from torch import nn
Where does FunLoss come from, so I can look further into this?
Oh, that was from when I wrote this small part of code
this is how it should be, now there should be no FunLoss anymore.
This one works - Thank you for the guidance on this!