I’m doing a personal multi classification project with the custom metric from the book:
def accuracy_multi(inp, targ, thresh=0.9, sigmoid=True):
if sigmoid: inp = inp.sigmoid()
return ((inp>thresh) == targ.bool()).float().mean()
But my metric doesn’t seem to have any impact on my models predictions. Am I mistaken thinking that the metric influences what’s predicted? I would have imagined if I call
learn.metrics = partial(accuracy_multi, thresh=0.95, sigmoid=False)
than the probabilities would be compared to 0.95, and considered predicted if they are greater than that. But the ClassificationInterpretation is considering a label as predicted if it’s probability is over 0.6 or so. I’m definitely misunderstanding something here.
For further information on my model build:
dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),
item_tfms = RandomResizedCrop(448, min_scale=0.8))
dls = dblock.dataloaders(murals_df, bs=3)
learn = cnn_learner(dls, resnet34, metrics=partial(accuracy_multi, thresh=0.95, sigmoid=False))
It’s a metric and does not impact the prediction in any way. Metrics are calculated from the predictions that are generated.
To add onto that:
Metrics: a way for humans to comprehend how a model is performing. Purely a viewing perspective
Loss functions: the actual way in which the gradients are being updated (in this case it’s
So my model is using some presets as a threshold for predictions? I guess I don’t understand how/when my model decides an output is confident enough to be a prediction.
This is coming from interp.plot_top_losses() by the way. I guess I wasn’t clear on that. The ClassificationInterpretation considers a label’s probability as “predicted” if that respective probability is above some threshold? Not sure where that threshold is coming from.
Not sure where that threshold is coming from.
I believe it’s compared to the output of the softmax and is true if the value >= 0.5.
Not sure if this is correct, but I used the following when using different thresholds
metric_precision = partial(PrecisionMulti, thresh=0.7)()