Allow for more than one output for loss and metric


(Arka Sadhu) #21

You are correct. This is a good workaround


#22

Alright that makes sense! Could you clarify what you mean by full callbacks before I start working on the PR?


(Jeremy Howard) #23

BTW for an example of one way to set up an ordering of classes, see how we do that with subclasses of Transform in fastai.


#24

For now, the list of callbacks is

[callback_fns(learner)] + [self.callbacks] + [callbacks passed to fit]

with Recorder being the first in the list of callback_fns. I meant the elements of the two last lists as full callbacks.


#25

Great thanks!


#26

@sgugger I’ve written a callback priority implementation though and I can’t seem to find any tests. Is there a way to test if I’ve broken anything, or I’d be happy to write some tests if nothing exists?


#27

To check you didn’t break anything, I’d test the examples notebook. Unit tests are coming later, but for now we didn’t have time to write any.


(Jeremy Howard) #28

Yes please!


#29

@sgugger made a PR and tested everything on the notebooks. Haven’t written the tests yet.


#30

After using the callback priority implemented by @jeremy and suggestion of @sgugger to combine two loss functions in backpropagation there are still some limitations. Below I will explain the scenario that I have, the workaround I’ve tried with the existing framework, and a proposed solution. Though, I’m not sure if I’m able to utilize the framework in another way to achieve what I want.

TL;DR
Add on_loss_end to have a place to record multiple losses and combine multiple losses for backpropagation

Scenario:

This is a segmentation task, and the workflow is roughly:

  1. Get data from DataLoader in DataBunch which returns input, target_mask, loss_weight_map
  2. Feed input into the model and the model will output the predicted_mask and an image_embedding
  3. Record image_embedding in Recorder
  4. Calculate unreduced loss, in this scenario I’m using WeightedCELoss with the loss_weight_map and a SoftDiceLoss
  5. Record unreduced loss in Recorder, note by unreduced loss I mean the loss/pixel for the WeightedCELoss and SoftDiceLoss
  6. Reduce losses
  7. Backpropagate, and update weights

Workaround 1: Using suggestions from Jeremy and Sylvain:

I have defined two new callbacks, SegmentationTrainer and LossAggregator

class SegmentationTrainer(fastai.Callback):
    def __init__(self, n_classes=2):
        self.n_classes = n_classes
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        last_target, self.weight_maps = last_target
        return last_input, last_target
    
    def on_loss_begin(self, last_output, last_target, **kwargs):
        # last_output also has the image_embeddings
        model_output = last_output.get("model_output")
        self.ce_loss = calculate_ce_loss(model_output, last_target, self.weight_maps)
        self.soft_dice_loss, _ = calculate_dice_loss(model_output, last_target, n_classes=self.n_classes)
        return model_output

    def on_backward_begin(self, **kwargs):
        last_loss = dict(
            ce_loss=self.ce_loss,
            soft_dice_loss=self.soft_dice_loss
        )
        return last_loss
       
class LossAggregator(fastai.Callback):
    _order = 10
    
    def on_backward_begin(self, last_loss, **kwargs):
        ce_loss = last_loss["ce_loss"]
        soft_dice_loss = last_loss["soft_dice_loss"]
        total_loss = torch.mean(ce_loss) + torch.mean(soft_dice_loss)
        return total_loss

Note here that I’ve defined the loss_fn being passed into the fit function to be: lambda out, *yb: torch.zeros(1) to work around the loss.detach() in the on_backward_begin function in the CallbackHandler so I don’t use it at all. So I haven’t completely followed Sylvain’s suggestion as I felt that not having both loss functions in the same place felt confusing. As a result, I needed a LossAggregator to reduce the losses before backpropagation.

The problem with this above implementation is since the Recorder is called before the SegmentationTrainer it is unable to record the unreduced losses calculated in on_backward_begin of the SegmentationTrainer. However, if we choose to prioritize the SegmentationTrainer before the Recorder, then the Recorder isn’t able to record the image_embeddings as the on_loss_begin of SegmentationTrainer only returns the output of the model.

Moreover, the LossAggregator seemed hacky, and is a big price to pay for making the code more understandable by putting both loss functions in the same place.

Workaround 2: Defining a callable Loss class

Instead of calculating the loss in the SegmentatorTrainer we could instead define a callable class that wraps all the losses for example:

class SoftDiceLoss:
    def __init__(self, dice_loss_weights=None, n_classes=2):
        self.dice_loss_weights = dice_loss_weights
        self.n_classes = n_classes
    
    def __call__(self, out, *yb):
        prediction = out.get("model_output")
        target = yb[0]
        loss = calculate_dice_loss(prediction, target, self.dice_loss_weights, self.n_classes)
        return loss

class WeightedCELoss:
    def __call__(self, out, *yb):
        prediction = out.get("model_output")
        target = yb[0]
        weight_map = yb[1]
        ce_loss = calculate_ce_loss(prediction, target, weight_map)
        return ce_loss
        
class LossWrapper:
    def __init__(self, losses):
        self.losses = losses
    def __call__(self, out, *yb):
        ret = {}
        for loss in self.losses:
            ret[loss.__class__.__name__] = loss(out, *yb)
        return ret

Everything works perfectly, if we do this, with one exception, the self.smoothener.add_value(loss.detach()) in the on_backward_begin. Since the loss wrapper is going to return a list (or dict in my case) of Tensors, it wouldn’t make sense to do loss.detach() on it.

I propose that we add another callback function called on_loss_end which is called after the calculation of the loss. This would not affect any of the existing code, but it would allow us to deal with the recording of multiple loss functions, and even the combination of multiple loss function.

if not loss_fn: return out.detach(),yb[0].detach()
loss = loss_fn(out, *yb)
loss = cb_handler.on_loss_end(loss) # new line
mets = [f(out,*yb).detach().cpu() for f in metrics] if metrics is not None else []

For example, in our SegmentationTrainer we just need to define the on_loss_end method to reduce the losses:

class SegmentationTrainer(fastai.Callback):
    def on_loss_end(self, last_loss, **kwargs):
        ce_loss = last_loss["WeightedCELoss"]
        soft_dice_loss, _ = last_loss["SoftDiceLoss"]
        total_loss = torch.mean(ce_loss) + torch.mean(soft_dice_loss)
        print(f"Loss is: {total_loss.item()}")
        return total_loss

And in the Recorder we could save the unreduced losses.

I have made a PR, and made sure that all the notebooks still run correctly!

Hopefully this will make the framework a lot more flexible and be able to deal with @TheShadow29 of recording the loss for SSD’s easier :slight_smile:


#31

I still think this is unnecessary (and I’m in an implementation of the SSD myself). Let’s be more simple and create one loss function like this:

class CombinedLoss(nn.Module):    
    def forward(output, target):
        self.loss1 = ...
        self.loss2 = ...
        return self.loss1 + self.loss2

Since this returns a rank0 tensor, there is no need for an on_loss_end callback, but it also recorded the losses so we can access them… in a callback :wink:
We can then define a RecordSeparateLosses callback that takes the learner, and on_batch_end for instance, we can access learn.loss.fn.loss1 and learn.loss.fn.loss2.


(Arka Sadhu) #32

Yes. I found this to be a very neat way. In fact a better way is to use the smoothvalue class (https://github.com/fastai/fastai/blob/d16dd170f8ca36b90627c7235c1635e817a78384/fastai/callback.py#L152).

So my loss looks like

class CombinedLoss(nn.Module):
    def __init__(self):
        super.__init__()
        self.loss1_smooth = SmoothenValue(0.98)
        self.loss2_smooth = SmoothenValue(0.98)
    def forward(output, target):
        self.loss1 = ...
        self.loss2 = ...
        self.loss1_smooth.add_value(self.loss1)
        self.loss2_smooth.add_value(self.loss2)
        return self.loss1 + self.loss2

In the callback, I can simply do on_batch_end something like self.learn.loss_fn.loss1_smooth.smooth.item()


#33

@sgugger maybe I am misunderstanding you, but I think this only partially solves my problem.

What you’re suggesting is that we save the unreduced losses inside the CombinedLoss so that we can record it later in the on_batch_end method is that correct? This would work for the training loop, but it would fail on the validation loop as there is no on_batch_end inside the validation loop.

Moreover, again maybe this is just personal preference, I feel that having the loss function record, it’s own loss old to be accessed later still seems kind of hacky as it violates the SRP.

Though the idea of having a separate callback record the inputs/outputs of the model makes a lot of sense!


#34

The answer might very well be to put on_batch_end at the end of loss_batch so that’s it’s also called during validation (with a is_training parameter to know the mode). I just don’t see the point of adding the on_loss_end as it’s the same as on_backward_begin.

As for the loss function keeping track of internal parameters that led to its result, I don’t see how it violates the SRP.


#35

Ahh yes, I guess putting on_batch_end into loss batch would make sense would make sense :slight_smile: as on_loss_end’s only job there would be to record the losses

I guess you’re right, in my mind I just thought that the loss function should be stateless and only be responsible for calculating the loss and not recording any state.


#36

I’ll work on it later this evening and we’ll see if it solves this issue, then.


#37

I can put in a PR to move the on_batch_end into the validate function


#38

I’d like to handle it, as I have other things that go a bit with it in mind (to be able to print the two losses for instance, or new metrics that aren’t means over batches).


#39

Cool looking forward to it!


#40

Ok finished working on this. Examples of use here!