Variable value cannot be updated in custom loss

I defined a variable(self.A) in the MyLoss for segmentation task, but its value will not be updated during the training process, and no gradient value will be generated. What should I do to update the value of Variable?

class MyLoss(nn.Module):

def __init__(self):
    super().__init__()
    self.A = torch.nn.Parameter(torch.ones(16, 1, 62, 126, requires_grad=True).cuda())
    
def my_loss(self, input, target):
    x = input[:, :, 1:, :] - input[:, :, :-1, :] 
    y = input[:, :, :, 1:] - input[:, :, :, :-1]
    
    delta_x = x[:, :, 1:, :-2]**2
    delta_y = y[:, :, :-2, 1:]**2
    delta_u = torch.abs(delta_x + delta_y) 

    loss = torch.mean(delta_u * self.A)

    return loss

def forward(self, input, target):
    input_1 = F.softmax(input, dim=1)[:,1:,:,:]
    my_loss = self.my_loss(input_1, target) 
    
    print(self.A)
    
    return my_loss
data = get_databunch(img_path='...', lbl_path='...')
learn = unet_learner(data, models.resnet34, metrics=metrics, loss_func=MyLoss(), wd=1e-2)             
learn.fit_one_cycle(1, slice(3e-3), pct_start=0.8)

@andrew_tal
self.A is not a parameter of your model, rather a parameter of the Loss class.
Optimizers update the parameters of the model only (the ones that model.parameters() gives).
Thats the reason the tensor MyLoss.A is not updating.

Thanks a lot! Yes, I understand. But is it possible to add the variables in the loss to the range optimized by the fastai optimizer?

Unfortunately no! You’ll need to find a way around and include all parameters in your model. By default, it is the parameters of the model that need optimization, and loss is simply a function mapping between the target and the actual output.

By the Way, I dont know if this would work, but you could try for sure! What if you initialize the loss class before hand, and append the loss_func.A to the model.parameters() generator? That might work(again, not sure, just an idea :stuck_out_tongue:

Thank you very much, I will give try it, if there any progress, I will post the solution here, thanks again! :+1 :pray:

1 Like