You are correct. This is a good workaround
Alright that makes sense! Could you clarify what you mean by full callbacks before I start working on the PR?
BTW for an example of one way to set up an ordering of classes, see how we do that with subclasses of Transform
in fastai.
For now, the list of callbacks is
[callback_fns(learner)] + [self.callbacks] + [callbacks passed to fit]
with Recorder being the first in the list of callback_fns. I meant the elements of the two last lists as full callbacks.
Great thanks!
@sgugger Iāve written a callback priority implementation though and I canāt seem to find any tests. Is there a way to test if Iāve broken anything, or Iād be happy to write some tests if nothing exists?
To check you didnāt break anything, Iād test the examples notebook. Unit tests are coming later, but for now we didnāt have time to write any.
Yes please!
After using the callback priority implemented by @jeremy and suggestion of @sgugger to combine two loss functions in backpropagation there are still some limitations. Below I will explain the scenario that I have, the workaround Iāve tried with the existing framework, and a proposed solution. Though, Iām not sure if Iām able to utilize the framework in another way to achieve what I want.
TL;DR
Add on_loss_end
to have a place to record multiple losses and combine multiple losses for backpropagation
Scenario:
This is a segmentation task, and the workflow is roughly:
- Get data from
DataLoader
inDataBunch
which returnsinput
,target_mask
,loss_weight_map
- Feed
input
into the model and themodel
will output thepredicted_mask
and animage_embedding
- Record
image_embedding
inRecorder
- Calculate unreduced loss, in this scenario Iām using
WeightedCELoss
with theloss_weight_map
and aSoftDiceLoss
- Record unreduced loss in
Recorder
, note by unreduced loss I mean the loss/pixel for theWeightedCELoss
andSoftDiceLoss
- Reduce losses
- Backpropagate, and update weights
Workaround 1: Using suggestions from Jeremy and Sylvain:
I have defined two new callbacks, SegmentationTrainer
and LossAggregator
class SegmentationTrainer(fastai.Callback):
def __init__(self, n_classes=2):
self.n_classes = n_classes
def on_batch_begin(self, last_input, last_target, **kwargs):
last_target, self.weight_maps = last_target
return last_input, last_target
def on_loss_begin(self, last_output, last_target, **kwargs):
# last_output also has the image_embeddings
model_output = last_output.get("model_output")
self.ce_loss = calculate_ce_loss(model_output, last_target, self.weight_maps)
self.soft_dice_loss, _ = calculate_dice_loss(model_output, last_target, n_classes=self.n_classes)
return model_output
def on_backward_begin(self, **kwargs):
last_loss = dict(
ce_loss=self.ce_loss,
soft_dice_loss=self.soft_dice_loss
)
return last_loss
class LossAggregator(fastai.Callback):
_order = 10
def on_backward_begin(self, last_loss, **kwargs):
ce_loss = last_loss["ce_loss"]
soft_dice_loss = last_loss["soft_dice_loss"]
total_loss = torch.mean(ce_loss) + torch.mean(soft_dice_loss)
return total_loss
Note here that Iāve defined the loss_fn
being passed into the fit
function to be: lambda out, *yb: torch.zeros(1)
to work around the loss.detach()
in the on_backward_begin
function in the CallbackHandler
so I donāt use it at all. So I havenāt completely followed Sylvainās suggestion as I felt that not having both loss functions in the same place felt confusing. As a result, I needed a LossAggregator
to reduce the losses before backpropagation.
The problem with this above implementation is since the Recorder
is called before the SegmentationTrainer
it is unable to record the unreduced losses calculated in on_backward_begin
of the SegmentationTrainer
. However, if we choose to prioritize the SegmentationTrainer
before the Recorder
, then the Recorder isnāt able to record the image_embeddings
as the on_loss_begin
of SegmentationTrainer
only returns the output of the model.
Moreover, the LossAggregator
seemed hacky, and is a big price to pay for making the code more understandable by putting both loss functions in the same place.
Workaround 2: Defining a callable Loss class
Instead of calculating the loss in the SegmentatorTrainer
we could instead define a callable class that wraps all the losses for example:
class SoftDiceLoss:
def __init__(self, dice_loss_weights=None, n_classes=2):
self.dice_loss_weights = dice_loss_weights
self.n_classes = n_classes
def __call__(self, out, *yb):
prediction = out.get("model_output")
target = yb[0]
loss = calculate_dice_loss(prediction, target, self.dice_loss_weights, self.n_classes)
return loss
class WeightedCELoss:
def __call__(self, out, *yb):
prediction = out.get("model_output")
target = yb[0]
weight_map = yb[1]
ce_loss = calculate_ce_loss(prediction, target, weight_map)
return ce_loss
class LossWrapper:
def __init__(self, losses):
self.losses = losses
def __call__(self, out, *yb):
ret = {}
for loss in self.losses:
ret[loss.__class__.__name__] = loss(out, *yb)
return ret
Everything works perfectly, if we do this, with one exception, the self.smoothener.add_value(loss.detach())
in the on_backward_begin
. Since the loss wrapper is going to return a list (or dict in my case) of Tensors, it wouldnāt make sense to do loss.detach()
on it.
I propose that we add another callback function called on_loss_end
which is called after the calculation of the loss. This would not affect any of the existing code, but it would allow us to deal with the recording of multiple loss functions, and even the combination of multiple loss function.
if not loss_fn: return out.detach(),yb[0].detach()
loss = loss_fn(out, *yb)
loss = cb_handler.on_loss_end(loss) # new line
mets = [f(out,*yb).detach().cpu() for f in metrics] if metrics is not None else []
For example, in our SegmentationTrainer
we just need to define the on_loss_end
method to reduce the losses:
class SegmentationTrainer(fastai.Callback):
def on_loss_end(self, last_loss, **kwargs):
ce_loss = last_loss["WeightedCELoss"]
soft_dice_loss, _ = last_loss["SoftDiceLoss"]
total_loss = torch.mean(ce_loss) + torch.mean(soft_dice_loss)
print(f"Loss is: {total_loss.item()}")
return total_loss
And in the Recorder
we could save the unreduced losses.
I have made a PR, and made sure that all the notebooks still run correctly!
Hopefully this will make the framework a lot more flexible and be able to deal with @TheShadow29 of recording the loss for SSDās easier
I still think this is unnecessary (and Iām in an implementation of the SSD myself). Letās be more simple and create one loss function like this:
class CombinedLoss(nn.Module):
def forward(output, target):
self.loss1 = ...
self.loss2 = ...
return self.loss1 + self.loss2
Since this returns a rank0 tensor, there is no need for an on_loss_end callback, but it also recorded the losses so we can access themā¦ in a callback
We can then define a RecordSeparateLosses callback that takes the learner, and on_batch_end for instance, we can access learn.loss.fn.loss1
and learn.loss.fn.loss2
.
Yes. I found this to be a very neat way. In fact a better way is to use the smoothvalue class (https://github.com/fastai/fastai/blob/d16dd170f8ca36b90627c7235c1635e817a78384/fastai/callback.py#L152).
So my loss looks like
class CombinedLoss(nn.Module):
def __init__(self):
super.__init__()
self.loss1_smooth = SmoothenValue(0.98)
self.loss2_smooth = SmoothenValue(0.98)
def forward(output, target):
self.loss1 = ...
self.loss2 = ...
self.loss1_smooth.add_value(self.loss1)
self.loss2_smooth.add_value(self.loss2)
return self.loss1 + self.loss2
In the callback, I can simply do on_batch_end
something like self.learn.loss_fn.loss1_smooth.smooth.item()
@sgugger maybe I am misunderstanding you, but I think this only partially solves my problem.
What youāre suggesting is that we save the unreduced losses inside the CombinedLoss
so that we can record it later in the on_batch_end
method is that correct? This would work for the training loop, but it would fail on the validation loop as there is no on_batch_end
inside the validation loop.
Moreover, again maybe this is just personal preference, I feel that having the loss function record, itās own loss old to be accessed later still seems kind of hacky as it violates the SRP.
Though the idea of having a separate callback record the inputs/outputs of the model makes a lot of sense!
The answer might very well be to put on_batch_end
at the end of loss_batch
so thatās itās also called during validation (with a is_training
parameter to know the mode). I just donāt see the point of adding the on_loss_end
as itās the same as on_backward_begin
.
As for the loss function keeping track of internal parameters that led to its result, I donāt see how it violates the SRP.
Ahh yes, I guess putting on_batch_end into loss batch would make sense would make sense as on_loss_endās only job there would be to record the losses
I guess youāre right, in my mind I just thought that the loss function should be stateless and only be responsible for calculating the loss and not recording any state.
Iāll work on it later this evening and weāll see if it solves this issue, then.
I can put in a PR to move the on_batch_end
into the validate
function
Iād like to handle it, as I have other things that go a bit with it in mind (to be able to print the two losses for instance, or new metrics that arenāt means over batches).
Cool looking forward to it!