Column-wise AUC loss function

(Feras) #1

I was wondering if anyone had tips on how to integrate the following defined loss function in the fastai framework language model “mean column-wise ROC AUC”. Should we use something like sklearn’s and then for Y andY_true pass 2 matrices containing all the columns for which the above loss function applies or a different approach is needed? Seems this function supports only 2 vectors, not 2 matrices, but we could do it per column and then avg?


The one you link to works with multiple labels. I’ll copy in how I use it with fastai lib when I get to my computer later today.

You need at least one positive example per class for it to work so you might need a large batch size for the validation set.


This is what I pass in metrics array to fit:

 def roc_auc(preds, y):                                                                                                                                                                                         
  1     return metrics.roc_auc_score(, np.exp(preds.cpu().numpy())) 

(metrics is just from sklearn import metrics)

(Feras) #4

Got it. Thanks a lot!

(Feras) #5

This is the final code I have since it seems that fastai doesn’t automatically one hot encode the Ys:

def roc_auc(preds, y):
  preds = np.exp(preds) #conv from logs
  exp = V(y).data.cpu().numpy() #predicted category ID
  #batch X num_classes
  bs = preds.shape[0] #batch size
  nclass = preds.shape[1] #size to determine length of Y one hot encoding
  y = np.zeros((bs, nclass))
  y[np.arange(bs), exp] = 1 #one hot encode Ys
  return metrics.roc_auc_score(y, preds, average="micro")

(Esteban J Guillen) #6

Why do you need to call .cpu().numpy() on preds?

(Feras) #7

Maybe pytorch change quite a bit since last time I used it, but since sklearn metrics are a python call that expects a numpy array as parameter, you need to get the GPU variable and convert it to that, so it’s usable in python code that runs on the CPU. Not sure if pytorch does that transparently these days.