AUC for multiclass classification

Adapted Using AUC as metric in fastai for multiclass

from sklearn.metrics import roc_auc_score

def auroc_score(input, target):
    input, target = input.cpu().numpy(), target.cpu().numpy()
    return roc_auc_score(target, input)

class AUC(Callback):
   "Computes the auc for multi-class classification"

    def on_epoch_begin(self, **kwargs): self.output, = [], []
    def on_batch_end(self, last_output, last_target, train, **kwargs):
        if not train:

    def on_epoch_end(self, last_metrics, **kwargs):
        if len(self.output) > 0:
            output =
            target =

            considered_class = self.clas

            indexes =  (target == considered_class)
            target_for_roc = torch.zeros(target.size())
            target_for_roc[indexes] = 1

            probs =  F.softmax(output, dim=1)
            pfinal = probs[:,considered_class]

            self.metric = auroc_score(pfinal, target_for_roc)

        return add_metrics(last_metrics, self.metric)


auc0 = AUC(clas=0)
auc1 = AUC(clas=1)
auc2 = AUC(clas=2)
learn = create_cnn(data, models.resnet34, 
                   metrics=[accuracy, auc0, auc1, auc2], callbacks = [notif_cb])

Leaving it out here for others to adapt, improve, find bugs


TypeError: object() takes no parameters

Getting the following errror:

ValueError Traceback (most recent call last)
in ()
----> 1 learn.fit_one_cycle(100)

10 frames
/usr/local/lib/python3.6/dist-packages/fastai/ in fit_one_cycle(learn, cyc_len, max_lr, moms, div_factor, pct_start, final_div, wd, callbacks, tot_epochs, start_epoch)
20 callbacks.append(OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor, pct_start=pct_start,
21 final_div=final_div, tot_epochs=tot_epochs, start_epoch=start_epoch))
—> 22, max_lr, wd=wd, callbacks=callbacks)
24 def lr_find(learn:Learner, start_lr:Floats=1e-7, end_lr:Floats=10, num_it:int=100, stop_div:bool=True, wd:float=None):

/usr/local/lib/python3.6/dist-packages/fastai/ in fit(self, epochs, lr, wd, callbacks)
200 callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
201 self.cb_fns_registered = True
–> 202 fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
204 def create_opt(self, lr:Floats, wd:Floats=0.)->None:

/usr/local/lib/python3.6/dist-packages/fastai/ in fit(epochs, learn, callbacks, metrics)
106 cb_handler=cb_handler, pbar=pbar)
107 else: val_loss=None
–> 108 if cb_handler.on_epoch_end(val_loss): break
109 except Exception as e:
110 exception = e

/usr/local/lib/python3.6/dist-packages/fastai/ in on_epoch_end(self, val_loss)
315 “Epoch is done, process val_loss.”
316 self.state_dict[‘last_metrics’] = [val_loss] if val_loss is not None else [None]
–> 317 self(‘epoch_end’, call_mets = val_loss is not None)
318 self.state_dict[‘epoch’] += 1
319 return self.state_dict[‘stop_training’]

/usr/local/lib/python3.6/dist-packages/fastai/ in call(self, cb_name, call_mets, **kwargs)
248 “Call through to all of the CallbakHandler functions.”
249 if call_mets:
–> 250 for met in self.metrics: self._call_and_update(met, cb_name, **kwargs)
251 for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)

/usr/local/lib/python3.6/dist-packages/fastai/ in _call_and_update(self, cb, cb_name, **kwargs)
239 def call_and_update(self, cb, cb_name, **kwargs)->None:
240 “Call cb_name on cb and update the inner state.”
–> 241 new = ifnone(getattr(cb, f’on
{cb_name}’)(**self.state_dict, **kwargs), dict())
242 for k,v in new.items():
243 if k not in self.state_dict:

in on_epoch_end(self, last_metrics, **kwargs)
33 pfinal = probs[:,considered_class]
—> 35 self.metric = auroc_score(pfinal, target_for_roc)
37 return add_metrics(last_metrics, self.metric)

in auroc_score(input, target)
4 def auroc_score(input, target):
5 input, target = input.cpu().numpy(), target.cpu().numpy()
----> 6 return roc_auc_score(target, input)

/usr/local/lib/python3.6/dist-packages/sklearn/metrics/ in roc_auc_score(y_true, y_score, average, sample_weight, max_fpr)
353 return _average_binary_score(
354 _binary_roc_auc_score, y_true, y_score, average,
–> 355 sample_weight=sample_weight)

/usr/local/lib/python3.6/dist-packages/sklearn/metrics/ in _average_binary_score(binary_metric, y_true, y_score, average, sample_weight)
78 check_consistent_length(y_true, y_score, sample_weight)
79 y_true = check_array(y_true)
—> 80 y_score = check_array(y_score)
82 not_average_axis = 1

/usr/local/lib/python3.6/dist-packages/sklearn/utils/ in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, warn_on_dtype, estimator)
519 "Reshape your data either using array.reshape(-1, 1) if "
520 "your data has a single feature or array.reshape(1, -1) "
–> 521 “if it contains a single sample.”.format(array))
523 # in the future np.flexible dtypes will be handled like object dtypes

ValueError: Expected 2D array, got 1D array instead:
array=[8.352953e-02 4.115525e-02 2.331579e-02 4.901069e-02 … 8.349909e-09 4.543838e-10 6.572282e-10 1.744149e-10].
Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.

FYI I realized my model is being trained for mult-label and that the shapes of input and target are different from one another.

There are 22 classes:
input shape (3161,) target shape (3161, 22)

Seems from sklearn that input and target should be of the same shape.

auc0 = AUC(clas=0)
auc1 = AUC(clas=1)
auc2 = AUC(clas=2)
auc3 = AUC(clas=3)
auc4 = AUC(clas=4)
auc5 = AUC(clas=5)
auc6 = AUC(clas=6)
auc7 = AUC(clas=7)
auc8 = AUC(clas=8)
learn = tabular_learner(
  layers=[40, 40, 40, 40], 
  metrics=[accuracy, auc0, auc1, auc2, auc3, auc4, auc5, auc6, auc7, auc8],

Using a tabular learner and when I try to fit_one_cycle, the following error is thrown:

How can I add another metric which is the mean of all my auc scores?

I think I’ve figured it out :slight_smile:
Declared a new variable roc and set that inside on_epoch_end as follows:
self.roc = self.metric

Then created a new metric whics takes aucs and computes mean of all the auc.rocs and made sure using _order that we compute the aucs and then only compute their mean.

or just use the newly added multi class roc auc metric :wink: