chsafouane
(Safouane Chergui)
March 8, 2021, 10:55pm
1
I’m playing with the tutorial in the beginner section for tabular data. I tried to change the loss function, so I used the focal loss like this
learn_fl = tabular_learner(dls_fl, metrics=[accuracy],
loss_func=FocalLossFlat())
When I train the model using the same procedure as the tutorial and I try to execute the following line
interp = ClassificationInterpretation.from_learner(learn_fl)
I get this error RuntimeError: shape '[1024, -1]' is invalid for input of size 1
I think the error has something to do with the reduction parameter of loss functions but I don’t know how to get around it.
As an aside, I tried to execute one of the lines shown in the error message:
learn_fl.get_preds(dl=dl, with_input=True, with_loss=True, with_decoded=True, act=None)
and I get this error: IndexError: tuple index out of range
. When I set with_loss=False, this line works just fine.
Moreover, when replacing FocalLossFlat()
by CrossEntropyLossFlat()
, everything works just fine.
Does anyone know the way to solve this? Thanks
3 Likes
I have a same error …will try to find solution
neuradai
(Walter Wiggins)
September 18, 2021, 2:02am
3
I recently posted a solution for this in another thread.
Ran into this same issue while using FocalLossFlat with cnn_learner.
Agree with @jhunter that the problem with ClassificationInterpretation arises in the call to learn.get_preds(with_loss=True).
Furthermore, I think the issue may trace to the definition of FocalLossFlat, which is a subclass of CrossEntropyLossFlat.
My inspection of the following code may hint at the root of the problem:
1 learn.loss_func
>>> Flattened loss of CrossEntropyLoss
2 learn.loss_func.func
>>> CrossEntropyLoss
W…
i tried with_loss = True but didnt work
neuradai
(Walter Wiggins)
September 18, 2021, 8:44pm
5
In the linked post, I included some code for a new class FocalLossFlatten()
that works with both ClassificationInterpretation()
and learn.get_preds(with_loss=True)
.
1 Like
Hi thanks you so much I will check it and let you know
tcapelle
(Thomas)
September 27, 2021, 11:37am
7
Would you mind try setting FocalLossFlat(reduction=None)
please, for me it fixes get_preds
and the interpreter: