Multi-label text classification

(Rafael Valdivia) #1

Hey guys,

Do you know if can handle multi-label text data, as we did using CSV’s with ‘Planet: Understanding the Amazon from Space’ Kaggle competition.

In other words, what I have is multiple chat sessions with labels indicating the topics that were discussed there. I would like to arrange them in such way that can ‘from_XXX’ get my data and attempt classification.

Let me know if anything is unclear.

thank you.

(Yeshar Hadi) #2

I’m unsure why this hasn’t gotten any responses. I’d imagine someone used to find a solution for the toxic comment Kaggle competition, so I’m going to look there now.

(Yeshar Hadi) #3

Looks like this person did it. :slight_smile:


Hello! I have the same question for multi-label text classification but I would like to apply fastai.text.

I replace in section Classifier tokens from Lesson 10 the number of classes:

# tok_trn, trn_labels = get_all(df_trn, 1)
tok_val, val_labels = get_all(df_val, 166)

and in the section Classifier

c = int(trn_labels.shape[1])

I get an error:

RuntimeError: multi-target not supported at /opt/conda/conda-bld/pytorch_1512387374934/work/torch/lib/THCUNN/generic/

Then I tried to replace the loss function by adding:

learn.crit = F.binary_cross_entropy_with_logits

and I get another error:

RuntimeError: Expected object of type Variable[torch.cuda.FloatTensor] but found type Variable[torch.cuda.LongTensor] for argument #1 'other'

Any ideas?

How I can do multi-label text classification with fastai.text?

did you solve your problem? I am also having that problem now.


no, I switched to other task


@Haotian - two more changes fix the errors:

  • convert the labels to floats to remove the last error:

    trn_labels = np.squeeze(np.load(CLAS_PATH/‘tmp’/‘trn_labels.npy’)).astype(float)
    val_labels = np.squeeze(np.load(CLAS_PATH/‘tmp’/‘val_labels.npy’)).astype(float)

  • change the accuracy metric from accuracy to accuracy_thresh to account for the different format of the labels:

    learn.metrics = [accuracy_thresh(0.5)]

For further improvement the model created with get_rnn_classifier can be adjusted or the loss function F.binary_cross_entropy_with_logits changed.