NLP: Multi Target Extraction

Hi everyone, I’ve been out of the whole ml scene for a few years now and am trying to get back in, which is harder than I though with all the changes between fastaiv1 and v2. I flew through the new course and checked out the new notebooks, but I can’t quite seem to find what I need.
My current project is about extracting data from a text publication. The texts can be in various formats and are therefore hard to handle via regex etc. I thought maybe it might be possible to extract these contents with ML.
For now I managed to choose a single text block as a label and train a simple classifier, which quickly got to 98% accuracy, which is already better than the regex solution. But that was the easy case with a single target which only had 4 possible outcomes. Overall I would have around 20 labels/datapoints, some of which are not finite. Is it possible to address this kind of task with ml? If yes how would I approach it, multiple models for every label or is it possible to have one model for all labels? And how would extraction for non finite labels like names work?

tl;dr
Can I classify text to multiple targets? If yes how and how do I handle non-finite labels?
Thanks for the help!

Okay, so I figured out that I need to use MultiCategoryBlock instead of CategoryBlock for this, but now my input and output batch size don’t match anymore. I’m working with bs=4 and have len(labels)=19, which results in an expected output size of 76. As I read in other forum posts, I assume I’ll need a linear decoder in my model to adjust the output, but I don’t quite get how to add that yet.
This is my current setup:

model_name = "facebook/mbart-large-50"
model_cls = AutoModelForSequenceClassification
hf_tok_kwargs = {'src_lang': 'de_DE', 'tgt_lang': 'de_DE'}

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(model_name,  
                                                              model_cls=model_cls,  
                                                              tokenizer_kwargs=hf_tok_kwargs, 
                                                              config_kwargs={'num_labels': len(labels)})

datablocks:

blocks = (HF_TextBlock(hf_arch, hf_config, hf_tokenizer, hf_model, max_length=256),     MultiCategoryBlock)
dblock = DataBlock(blocks=blocks, get_x=ColReader('publication_content'), get_y=get_y, splitter=ColSplitter())

dls = dblock.dataloaders(df, bs=4)

Shapes:

b = dls.one_batch()
len(b),b[0]["input_ids"].shape, b[1].shape

(2, torch.Size([4, 256]), torch.Size([4, 19]))

Bartsplitter, copied from blurr

def mbart_splitter(m):
model = m.hf_model if (hasattr(m, 'hf_model')) else m

  embeds_modules = [
    model.model.encoder.embed_positions, 
    model.model.encoder.embed_tokens,
    model.model.decoder.embed_positions, 
    model.model.decoder.embed_tokens
  ]

  embeds = nn.Sequential(*embeds_modules)
  groups = L(embeds, model.model.encoder, model.model.decoder, model.classification_head)
  return groups.map(params).filter(lambda el: len(el) > 0)

Metrics

precision = PrecisionMulti(average='macro')
recall = RecallMulti(average='macro')
f1 = F1ScoreMulti(average='macro')

learn_metrics = [accuracy_multi, precision, recall, f1]
learn_cbs = [HF_BaseModelCallback]

Learner

model = HF_BaseModelWrapper(hf_model)
learn = Learner(dls, model, opt_func=Adam, loss_func=CrossEntropyLossFlat(), metrics=learn_metrics, cbs=learn_cbs, splitter=mbart_splitter)
learn.freeze()

Now when I check the shapes:

xb, yb = dls.one_batch()
out = model(xb)
xb['input_ids'].shape, out[0].shape

torch.Size([4, 256]), torch.Size([4, 19])

Interestingly when I run learn.lr_find() I get the error:

Expected input batch_size (4) to match target batch_size (76).

Shouldn’t I get an expected target batch_size of 19 here?

am I making some grave mistake here or is it just the missing linear decoder?

For multilabel classification you got to use BCEWithLogitsLoss rather then CrossEntropyLoss. CrossEntropy expects single label per input item, the targets are flattened and treated as it is a batch of 76 items.

1 Like

Thank you so much, that fixed it! :slight_smile: