I am now much closer to making things work. Hopefully. I will write some updates here for those interested. Any feedback is obviously welcome!
So the first step is obviously following the notebook on text transfer learning right until Classifier. Make sure to follow the latest version running on v1!
It is possible to use text_classifier_learner
. Note that a model for multi-class (single-label) classification and multi-label classification only differ in the number of outputs! (Because the softmax or sigmoid activation is applied by the loss function). n_class
in text_classifier_learner
will take the correct value (the number of labels, not the number of “classes”: this variable should probably be renamed).
I update the learner’s loss function in text_classifier_learner
, by adding
learner.loss_func = BCEWithLogitsFlat()
This is probably not needed [not needed indeed!], but printing the type is not specific enough [it is: print learner.loss_func.func], so I will leave this in untill I am sure it is not needed.
Next we need a data loader. I am using the API for this:
self.data = (TextList.from_csv(path,
'multi_label.csv',
cols='text',
vocab=self.vocab)
.random_split_by_pct(valid_pct=0.2)
.label_from_df(cols=[0, 1])
.databunch(bs=self.batch_size))
My data is in a csv file, where (this is an example) the labels are in the first two columns (the column name is the label name, the cell value is a float, 0 or 1 in my case). The column “text” has the plain text (unprocessed). Don’t forget to pass here your vocabulary from the language model! I am using random split (you may use a bool column for this as well).
With these ingredients I am able to at least instantiate the classifier:
classifier = text_classifier_learner(self.data, drop_mult=0.5, metrics=[fbeta])
And this is as far as I’ve got so far. Now I need to see if training happens, and I will need to look into the metrics method, and probably create a custom one. [After the edits this code should be good enough!]
Note: I have edited this post to include suggestions by @sgugger, thanks!