def MultiCategoryBlock(encoded=False, vocab=None, add_na=False):
"`TransformBlock` for multi-label categorical targets"
tfm = EncodedMultiCategorize(vocab=vocab) if encoded else [MultiCategorize(vocab=vocab, add_na=add_na), OneHotEncode]
return TransformBlock(type_tfms=tfm)
When creating a DataBlock with MultiCategoryBlock, we don’t set vocab. I could assume vocab gets set up with the get_y behaviour behind the scene.
However, I couldn’t find when and where the vocab gets set up. If anyone knows, could you please explain when/where the connection between MultiCategoryBlock and get_y is established in DataBlock?
Hi! I had the exact same question as you did (initially I was very happy your question came up from the google search because it was literally what I was searching).
So the first part is here:
def MultiCategoryBlock(encoded=False, vocab=None, add_na=False):
"`TransformBlock` for multi-label categorical targets"
tfm = EncodedMultiCategorize(vocab=vocab) if encoded else [MultiCategorize(vocab=vocab, add_na=add_na), OneHotEncode]
return TransformBlock(type_tfms=tfm)
Now the MultiCategorize part (with a vocab=None) initially does this:
class MultiCategorize(Categorize):
"Reversible transform of multi-category strings to `vocab` id"
loss_func,order=BCEWithLogitsLossFlat(),1
def __init__(self, vocab=None, add_na=False): super().__init__(vocab=vocab,add_na=add_na,sort=vocab==None)
def setups(self, dsets):
if not dsets: return
if self.vocab is None:
vals = set()
for b in dsets: vals = vals.union(set(b))
self.vocab = CategoryMap(list(vals), add_na=self.add_na)
def encodes(self, o):
if not all(elem in self.vocab.o2i.keys() for elem in o):
diff = [elem for elem in o if elem not in self.vocab.o2i.keys()]
diff_str = "', '".join(diff)
raise KeyError(f"Labels '{diff_str}' were not included in the training dataset")
return TensorMultiCategory([self.vocab.o2i[o_] for o_ in o])
def decodes(self, o): return MultiCategory ([self.vocab [o_] for o_ in o])
And the part that sets the vocab is this one:
if self.vocab is None:
vals = set()
for b in dsets: vals = vals.union(set(b))
Which essentially creates a set (remember we don’t want duplicate values) and then for both the train and the validation dataset it appends the unique values that it can find.