Weighted rows loss function

Let’s say I want to do a regression task with MSE loss function but I also want to give more weight to certain observations, how could I pass in the weight of the batch to the loss function in fastai?

For example if I have this custom loss function:

def weighted_mse_loss(input, target, weight):
return torch.sum(weight * (input - target) ** 2)


1 Like

If you write your custom collate function to return input, [target,weight], you can have a custom loss function to handle this. I don’t know when your weight is determined, but when building the batch seems like a good time.

1 Like

Let’s say I already know in advance the weights of my observations, for example I want to put less weight on observations in the past for example. So I know all the weights before putting the data through fastai.

In this situation, I am not sure I understand how I would do that with a custom collate function.

The only way I can think of right now is that I would write a custom Module for my model where the forward method contains two tensor parameters, one for the actual data and another for the weights:

def forward(self, data:Tensor, weights:Tensor)

Then I guess I would have access to that in the loss function?


Or you could add them to your targets. The only downside is you would have a tensor of weights instead of just one, but you can extract the first/last.

1 Like

Makes sense! Just trying to add the weights to my targets by using label_from_func, but realizing that the func gets called with the index of the dataframe instead of a row in the dataframe. I see in fastai code that there`s a _label_from_list method, but since it begins with an _, I guess it is a private method. What would be the best way for me to add weights to my targets? Here is my code for reference.

data = TabularList.from_df(data, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)
	data = data.split_from_df(col='IsValidation')

def getLabel(x):
	return x[['item_cnt_month', 'Weight']]

data = (data.label_from_func(func=getLabel, label_cls=FloatList)
					  .databunch(bs=bs, num_workers=4))

learn = tabular_learner(data, layers=[500,250], ps=[0.01,0.1], emb_drop=0.08, metrics=[rmse], y_range=[0,22])

def weighted_mse_loss(input, target,weight):
	return torch.sum(weight * (input - target) ** 2)

learn.loss_func = weighted_mse_loss


Got this working using the private method _label_from_list, but I was wondering if I should use one of the public methods or maybe create a new public method to handle this scenario?

data = (TabularList.from_df(pets, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)
                            ._label_from_list(pets[['AdoptionSpeed', 'Weight']].values, label_cls=FloatList)

data.c = 1

learn = tabular_learner(data, layers=[200,100], y_range=[0, 4])

def weighted_mse_loss(input, target_weight):
	target = target_weight[:, 0][:, None]
	weight = target_weight[:, 1][:, None]
	return torch.sum(weight * (input - target) ** 2)

learn.loss_func = weighted_mse_loss

learn = learn.to_fp16()

learn.fit_one_cycle(1, 1e-3)


label_from_df with cols = ['AdoptionSpeed', 'Weight'] and label_cls=FloatList should also work.

1 Like

Indeed! Didn’t know we could pass an array to cols.


How would this work for classes/categorical targets?
Then I can’t have FloatList, also I think I need a CategoryList so that the classes are determined correctly?

Hi, did you manage to figure this out? Couldn’t get it to work with CategoryList. Any suggestions please @sgugger? Thanks