Confusion about multi classification predictions

I’m doing a personal multi classification project with the custom metric from the book:

def accuracy_multi(inp, targ, thresh=0.9, sigmoid=True):
    if sigmoid: inp = inp.sigmoid()
    return ((inp>thresh) == targ.bool()).float().mean()

But my metric doesn’t seem to have any impact on my models predictions. Am I mistaken thinking that the metric influences what’s predicted? I would have imagined if I call

learn.metrics = partial(accuracy_multi, thresh=0.95, sigmoid=False)

than the probabilities would be compared to 0.95, and considered predicted if they are greater than that. But the ClassificationInterpretation is considering a label as predicted if it’s probability is over 0.6 or so. I’m definitely misunderstanding something here.

For further information on my model build:

dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),
                  item_tfms = RandomResizedCrop(448, min_scale=0.8))

dls = dblock.dataloaders(murals_df, bs=3)

learn = cnn_learner(dls, resnet34, metrics=partial(accuracy_multi, thresh=0.95, sigmoid=False))
1 Like

It’s a metric and does not impact the prediction in any way. Metrics are calculated from the predictions that are generated.

1 Like

To add onto that:

Metrics: a way for humans to comprehend how a model is performing. Purely a viewing perspective

Loss functions: the actual way in which the gradients are being updated (in this case it’s CrossEntropyLossFlat)

So my model is using some presets as a threshold for predictions? I guess I don’t understand how/when my model decides an output is confident enough to be a prediction.

This is coming from interp.plot_top_losses() by the way. I guess I wasn’t clear on that. The ClassificationInterpretation considers a label’s probability as “predicted” if that respective probability is above some threshold? Not sure where that threshold is coming from.

Not sure where that threshold is coming from.

I believe it’s compared to the output of the softmax and is true if the value >= 0.5.

Not sure if this is correct, but I used the following when using different thresholds

metric_precision = partial(PrecisionMulti, thresh=0.7)()

1 Like