Wanting to plot most confused classes

I have this model that I’ve learned on PlantClefDataset (1k classes of plants). I’ve got 13% error rate which is pretty neat, but I would like to further investigate images model gets wrong.

There are few classes that it mistakes quite often (from 3 to 7 times) and I would like to be able to have a function/script that would do the following:

  • take optional argument of class, or maybe list of classes
  • plot those mistakenly classified for the 2nd class
  • plot few examples of the examples of the second class

or just show all the images from the 1 square of confusion matrix (just like How to see images from one square of confusion matrix)

Why I need to display the concrete images? Because sometimes plants are fairly different on one image and look almost identical on the other and I would like to know if it is my model guessing wrong, or those plants are really similar.

I’ve tried playing around with classification interpretation, but I’m not that familiar with fastai code to write that myself at the moment.
I know I can get the number of mistakes from confusion matrix, class names from interp.data.y.classes.

Has anyone written something similar? How can I get the ids of concrete images that got confused?

1 Like

Your best bet would probably be to modify plot_top_losses() to take in a class argument or two, have it map back to the c2i index value, and try to reverse map that. I can try to work something very basic up shortly, but let me know if that’s enough to start with. Look at the top_losses function as well.

Edit: Here is some starter code. I got it working for tabular, so this should give you a strong footing as to what is needed for Image data!!!

for i, idx in enumerate(tl_idx):
  da, cl = interp.data.dl(DatasetType.Valid).dataset[idx]
  da = str(da)
  id = int(cl)
  cl = str(cl)
  
  da = da.split(';')
  arr = []
  if cl == label:
    arr.extend([classes[interp.pred_class[idx]], classes[id], f'{interp.losses[idx]:.2f}',
                    f'{interp.preds[idx][id]:.2f}'])
    for x in range(len(da)-1):
            _, value = da[x].rsplit(' ', 1)
            arr.append(value)
    df.loc[i] = arr

Here, we assume ‘label’ is some ground truth. If we wanted to take this further, we could also have put in

if cl == label1:
    if classes[interp.pred_class[idx]] == label2:

in order for us to see the confused between two classes!

If you’d like me to try to do the same for images let me know :slight_smile:

1 Like

You can get the predictions and then manipulate according to your needs.

def get_something(predicted_class, actual_class):
    preds, y = learn.get_preds(DatasetType.Valid)
    
    y_hat = preds.argmax(dim=-1).numpy()
    y = y.numpy()
    
    for i, (a, b) in enumerate(zip(y_hat, y)):
        if a == predicted_class and b == actual_class:
            # You can do your manipulations here
            # To access the data in the valid set
            data.valid_ds.x[i]

I just like to work with numpy arrays. You can skip that step if you want.

1 Like

@Blanche here is something that is almost complete. I’m quite stumped on it but perhaps you can figure out what I missed. Interesting bit is it returns what we want, but I get an ‘index 12 is out of bounds’ error. I don’t have time to look at it anymore now, I may revisit it tonight if you can’t get it. 99% of this code was taken from the plot_top_losses() function from the source code, so that if you were to want to make a PR, the code structure is already there.

def plot_top_losses(interp, k:int, class_1:str, class_2:str, largest=True, figsize=(12,12), heatmap:bool=None, heatmap_thresh:int=16, return_fig:bool=None)->Optional[plt.Figure]:
    "Show images in `top_losses` along with their prediction, actual, loss, and probability of actual class."
    tl_val,tl_idx = interp.top_losses(len(interp.losses))
    classes = interp.data.classes
    cols = math.ceil(math.sqrt(k))
    rows = math.ceil(k/cols)
    fig,axes = plt.subplots(rows, cols, figsize=figsize)
    fig.suptitle('prediction/actual/loss/probability', weight='bold', size=14)
    x = 0
    for i,idx in enumerate(tl_idx):
        im,cl = interp.data.dl(interp.ds_type).dataset[idx]
        ix = int(cl)
        cl = str(cl)
        if cl == class_1 and classes[interp.pred_class[idx]] == class_2:
          im.show(ax=axes.flat[x], title=
              f'{classes[interp.pred_class[idx]]}/{classes[ix]} / {interp.losses[idx]:.2f} / {interp.preds[idx][ix]:.2f}')
          x += 1
    if ifnone(return_fig, defaults.return_fig): return fig
1 Like

Thank you @muellerzr and @kushaj very much :smiley:
After few adjustment it works, I’m only missing preds, because interp doesn’t expose those.

def plot_top_losses(interp, k: int, class_1: str, class_2: str, largest=True, figsize=(12, 12), heatmap: bool = None,
                    heatmap_thresh: int = 16, return_fig: bool = None) -> Optional[plt.Figure]:
    "Show images in `top_losses` along with their prediction, actual, loss, and probability of actual class."
    tl_val, tl_idx = interp.top_losses(len(interp.losses))
    classes = interp.data.classes
    cols = math.ceil(math.sqrt(k))
    rows = math.ceil(k / cols)
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    fig.suptitle('prediction/actual/loss/probability', weight='bold', size=14)
    x = 0
    for i, idx in enumerate(tl_idx):
        im, cl = interp.data.dl(interp.ds_type).dataset[idx]
        ix = int(cl)
        if str(cl) == class_1 and str(classes[interp.pred_class[idx]]) == class_2 or str(cl) == class_2 and str(classes[interp.pred_class[idx]]) == class_1:
            if x >= k:
                break
            # TODO add preds from somewhere to label
            im.show(ax=axes.flat[x], title=
            f'{classes[interp.pred_class[idx]]}/{classes[ix]} / {interp.losses[idx]:.2f}')
            x += 1
    if ifnone(return_fig, defaults.return_fig): return fig
1 Like

Great work! I see what you mean now. Let me take a look

Edit:

def plot_top_losses(interp, k: int, class_1: str, class_2: str, largest=True, figsize=(12, 12), heatmap: bool = None,
                    heatmap_thresh: int = 16, return_fig: bool = None) -> Optional[plt.Figure]:
    "Show images in `top_losses` along with their prediction, actual, loss, and probability of actual class."
    tl_val, tl_idx = interp.top_losses(len(interp.losses))
    classes = interp.data.classes
    cols = math.ceil(math.sqrt(k))
    rows = math.ceil(k / cols)
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    fig.suptitle('prediction/actual/loss/probability', weight='bold', size=14)
    x = 0
    for i, idx in enumerate(tl_idx):
        im, cl = interp.data.dl(interp.ds_type).dataset[idx]
        ix = int(cl)
        if str(cl) == class_1 and str(classes[interp.pred_class[idx]]) == class_2 or str(cl) == class_2 and str(classes[interp.pred_class[idx]]) == class_1:
            if x >= k:
                break
            # TODO add preds from somewhere to label
            im.show(ax=axes.flat[x], title=
            f'{classes[interp.pred_class[idx]]}/{classes[ix]} / {interp.losses[idx]:.2f} / {interp.preds[idx][ix]:.2f}')
            x += 1
    if ifnone(return_fig, defaults.return_fig): return fig

You were missing {interp.preds[idx][ix]:.2f}

Perhaps @sgugger can chime in onto what it should be named, as I’m unsure if plot_confused would be too close. Or perhaps an append to plot_top_losses where we have an optional two classes and one class function?

1 Like

I’d also add plot_top_correct for given class name, so we can see the difference between incorrectly classified and the actual class. Like for example for ‘Centaurea decipiens Thuill.’ vs ‘Centaurea jacea L.’ I can see that flower is a tad different.
I didn’t have preds, because I had outdated fastai lib :stuck_out_tongue:

I’ve butchered plot_top_correct, because there is commented out method top_scores and making that work is non trivial for me at the moment.

def plot_top_correct(interp, k: int, actual_class: str, largest=True, figsize=(12, 12), heatmap: bool = None,
                    heatmap_thresh: int = 16, return_fig: bool = None) -> Optional[plt.Figure]:
    classes = interp.data.classes
    preds, y = learn.get_preds(DatasetType.Valid)
    
    y_hat = preds.argmax(dim=-1).numpy()
    y = y.numpy()
    classes = interp.data.classes
    cols = math.ceil(math.sqrt(k))
    rows = math.ceil(k / cols)
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    fig.suptitle('prediction/loss/probability', weight='bold', size=14)
    x = 0
    for idx, (a, b) in enumerate(zip(y_hat, y)):
        if classes[a] == actual_class and classes[b] == actual_class:
            im, cl = interp.data.dl(interp.ds_type).dataset[idx]
            ix = int(cl)
            if x >= k:
                break
            im.show(ax=axes.flat[x], title=
            f'{classes[interp.pred_class[idx]]}/ {interp.losses[idx]:.2f}/ {interp.preds[idx][ix]:.2f}')
            x += 1
    if ifnone(return_fig, defaults.return_fig): return fig
2 Likes

@Blanche, give me a day or two… I have an idea brewing that may help with this for the user… will report back soon. And hopefully a surprise for google colab users too…

1 Like

@Blanche and it’s done! I made a widget for Google Colab currently, and working on getting it for regular platforms! (and image cleaner)

2 Likes