I propose a different variant of accuracy_thresh
that can be used to calculate what fractions examples are classified perfectly in multiclass classification problems.
For the following, I will assume thresh=0.5
. The current accuracy_thresh
function calculates for each examples what fraction of classes where identified correctly. For example, in a face classification problem, if the true labels are “bald, not smiling” and the predicted labels are “bald, smiling”, then 50% of the classes where identified correctly. accuracy_thresh
computes the mean of these fractions over all examples.
Instead, it can be useful to only consider a classification a success if all the classes are identified correctly. Effectively, this means setting the above mentioned fractions to 0 if they are not 1, and then calculate the mean.
The function would look like this:
def multilabel_accuracy(y_pred, y_true, tresh):
y_pred = y_pred > tresh
y_true = y_true.byte()
return (y_pred == y_true).all(1).float().mean()
If this is deemed useful, I can create a pull request for this. It would also be possible to make this an optional parameter for accuracy_thresh
. What do you think?