Hi all,
I was wondering if there would be support for auxiliary inputs which aren’t used in the forward pass of the neural network but only in the backward pass.
For example, when doing segmentation, we might want to have a weight map that weights pixels close to two masks more heavily than a pixel that is far away in the corner (e.g. in the UNet paper), and also use two different kinds of losses e.g. SoftDice and CrossEntropy. Moreover, in a more general sense, the relationship between model outputs and losses is many to many. So it would be difficult for the current fastai framework to handle it.
I’ve been working a local copy of a fastai-like callback framework from 3 weeks ago so it’s a bit outdated, where I have implemented the above functionality by injecting the input data from the DataLoader
(which would be the input image, weight maps, masks) and the outputs of the model (the predicted probability map for each class per pixel) into the state_dict
. I also created a new callback method called on_loss_calculate
which takes the place of loss = loss_fn(out, *yb)
in line 25 of basic_train.py
So after the forward pass the state_dict
would have something like:
{
...other key-value pairs e.g. epoch, num_iter ...
'input': <input image for model>
'weight_map': <input weight map for model>
'output': <output of model>
'target': <ground truth>
}
Then instead of using loss_fn
, I’ve made each loss a Callback
for example:
class CrossEntropyLossCallback(Callback):
def __init__(self, weight=1):
# weight of loss when calculating weighted sum of loss
self.weight = weight
def on_loss_calculate(self, **kwargs):
output = kwargs['output']
weight_map = kwargs['weight_map']
target = kwargs['target']
loss = calculate_ce_loss(output, target, weight_map, weight=self.weight)
self.loss = loss
return loss
def on_backward_begin(self, **kwargs):
return torch.mean(self.loss)
Then the output of on_calculate_loss
would be appended to an array in the CallbackHandler
, and the loss tensors would be summed up in on_backward_begin
by the CallbackHandler
to be returned to the optimizer.
Though, I’m deciding between the above which makes the code kind of brittle as a change in the one of the keys in the output dict of the Dataset
would break it, so I’ve also experimented with a more Redux-like
implementation where the attributes are directly injected into the class by the CallbackHandler
:
class CrossEntropyLossCallback(Callback):
def __init__(self, state_to_attr_dict, weight=1):
super().__init__(state_to_attr_dict)
self.weight = weight
def on_loss_calculate(self):
loss = calculate_ce_loss(self.output, self.target, self.weight_map, weight=self.weight)
self.loss = loss
return loss
def on_backward_begin(self):
return torch.mean(self.loss)
Where the state_to_attr
is a dictionary that maps the required values in the state to our instance variables, so we don’t need to use **kwargs
.
I haven’t had a chance to read the new fastai code, nor have I familiarized myself with the lesson written so I’m not sure the best way to incorporate these changes into the framework. Though it seems like replacing loss_fn=[CrossEntropyLossCallback]
in the initialization of the Learner
class in line 95 of basic_train.py may be sufficient?
Would love to know everyone’s thoughts. Also, the same idea above could be used to calculate the metrics as the relationship between outputs and metrics are also many to many.