Multi-label classification - Help deciphering multilabel-indicator error

Hey there! So I’m currently working on an image classifier between weeks 2 and 3 and I’m trying to set up a multi-label classification problem.

I have a small spreadsheet containing pokemon and their types. As an example, a row can look like the following: (Apparently I can’t upload .csv files)

123,Scyther,Bug Flying

I run the following to grab the csv file, and match each one up to its respective sprite that I have in the folder images:

src = ImageDataLoaders.from_csv(path, csv_fname='pokemon_ss_stripped.csv', delimiter=',', folder=path/'images', suff='.png', valid_pct=0.2, label_col=2, label_delim=' ', seed=42, item_tfms=Resize(128))

After running the above I get something like this. All seems to be going according to plan.

Since I’m trying to do multi-label classification, I need to pull in my metrics:
from fastai.metrics import accuracy_multi, FBetaMulti, partial

acc_02 = partial(accuracy_multi)
f_score = partial(FBetaMulti(1.5))

I create my learner:
learn = cnn_learner(src, resnet18, metrics=[acc_02, f_score])

I check my vocab to verify it has what I expect:
learn.dls.vocab

looking good so far. 16 labels representing each of the 16 types. (At the time, anyways)

With everything seeming alright, I try to start training:

learn.fine_tune(4)

I an error message with a pretty healthy stacktrace: Classification metrics can't handle a mix of multilabel-indicator and continuous-multioutput targets. How can I decipher this, and where would I have gone wrong? Is there some sort of conversion necessary on my data, or am I just looking at the wrong thing? As it stands everything seems to be behaving how I expect up to the point of failure.

I also checked the topics on this forum for the error message and couldn’t find much that seemed relevant. If it has been answered before feel free to point it out to me.

Thanks in advance!

2 Likes

I guess from f_score line partial removal will remove the problem. So instead of partial(FBetaMulti(1.5)) --> only f_score = FBetaMulti(1,5) should work