Multilabel Classification with ULMFiT


(Brian Muhia) #21

Alright, I’ll update this thread with my results later. I’m actually just working on a multilabel classification problem, and will write a new notebook in fastai v1 once I’m done with the current phase of fine-tuning and testing.


(Jan) #22

I’m trying to do multi-label classification with the pretrained language model approach. This is the important parts of my code:

train_ds = TextDataset.from_csv(DATA_PATH, name='train', 
                            classes=classes, n_labels=len(classes))
valid_ds = TextDataset.from_csv(DATA_PATH, name='valid', 
                            classes=classes, n_labels=len(classes))
data_lm = lm_data([train_ds, valid_ds], DATA_PATH)

train_ds = TextDataset.from_csv(DATA_PATH, name='train', vocab=data_lm.train_ds.vocab, 
                                classes=classes, n_labels=len(classes))
valid_ds = TextDataset.from_csv(DATA_PATH, name='valid', vocab=data_lm.train_ds.vocab, 
                                classes=classes, n_labels=len(classes))

data_clas = classifier_data([train_ds, valid_ds], DATA_PATH)

learn = RNNLearner.language_model(data_lm, pretrained_fnames=['lstm_wt103', 'itos_wt103'], drop_mult=0.5)
learn.fit_one_cycle(1, 1e-2)
learn.save_encoder('ft_enc')

learn = RNNLearner.classifier(data_clas, drop_mult=0.5)
learn.load_encoder('ft_enc')
learn.fit_one_cycle(1, 1e-2)

I have a couple of questions regarding the classification part (language modeling works fine):

  1. Which loss function should we choose and where do we specify it?
    I have tried doing learn.loss_fn = F.binary_cross_entropy_with_logits but this function does not like the data type of the targets. Think it wants it in float32 and currently they are in int’s. Should I change this somewhere?

  2. How do I get predictions once the model is trained?
    If I do learn.get_preds() I get the following error:

    ~/fastai/fastai/basic_train.py in loss_batch(model, xb, yb, loss_fn, opt, cb_handler, metrics)
    22 out = model(*xb)
    23 out = cb_handler.on_loss_begin(out)
    —> 24 if not loss_fn: return out.detach(),yb[0].detach()
    25 loss = loss_fn(out, *yb)
    26 mets = [f(out,*yb).detach().cpu() for f in metrics] if metrics is not None else []

    AttributeError: ‘tuple’ object has no attribute ‘detach’

Maybe I’m missing something crucial and then these questions aren’t appropriate but then I would love for someone to point out what I’m doing wrong. I’ve been looking around the docs and examples but still couldn’t solve it.

Thanks in advance.


(William Collins) #23

I’ve been struggling with adding a multi-label classifier (having around 2000 classes) on top of the language model as well. I’m seeing very diverse recommendations, but nothing seems to work.

Here are my main questions:

  • Does the target value need to be one hot encoded or do I just pass the index of the correct target? Certain losses seem to require different target representations. My data is not in csv form and can’t be loaded via prepackaged classes so I need to understand that formats are expected downstream.

  • What “criterion” do I use?
    If I leave it as is (“cross_entropy” is used inside RNN_Learner I think) I get a multi-target error:
    RuntimeError: multi-target not supported at /pytorch/torch/lib/THCUNN/generic/ClassNLLCriterion.cu:16
    If I change it to “binary_cross_entropy” I get an error saying my input and target are different lengths:
    ValueError: Target and input must have the same number of elements. target nelement (48) != input nelement (100368). This seems to imply that one hots are needed since 100368 = my_batch_size x num_classes.

  • The large input number above makes me think that the LinearPoolingClassifier being used in “get_rnn_classifier” is flattening tensors, but I can’t find documentation that clarifies this. Do I need to switch this out with something else?

  • Are there any complete, working examples of using a multi-target classifier on top of the language model? All I have been able to find is binary classifiers (IMDB sentiment, etc).

Thanks!