How to extract category names without looking at datasource?

Is there a way to get the category names that the learner is using without going into the data source?

Consider this code:

dls = DataBlock(
    blocks=(ImageBlock, CategoryBlock), 
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=[Resize(224, method='squish')]
).dataloaders(path, bs=32)

learn = vision_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(2)

// I upload a file...

learn.predict(img)

// The above outputs:

('electric',
 tensor(6),
 tensor([7.5948e-04, 2.1046e-03, 4.5287e-04, 6.6661e-02, 2.0419e-03, 5.7200e-04,
         7.8898e-01, 4.6769e-02, 4.2984e-02, 2.3263e-02, 9.3710e-03, 1.7955e-04,
         3.1295e-05, 1.5521e-02, 3.0948e-04]))

Iā€™d like to print every possible category with the model predictions in a human readable form, like a category name followed by a percentage of probability but need to know which tensor indexes correspond to each category.

Thanks in advance!

The order of the probabilities of each category is given by dls.vocab.

1 Like

That is exactly what I was looking for. Thank you @karen!