ClassificationInterpretation with learner with focalloss

I’m playing with the tutorial in the beginner section for tabular data. I tried to change the loss function, so I used the focal loss like this

learn_fl = tabular_learner(dls_fl, metrics=[accuracy],
                          loss_func=FocalLossFlat())

When I train the model using the same procedure as the tutorial and I try to execute the following line

interp = ClassificationInterpretation.from_learner(learn_fl)

I get this error RuntimeError: shape '[1024, -1]' is invalid for input of size 1

I think the error has something to do with the reduction parameter of loss functions but I don’t know how to get around it.

As an aside, I tried to execute one of the lines shown in the error message:

learn_fl.get_preds(dl=dl, with_input=True, with_loss=True, with_decoded=True, act=None)

and I get this error: IndexError: tuple index out of range. When I set with_loss=False, this line works just fine.

Moreover, when replacing FocalLossFlat() by CrossEntropyLossFlat(), everything works just fine.

Does anyone know the way to solve this? Thanks

3 Likes

I have a same error …will try to find solution

I recently posted a solution for this in another thread.

i tried with_loss = True but didnt work

In the linked post, I included some code for a new class FocalLossFlatten() that works with both ClassificationInterpretation() and learn.get_preds(with_loss=True).

1 Like

Hi thanks you so much I will check it and let you know

Would you mind try setting FocalLossFlat(reduction=None) please, for me it fixes get_preds and the interpreter: