When I do multiclass segmentation and use DiceMulti/ JaccardCoeffMulti, they always return an average value, while I want to view certain value for certain organ. How can I modify it? So much obliged for anyone answering it.
I’m not very familiar with segmentation, so someone else with more experience may know of a built-in way to get what you need, but here’s my hacky attempt which I think gets you the right output. You can see the full code in this Colab notebook.
In it, I redefine DiceMulti
(all I do is change the last line from return np.nanmean(binary_dice_scores)
to return binary_dice_scores
so it returns all of the scores and not just the mean)—not sure if that’s what you are looking for.
Thank you so much for your help. Though that link is not accessible, you description inspired me=).
I write a CustomDiceMulti with class_index as parameter.
class CustomDiceMulti(Metric):
"Averaged Dice metric (Macro F1) for multiclass target in segmentation"
def __init__(self, axis=1, class_index=None):
self.axis = axis
self.class_index = class_index
def reset(self): self.inter,self.union = {},{}
def accumulate(self, learn):
pred,targ = flatten_check(learn.pred.argmax(dim=self.axis), learn.y)
for c in range(learn.pred.shape[self.axis]):
p = torch.where(pred == c, 1, 0)
t = torch.where(targ == c, 1, 0)
c_inter = (p*t).float().sum().item()
c_union = (p+t).float().sum().item()
if c in self.inter:
self.inter[c] += c_inter
self.union[c] += c_union
else:
self.inter[c] = c_inter
self.union[c] = c_union
@property
def value(self):
# Get intermediate calculations
inter, union = self.get_inter_union()
if self.class_index is not None:
# Calculate and return for the specified class
return 2 * inter[self.class_index] / union[self.class_index] if union[self.class_index] > 0 else np.nan
else:
# Calculate mean across classes
binary_dice_scores = np.array([])
for c in inter:
binary_dice_scores = np.append(binary_dice_scores, 2 * inter[c] / union[c] if union[c] > 0 else np.nan)
return np.nanmean(binary_dice_scores)
def get_inter_union(self):
"""Returns the intersection and union counts for each class."""
return self.inter, self.union
It may help someone but I suggest second check on the math thing. And thank you again for it really helping me.
Oh interesting! I wasn’t sure what inter
was but now that makes sense.
Here’s the full code of DiceMulti
that I edited—your approach is better at getting the values directly though:
class DiceMulti(Metric):
"Averaged Dice metric (Macro F1) for multiclass target in segmentation"
def __init__(self, axis=1): self.axis = axis
def reset(self): self.inter,self.union = {},{}
def accumulate(self, learn):
pred,targ = flatten_check(learn.pred.argmax(dim=self.axis), learn.y)
for c in range(learn.pred.shape[self.axis]):
p = torch.where(pred == c, 1, 0)
t = torch.where(targ == c, 1, 0)
c_inter = (p*t).float().sum().item()
c_union = (p+t).float().sum().item()
if c in self.inter:
self.inter[c] += c_inter
self.union[c] += c_union
else:
self.inter[c] = c_inter
self.union[c] = c_union
@property
def value(self):
binary_dice_scores = np.array([])
for c in self.inter:
binary_dice_scores = np.append(binary_dice_scores, 2.*self.inter[c]/self.union[c] if self.union[c] > 0 else np.nan)
# return np.nanmean(binary_dice_scores)
return binary_dice_scores
this is the output during training (messy but at least shows the values):
And then .recorder.log
gets the array binary_dice_scores
for the last epoch:
And here’s my Colab link trying again: Google Colab
I see and thank you again for your help.