Unbalanced classes in image segmentation


(nkiruka chuka-obah) #1

I am doing background segmentation where the background is many times more than the foreground and sometimes the foreground is not present. Is there a principled method that I can use for dealing with this in my model? Right now, I am trying to use weights in the BCEWITHLOGITSLOSS criterion but i don’t see a significant change.

Any ideas


(Brookie Guzder-Williams) #2

I’m also doing image segmentation with unbalanced classes but have been working in Keras, so I haven’t used the pytorch BCEWITHLOGITSLOSS.

Keras doesn’t have a weights parameter but I wrote my own (simply by copying the Keras source code for categorical-crossentropy and adding weight parameter).

All that said - I did see improvements in my results. If BCEWITHLOGITSLOSS isn’t behaving as you expect maybe you could try writing your own pytorch version to ensure its doing what you want it too.

If it helps here’s what I did in Keras:

def weighted_categorical_crossentropy(weights):
    """ weighted_categorical_crossentropy

        Args:
            * weights<ktensor|nparray|list>: crossentropy weights
        Returns:
            * weighted categorical crossentropy function
    """
    if isinstance(weights,list) or isinstance(np.ndarray):
        weights=K.variable(weights)

    def loss(target,output,from_logits=False):
        if not from_logits:
            output /= tf.reduce_sum(output,
                                    len(output.get_shape()) - 1,
                                    True)
            _epsilon = tf.convert_to_tensor(K.epsilon(), dtype=output.dtype.base_dtype)
            output = tf.clip_by_value(output, _epsilon, 1. - _epsilon)
            weighted_losses = target * tf.log(output) * weights
            return - tf.reduce_sum(weighted_losses,len(output.get_shape()) - 1)
        else:
            raise ValueError('WeightedCategoricalCrossentropy: not valid with logits')
    return loss

I didn’t implement the logits part since it wasn’t immediately clear how to do it and I didn’t need it.

The other piece of the puzzle is of course picking the correct weights. I didn’t think too deeply about this but I found a stackoverflow answer where they took the log of the ratio (total_for_all_categories/total_for_category) because the different categories were imbalanced by many orders of magnitude. I tried both - with the log and without the log and got better results with the log version.

Hopefully that rambling is a little helpful.


(nkiruka chuka-obah) #3

Thanks @brookisme I am transferring an architecture from a keras model to pytorch and they use weights obtained from fraction of pixels in a class. I’ll write the custom loss and see if that helps.


(RobG) #4

How are you constructing the weights passed to the pytorch loss function? After much forum reading I settled on sample size / (num classes * class frequency) but didn’t have much luck, at least with CrossEntropyLoss.
Also, if I understand, the object to classify is dimensionally small, in which case you could run a bounding box network first.


(nkiruka chuka-obah) #5

Hi, I’m doing it by fraction. That is: total_of_class/total


(Brookie Guzder-Williams) #6

like @nchukaobah i’m starting with the fraction total_of_class/total but my classes are extremely unbalanced so I am using the log and requiring the min to be 1…

w=total_of_class/total
w=math.log(C*w)
w=max(w,1)

This got the idea from this SO answer.


(Brookie Guzder-Williams) #7

In case you’re still interested I’ve ported my keras code above over to pytorch.

It was mostly straight forward. The only real tricks were:

  1. Because pytorch is bands-first (as opposed to Keras being bands-last) I had to take the transpose of the unweighted losses to do the multiplication
  2. Keras automatically performs a mean-reduction so although my Keras code above returns a tensor it still works as a loss function. Pytorch does not so I had to explicitly call mean.
  3. Noting that the the mean is the same for a transposed tensor I saved the step of transposing back to a bands-first batch.

Here’s what it looks like:

""" Usage
criterion = WeightedCategoricalCrossentropy(weights,device=DEVICE)
"""

class WeightedCategoricalCrossentropy(nn.Module):
    """ weighted_categorical_crossentropy
        
        mean reduction of weighted categorical crossentropy

        Args:
            * weights<tensor|nparray|list>: category weights
            * device<str|None>: device-name. if exists, send weights to specified device
    """
    def __init__(self, weights, device=None):
        super(WeightedCategoricalCrossentropy, self).__init__()
        self.weights=to_tensor(weights)
        if device:
            self.weights=self.weights.to(device)

    def forward(self, inpt, targ):
        return weighted_categorical_crossentropy(inpt,targ,self.weights)


def weighted_categorical_crossentropy(inpt,targ,weights):
    """ weighted_categorical_crossentropy

    Args:
        * inpt <tensor>: network prediction 
        * targ <tensor>: network target
        * weights<tensor|nparray|list>: category weights
    Returns:
        * mean reduction of weighted categorical crossentropy
    """
    weights=to_tensor(weights).float()
    inpt=inpt/(inpt.sum(1,True)+EPS)
    inpt=torch.clamp(inpt, EPS, 1. - EPS)
    losses=((targ * torch.log(inpt))).float()
    weighted_losses_transpose=weights*losses.transpose(1,-1)
    return -weighted_losses_transpose.mean()


def to_tensor(v):
    if isinstance(v,list):
        v=torch.tensor(v)
    elif isinstance(v,np.ndarray):
        v=torch.from_numpy(v)
    return v

(Joseph Bullock) #8

Hi,

What are you defining as EPS here?


(Brookie Guzder-Williams) #9

Apologies for not including it before.

EPS is just a small number (short for “epsilon” ). I used 1e-9 but the
idea is to avoid division by zero in this line

inpt=inpt/(inpt.sum(1,True)+EPS)

and in this line

inpt=torch.clamp(inpt, EPS, 1. - EPS)

avoid values of exactly 0 or 1.


(Joseph Bullock) #10

Great, thank you. I also found that using

learn.lossfunc = CrossEntropyFlat(weight = weight)

where weight is a Tensor works for Segmentation tasks.