Update weights of a custom parameter inside a loss function

Hi guys,
I’m trying to implement a custom loss function for multi-task learning that has learnable parameters to weight each task. I’m following this blog post. Here’s the loss function:

class MultiTaskLossWrapper(nn.Module):
    def __init__(self, task_num, **kwargs):
        super(MultiTaskLossWrapper, self).__init__()
        self.task_num = task_num
        self.log_vars = nn.Parameter(torch.zeros((task_num)).normal_(0, 0.01))
        #self.weights = kwargs['weights']

    def forward(self, inputs, target):
        criterion = [BCEWithLogitsLossFlat() for x in range(self.task_num)]
        losses = torch.zeros(1).to(target.device)        
        for j, crit in enumerate(criterion):
            mask = target[:, j] >= 0.0
            if len(inputs[:, j][mask]) != 0:
                scaling = torch.exp(-self.log_vars[j]) 
                losses += (scaling * crit(inputs[:, j][mask], target[:, j][mask])) + self.log_vars[j]
        return losses.sum()

The log_vars variable is the custom learnable weights I’m trying to update. I realised that the gradients are computed just fine and I see they changed from None to something. But when I try to access the log_vars after training, the values are still the same. I’m not sure if I made a mistake or this is just the expected behaviour from Pytorch. Here’s a sample of what I did:

Before training

loss_func = MultiTaskLossWrapper(task_num = len(targets))
print(loss_func.log_vars)

Parameter containing:
tensor([ 0.0096, 0.0197, -0.0107, 0.0072, -0.0099, -0.0091, 0.0079, 0.0079,
-0.0022, -0.0134, 0.0019, -0.0003, -0.0062, 0.0001, -0.0010, -0.0034,
0.0174, 0.0008, -0.0119, 0.0042, 0.0050, -0.0005, -0.0110, 0.0151,
-0.0095, 0.0089, -0.0082, 0.0051, 0.0104, -0.0164, 0.0190, 0.0030,
0.0147, 0.0098, 0.0031, -0.0106, 0.0067, -0.0159, -0.0148, 0.0160,
-0.0032, 0.0129, -0.0064, 0.0010, -0.0177, 0.0057, 0.0171, -0.0217,
0.0054, 0.0072, -0.0002, 0.0008, -0.0141, 0.0022, -0.0155, 0.0038,
0.0043, -0.0091, -0.0095, -0.0153, 0.0118, 0.0117, 0.0124, 0.0043,
0.0134, -0.0050, 0.0085, -0.0009, -0.0060, 0.0127, 0.0079, -0.0013,
-0.0094, -0.0085, -0.0030, -0.0067, 0.0103, -0.0046, 0.0074, -0.0003,
0.0132, -0.0152, 0.0090, -0.0061, -0.0137, -0.0168, 0.0032, 0.0177,
-0.0031, -0.0075, -0.0192], requires_grad=True)

After training:

print(learn.loss_func.log_vars)
Parameter containing:
tensor([ 0.0096, 0.0197, -0.0107, 0.0072, -0.0099, -0.0091, 0.0079, 0.0079,
-0.0022, -0.0134, 0.0019, -0.0003, -0.0062, 0.0001, -0.0010, -0.0034,
0.0174, 0.0008, -0.0119, 0.0042, 0.0050, -0.0005, -0.0110, 0.0151,
-0.0095, 0.0089, -0.0082, 0.0051, 0.0104, -0.0164, 0.0190, 0.0030,
0.0147, 0.0098, 0.0031, -0.0106, 0.0067, -0.0159, -0.0148, 0.0160,
-0.0032, 0.0129, -0.0064, 0.0010, -0.0177, 0.0057, 0.0171, -0.0217,
0.0054, 0.0072, -0.0002, 0.0008, -0.0141, 0.0022, -0.0155, 0.0038,
0.0043, -0.0091, -0.0095, -0.0153, 0.0118, 0.0117, 0.0124, 0.0043,
0.0134, -0.0050, 0.0085, -0.0009, -0.0060, 0.0127, 0.0079, -0.0013,
-0.0094, -0.0085, -0.0030, -0.0067, 0.0103, -0.0046, 0.0074, -0.0003,
0.0132, -0.0152, 0.0090, -0.0061, -0.0137, -0.0168, 0.0032, 0.0177,
-0.0031, -0.0075, -0.0192], requires_grad=True)

EDIT:

Ok, so the parameters inside the loss function are not among the model parameters. I think that’s why they are not being updated. But the question stands: how to make this loss function update its weights during training?

1 Like

@marcossantana Did you have any luck with getting this working?

Judging from the original post, a simple way to do it would be to include loss module into your model and targets into xb and set learner loss_func to be dummy returning the loss computed by the model. There are multiple ways to do it technically, most probably utilizing callbacks more would be best.

I think I have got this working (maybe in a slightly crude way) by passing the params from the loss function to the Learner split. I’m not sure this is the best way of doing this so I will keep digging into this. I can see that the loss parameters are updating during training but I’m not convinced it’s working exactly as I think it is :grimacing: