Allow for more than one output for loss and metric

(Arka Sadhu) #1

Proposal: Currently loss is simply one-dimensional tensor. However, in networks like ssd, there are multiple loss functions like regression and classification loss. Currently it is not possible to print both losses (to the best of my knowledge, please correct me if I am wrong). So the proposal is to support multiple outputs with the first output being the loss function to be considered. This would allow to use the other outputs to be printed via callbacks. Same with metrics.

Code to be changed:
In, need to add condition to check if multiple outputs get the first output. Similarly next line for metrics.


I’ve also been thinking about this. There are cases where one model can have many outputs (such as in MaskRCNN), and each output can be associated with many losses (for example when doing image segmentation, combining CE loss and Soft Dice Loss). Moreover, we might also want to have a different LR scheme for different losses.

It seems like creating a Callback class to calculate the losses would make sense, and instead of providing the loss_fn parameter in loss_batch function, pass in a function/dict that maps the output of the model to the correct loss Callback class.

(Arka Sadhu) #3

The idea of different lr schemes for different loss functions didn’t occur to me, and I haven’t really seen it being used anywhere in particular. Would be a nice experiment to see if that actually gives some better results.

(Kerem Turgutlu) #4

Wouldnt having different lr settings have a similar effect to SoftDice + alpha*BCE where alpha is a hyperparameter you may tune ?

(Arka Sadhu) #5

This would be true only if the schedule is same. Say you have cosine annealing in one, and linear decay in another. No way to tune it via only the alpha parameter

(Kerem Turgutlu) #6

That might possibly make sense in a multitask setting i guess. But for a single task like segmentation adding losses should be fine but it’s an interesting area to dig deeper for sure.