Is it expected for MultiCategorize to encode str/ tuple?

Not sure if it is a bug so I didnt post it in github but instead post it here. I encountered a situation where MultiCategorize return unexpected outputs when the input is str or tuple.

My labels are string of ids. While feeding in list of these string ids work perfectly fine. When I feed in one single string id, it treats str as iterable and then iterate each character of the string, and then try to encode each character. The same happens when I feed in tuple of a string id. As a result, I receive wrong outputs when I input str/ tuple.

Illustrated my point with the example below:

_vocab=list(map(str, range(24)))
_tfms = MultiCategorize(vocab=_vocab)

# expected behavior
_tfms(['21'])
>> TensorMultiCategory([21])

# unexpected behavior
_tfms('21')
>>TensorMultiCategory([2, 1])

# unexpected behavior
_tfms(('21'))
TensorMultiCategory([2, 1])

When I look through the doc, MultiCategorize typically expect list, or its subclasses as inputs. I think it may create confusion if it silently accepts str/ tuple coz it may return wildly different outputs as shown above.
In this case, Is it better to raise a warning/ error when user feed in tuple/ str, rather than silently passing it through?

The code above is what is determining this behavior. Fastai treats tuples as a β€œspecial” class and recursively looks through it to find β€˜21’. It then considers β€˜21’ a list of β€˜2’ and β€˜1’.

This is actually undocumented functionality here, but shows up in the core of fastai. It would be good to add the functionality to the docs actually. Below is an example of something like applying multiple categories, such as having a model that outputs more than one prediction.

_vocab=list(map(str, range(24)))
_tfms = MultiCategorize(vocab=_vocab)
_tfms((['21','3','9'],['21','11','8']))
#>>> (TensorMultiCategory([21,  3,  9]), TensorMultiCategory([21, 11,  8]))

Or if we want to model more complex labels:

_vocab=list(map(str, range(34)))
_tfms = MultiCategorize(vocab=_vocab)
_tfms((['11'],(['21'],['22'],((['31'],),))))

#>>> (TensorMultiCategory([11]),
(TensorMultiCategory([21]),
TensorMultiCategory([22]),
((TensorMultiCategory([31]),),)))

1 Like