What if i want to get specific score for certain class?

When I do multiclass segmentation and use DiceMulti/ JaccardCoeffMulti, they always return an average value, while I want to view certain value for certain organ. How can I modify it? So much obliged for anyone answering it.

I’m not very familiar with segmentation, so someone else with more experience may know of a built-in way to get what you need, but here’s my hacky attempt which I think gets you the right output. You can see the full code in this Colab notebook.

In it, I redefine DiceMulti (all I do is change the last line from return np.nanmean(binary_dice_scores) to return binary_dice_scores so it returns all of the scores and not just the mean)—not sure if that’s what you are looking for.

1 Like

Thank you so much for your help. Though that link is not accessible, you description inspired me=).
I write a CustomDiceMulti with class_index as parameter.

class CustomDiceMulti(Metric):
    "Averaged Dice metric (Macro F1) for multiclass target in segmentation"
    def __init__(self, axis=1, class_index=None):
      self.axis = axis
      self.class_index = class_index
    def reset(self): self.inter,self.union = {},{}
    def accumulate(self, learn):
        pred,targ = flatten_check(learn.pred.argmax(dim=self.axis), learn.y)
        for c in range(learn.pred.shape[self.axis]):
            p = torch.where(pred == c, 1, 0)
            t = torch.where(targ == c, 1, 0)
            c_inter = (p*t).float().sum().item()
            c_union = (p+t).float().sum().item()
            if c in self.inter:
                self.inter[c] += c_inter
                self.union[c] += c_union
            else:
                self.inter[c] = c_inter
                self.union[c] = c_union
    @property
    def value(self):
        # Get intermediate calculations
        inter, union = self.get_inter_union()

        if self.class_index is not None:
            # Calculate and return for the specified class
            return 2 * inter[self.class_index] / union[self.class_index] if union[self.class_index] > 0 else np.nan
        else:
            # Calculate mean across classes
            binary_dice_scores = np.array([])
            for c in inter:
                binary_dice_scores = np.append(binary_dice_scores, 2 * inter[c] / union[c] if union[c] > 0 else np.nan)
            return np.nanmean(binary_dice_scores)

    def get_inter_union(self):
        """Returns the intersection and union counts for each class."""
        return self.inter, self.union

It may help someone but I suggest second check on the math thing. And thank you again for it really helping me.

1 Like

Oh interesting! I wasn’t sure what inter was but now that makes sense.

Here’s the full code of DiceMulti that I edited—your approach is better at getting the values directly though:

class DiceMulti(Metric):
    "Averaged Dice metric (Macro F1) for multiclass target in segmentation"
    def __init__(self, axis=1): self.axis = axis
    def reset(self): self.inter,self.union = {},{}
    def accumulate(self, learn):
        pred,targ = flatten_check(learn.pred.argmax(dim=self.axis), learn.y)
        for c in range(learn.pred.shape[self.axis]):
            p = torch.where(pred == c, 1, 0)
            t = torch.where(targ == c, 1, 0)
            c_inter = (p*t).float().sum().item()
            c_union = (p+t).float().sum().item()
            if c in self.inter:
                self.inter[c] += c_inter
                self.union[c] += c_union
            else:
                self.inter[c] = c_inter
                self.union[c] = c_union

    @property
    def value(self):
        binary_dice_scores = np.array([])
        for c in self.inter:
            binary_dice_scores = np.append(binary_dice_scores, 2.*self.inter[c]/self.union[c] if self.union[c] > 0 else np.nan)
        # return np.nanmean(binary_dice_scores)
        return binary_dice_scores

this is the output during training (messy but at least shows the values):

And then .recorder.log gets the array binary_dice_scores for the last epoch:

And here’s my Colab link trying again: Google Colab

I see and thank you again for your help. :grin:

1 Like