Questions on custom loss function and gradients

I am working on an image segmentation problem where classes near to one another in numeric value are similar. I am trying to implement a custom loss function which uses a cost matrix to penalize predictions that are further from the target class more than predictions which are near.

I am trying to model the custom loss function based on the documentation from Pytorch here: https://pytorch.org/docs/stable/notes/extending.html?highlight=extend%20autograd

In the custom loss function the cost matrix will be of this form:

tensor([[0., 1., 2., 3., 4., 5.],
        [1., 0., 1., 2., 3., 4.],
        [2., 1., 0., 1., 2., 3.],
        [3., 2., 1., 0., 1., 2.],
        [4., 3., 2., 1., 0., 1.],
        [5., 4., 3., 2., 1., 0.]], device='cuda:0', requires_grad=True)

During training in my implementation the valid_loss and metrics never update which I believe signifies that the autograd framework cannot compute the gradient on its own. I tried adding an implementation for backward, but that doesn’t appear to be called.

I would appreciate some guidance on the following questions:

  1. Is there a way to update the implementation below or an equivalent alternative approach so that the framework is able to compute the gradient automatically?
  2. If not, what changes are needed to provide a gradient calculation that will be called by the framework?

Details below:

Loss function

class CostMatrixLoss(torch.nn.Module):
    def __init__(self, cost):
        super(CostMatrixLoss, self).__init__()
        self.cost = cost

    def forward(self, input, target):
                
        cmf = CostMatrixFunction()
        score = cmf.forward(input, target, self.cost)
        return score
      
class CostMatrixFunction(torch.autograd.Function):

    def __init__(self):
        super(CostMatrixFunction, self).__init__()

    # Note that both forward and backward are @staticmethods
    @staticmethod
    def forward(input, target, cost):
        #ctx.save_for_backward(input, cost)
  
        preds = torch.softmax(input, 1)
  
        outsize = preds.size()
        num_classes = outsize[1]
        targ_hot = to_one_hot(target, outsize, num_classes)
        
        # mtrx multiply predictions by cost to penalize classes further from target
        preds_cost = preds @ cost
        
        # multiply by one hot encoding to keep only the 1 cell/row which gives the cost
        final_cost = preds_cost * targ_hot
        score = final_cost.sum() * cost.size(0)/outsize[0]                       
        return score

    @staticmethod
    def backward(grad_out):
        print('Custom backward called!')
        #input, cost = ctx.saved_tensors
        
        grad_input = grad_out @ cost_tensor
        return grad_input, None

Invoking code

cost_matrix = create_cost_matrix(1, num_classes, epsilon)
cost_tensor = torch.from_numpy(cost_matrix).cuda()
#cost_tensor.requires_grad=True

learn = unet_learner(data, models.resnet34, metrics=metrics, wd=wd)
learn.loss_func = FlattenedLoss(CostMatrixLoss, axis=1, cost=cost_tensor)
learn.fit_one_cycle(10, slice(lr))

def to_one_hot(y, outsize=None, n_dims=None):
    y = y.view(-1, 1)
    n_dims = n_dims if n_dims is not None else int(torch.max(y)) + 1
    zeroes_t = torch.zeros(y.size()[0], n_dims, device=y.device)
    y_one_hot = zeroes_t.scatter_(1, y, 1)
        
    y_one_hot = y_one_hot.cuda()
    y_one_hot = y_one_hot.view(outsize)
    return y_one_hot

You can check the backward pass is called by putting a set_trace in it (import pdb; pdb.set_trace()). I don’t think your weight matrix should have requires_grad=True unless you plan to have it get updated (in which case you should make it part of your model so that the optimizer knows it has to update it).

Thank you for your initial response.

I added the set_trace to the backward pass and can confirm it is not called. I didn’t think the requires_grad=True was needed, but tried that as a last resort because I was out of other ideas.

I also checked the score tensor being returned in the forward pass and that has a grad_fn associated with it if that is relevant.

Taking a step back, do you believe the framework should be able to automatically compute the gradient for this loss function and there is an implementation issue with the loss function prohibiting the auto gradient computation? Or is there something fundamental about the design of the loss function (such as multiplying by the one hot encoded tensor) that makes it impossible for the framework to compute the gradient on its own?

I also tried using the CostMatrixLoss function (with a cost tensor of size 32 x 32) in the lesson3-camvid notebook. While this loss function is not appropriate for that data set and classes, I was able to verify that behavior is similar there with the valid_loss and metric not updating during training.

And does it train with the standard loss? It seems really weird you don’t get gradients.

Yes, the camvid notebook trains with the standard loss.

I noticed in the pytorch examples they use a context object to save information from the forward pass to the backward pass. I’m guessing that this is something that fast.ai is taking care of behind the scenes? I was poking around the fast.ai code to see if I could find where this happens, but haven’t found anything yet. Could it be that the cost matrix isn’t known to the backward call, and that is why no gradient is being used?

There is a reason it’s called autograd: the gradients are computed automatically, without you having anything to do. Check your forward pass is called properly too (maybe the loss function isn’t called at all). Then try manually on an output to compute your custom loss and then call backward.

I have verified that the loss function is called by setting breakpoints and stepping through the code.

Thank you for the suggestion. Would you please explain a bit more what is meant by “manually on an output to compute your custom loss and then call backward”? Is there a code snippet you can point me to?

Well call your model on a batch of inputs to get the output, then feed it to your loss function with the targets and then call backward.

@turntwo463 were you able to get it to work?