Hi! I’m having a little trouble understanding how to create a torchtext dataset that supports multiple labels. Per this thread (Creating a ModelData object without torchtext splits?), I tried creating a custom dataset, but I’m not sure I’m doing the right thing with the labels.
In the code below, I’ve created a different field and entry for each label. Is this correct, or should I just be creating one Label field and a list of labels to enter into the dataset Example along with the text? Right now, the below isn’t working because TextData.from_splits is expecting a Label field, but I’m wondering if I’m close. Thanks for any pointers! (Note, I’m also pulling the data from a dictionary of dataframes (dfs), as opposed to a directory).
class ToxicCommentDataset(torchtext.data.Dataset):
def __init__(self, path, text_field, label1_field, label2_field,
label3_field, dfs, **kwargs):
fields = [('text', text_field), ('Label1', label1_field), ('Label2', label2_field),
('Label3', label3_field)]
examples = []
for i in range(dfs[path].values[:,1].shape[0]):
text = dfs[path].comment_text.iloc[i]
Label1 = None
Label2 = None
Label3 = None
if 'Label1' in dfs[path]:
Label1 = dfs[path].Label1[i]
Label2 = dfs[path].Label2[i]
Label3 = dfs[path].Label3[i]
examples.append(data.Example.fromlist([text, Label1, Label2, Label3], fields))
super().__init__(examples, fields, **kwargs)
@staticmethod
def sort_key(ex): return len(ex.text)
@classmethod
def splits(cls, path, text_field, label1_field, label2_field,
label3_field, train, val, test, dfs, **kwargs):
return super().splits(path,
text_field=text_field, label1_field=label1_field, label2_field=label2_field,
label3_field=label3_field, train=train, validation=val, test=test, dfs=dfs, **kwargs)