I propose a different variant of
accuracy_thresh that can be used to calculate what fractions examples are classified perfectly in multiclass classification problems.
For the following, I will assume
thresh=0.5. The current
accuracy_thresh function calculates for each examples what fraction of classes where identified correctly. For example, in a face classification problem, if the true labels are “bald, not smiling” and the predicted labels are “bald, smiling”, then 50% of the classes where identified correctly.
accuracy_thresh computes the mean of these fractions over all examples.
Instead, it can be useful to only consider a classification a success if all the classes are identified correctly. Effectively, this means setting the above mentioned fractions to 0 if they are not 1, and then calculate the mean.
The function would look like this:
def multilabel_accuracy(y_pred, y_true, tresh): y_pred = y_pred > tresh y_true = y_true.byte() return (y_pred == y_true).all(1).float().mean()
If this is deemed useful, I can create a pull request for this. It would also be possible to make this an optional parameter for
accuracy_thresh. What do you think?