I am working on an image segmentation problem where classes near to one another in numeric value are similar. I am trying to implement a custom loss function which uses a cost matrix to penalize predictions that are further from the target class more than predictions which are near.
I am trying to model the custom loss function based on the documentation from Pytorch here: https://pytorch.org/docs/stable/notes/extending.html?highlight=extend%20autograd
In the custom loss function the cost matrix will be of this form:
tensor([[0., 1., 2., 3., 4., 5.],
[1., 0., 1., 2., 3., 4.],
[2., 1., 0., 1., 2., 3.],
[3., 2., 1., 0., 1., 2.],
[4., 3., 2., 1., 0., 1.],
[5., 4., 3., 2., 1., 0.]], device='cuda:0', requires_grad=True)
During training in my implementation the valid_loss and metrics never update which I believe signifies that the autograd framework cannot compute the gradient on its own. I tried adding an implementation for backward, but that doesn’t appear to be called.
I would appreciate some guidance on the following questions:
- Is there a way to update the implementation below or an equivalent alternative approach so that the framework is able to compute the gradient automatically?
- If not, what changes are needed to provide a gradient calculation that will be called by the framework?
Details below:
Loss function
class CostMatrixLoss(torch.nn.Module):
def __init__(self, cost):
super(CostMatrixLoss, self).__init__()
self.cost = cost
def forward(self, input, target):
cmf = CostMatrixFunction()
score = cmf.forward(input, target, self.cost)
return score
class CostMatrixFunction(torch.autograd.Function):
def __init__(self):
super(CostMatrixFunction, self).__init__()
# Note that both forward and backward are @staticmethods
@staticmethod
def forward(input, target, cost):
#ctx.save_for_backward(input, cost)
preds = torch.softmax(input, 1)
outsize = preds.size()
num_classes = outsize[1]
targ_hot = to_one_hot(target, outsize, num_classes)
# mtrx multiply predictions by cost to penalize classes further from target
preds_cost = preds @ cost
# multiply by one hot encoding to keep only the 1 cell/row which gives the cost
final_cost = preds_cost * targ_hot
score = final_cost.sum() * cost.size(0)/outsize[0]
return score
@staticmethod
def backward(grad_out):
print('Custom backward called!')
#input, cost = ctx.saved_tensors
grad_input = grad_out @ cost_tensor
return grad_input, None
Invoking code
cost_matrix = create_cost_matrix(1, num_classes, epsilon)
cost_tensor = torch.from_numpy(cost_matrix).cuda()
#cost_tensor.requires_grad=True
learn = unet_learner(data, models.resnet34, metrics=metrics, wd=wd)
learn.loss_func = FlattenedLoss(CostMatrixLoss, axis=1, cost=cost_tensor)
learn.fit_one_cycle(10, slice(lr))
def to_one_hot(y, outsize=None, n_dims=None):
y = y.view(-1, 1)
n_dims = n_dims if n_dims is not None else int(torch.max(y)) + 1
zeroes_t = torch.zeros(y.size()[0], n_dims, device=y.device)
y_one_hot = zeroes_t.scatter_(1, y, 1)
y_one_hot = y_one_hot.cuda()
y_one_hot = y_one_hot.view(outsize)
return y_one_hot