CrossEntropyFlat not defined? FastAi-2. Any Idea?

Hallo guys,
I tried to classify the arabic characters and achieved aprox 95% accuracy.
As this is a Multiclass Classification problem I used cross entropy for my loss function.

   learn = cnn_learner(dls, resnet18, pretrained=False,
                    loss_func=F.cross_entropy, metrics=accuracy, model_dir="/tmp/model/")

Now, when I wanted to look at the results I saw that the prediction cotains all classes:

arabic_char

I figured out that the problem is the wrong loss function. Instead of using cross_entropy i should use CrossEntropyFlat() (which fastai picks by default)

Question:

  1. Could someone please explain why I cannot use cross entropy?
  2. How to manually set CrossEntropyFlat? I tried:
  • learn.loss_func = CrossEntropyFlat()
  • learn.loss_func = torch.CrossEntropyFlat()
  • learn.loss_func = nn.CrossEntropyFlat()
    -> Which always throws errors

Best regards
Anel

If it’s multi-class you should use BCEWithLogitsLossFlat and accuracy_multi, as the two you’re describing only work if one label is present, not multiple. An example is the planets dataset

It is multi-class but not multi-label. Each target has only one valid class. In total I’ve got 28 classes. The dataset is basically mnist with arabic chars, so each image contains one character.

1 Like

Got it, I follow :slight_smile: So it’s CrossEntropyLossFlat(), so you may be calling the wrong thing. Pass that to the loss_func in cnn_learner and you’re golden (though it should do this by default).

1 Like

Yes, you are right. FastAI picks CrossEntropyFlat by default, however, I’d like to set it manually.
I really appreciate you taking time to help and l don’t mean to be rude, but please first read the whole question. I tried to formulate it as detailed as possible:

What I would love to know is:

  • Why can’t I use crossentropyloss but MUST use CrossEntropyFlat?
    (CrossEntropyLoss worked for the pet breeds problem)
  • How to set CrossEntropyFlat() manually?
    I tried:
    learn.loss_func = CrossEntropyFlat()
    learn.loss_func = torch.CrossEntropyFlat()
    learn.loss_func = nn.CrossEntropyFlat()
    -> Which always throws errors

@muellerzr could i maybe assist you with additional information? :grinning:

So I was able to find out why i could not set the loss function.
You’ve got to make sure that at the top of your notebook you run the following cell:

from fastai.vision import *

hence running

from fastai2.vision.all import *

will not work.

@sgugger should I create an issue on github or is this known?

.all is the proper way to import for now. If you check the init file in fastai.vision (what is called when you do this), you’ll notice it’s blank so you don’t important anything at all.

How come CrossEntropyFlat() is only defined if I import
from fastai.vision import*

I restarted the Kernel and tried it multiple times.

I’m not sure. I can get this to work just fine on my end:

from fastai2.vision.all import *

CrossEntropyLossFlat()

Are you sure you’re doing CrossEntropyLossFlat and not just CrossEntropyFlat?

Also, our CrossEntropyLossFlat simply calls nn.CrossEntropy with some argmax and softmax’s to be called depending on the scenario

1 Like

If I use CrossEntropyLossFlat() it works without:

from fastai.vision import *

Which is fine for me. I still wonder why CrossEntropyFlat works only with the import mentioned above.
However, I will keep on using CrossEntropyLossFlat().

Could you please clear the last confusion on why using CrossEntropyLoss() the Intepretation looks like:

arabic_char

Whereas when I use CrossEntropyLossFlat() it looks correct like:

Is this nn.CrossEntropyLoss?

Edit: if so, it’s most likely related to the fact that CrossEntropyLossFlat contains a decodes that runs a softmax over the outputs, giving is that one class vs all of them like you show

1 Like

Yes it is
loss_func = nn.CrossEntropyLoss()

Okey, cool your answer makes sense.
So that means whenever I want to classify multiple classes (single label per Image) I use CrossEntropyLossFlat to rank 1 class against the others.
I should never use CrossEntropyLoss() if I have more than two classes.

This also explains why it worked in the 3 vs 7 mnist notebook. In that case we only had two classes.