Fastai RNNLearner.classifier wrong size of head of classifier

Yesterday, i made the following PR on github :

As Jeremy pointed out, there was not enough information in my PR. So I decided to make this post to be more precise about the issue. I created a notebook to reproduce the issue, you can find it on this link :

The issue is the following. Let’s imagine that you’re dealing with a multilabel/multiclass text classification problem, and that your dataset has one column for each label. When creating a TextClasDataBunch, you can specify the number of labels, but when creating the RNNLearner.classifier, you have an error. Your classifier head is made for two labels.

You can easily see in the source code(in text/ where the error is coming from, the size of the output layer of the classifier is n_class where :
n_class = (len(ds.classes) if (not is_listy(lbl) or (len(lbl) == 1))
else len(lbl))
and ds = data.train_ds where data is your data_bunch and lbl = ds.labels[0]

As you can see in the example in my notebook, the labels are in an array, not a list, and so
n_class = len(ds.classes)

And ds.classes has only two elements, this is because in text/, when creating a TextDataSet, if classes is not specified or not in files, then :

self.classes = np.unique(self.labels)

and since our labels are only 0 and 1, this has size 2.

There are multiple fixes possible for that :

  • Change the initialization of classes in TextDataSet when not specified, so it can have the right size when there is one column by label.
  • Accept array type of lbl when choosing n_class
  • Add a n_labels attribute to either TextDataSet Class or TextClasDataBunch Class and uses it in the classifier.

The last fixes seemed the most natural to me, and that’s what I submitted in my PR.
I would love to hear your thoughts about this issue and about the best fix for it.

1 Like

My preference would always be to infer the number of labels from the data instead of asking the user to pass another attribute. Maybe the problem lies in the is_listy() test and it needs something more general like isinstance(...,Iterable)? A numpy array would pass that test.

1 Like

We actually just closed a PR suggesting that - it would break our code. If we want to test somewhere whether something is iterable, we should simply use isinstance(o,Iterable) or similar.

I didn’t suggest to change the code inside is_listy, merely to not use is_listy in this instance :wink:


I have opened up a consolidated PR for this -
having both

  1. the test showing what was working before, and what wasn’t
  2. and the fix

fix is debatable - (a lot of things could be Iterable) - but I guess this is the cleanest/minimal change
(and we can think about it if it does create a problem in future)

<Mentioning - as he already had a running PR for a test - I just made it generic testing both single and multilabel cases>