U_net: RGB masks values convertion in fast.ai

(Antonin Sumner) #1

I’m trying to apply lesson 3’s U-net segmentation to the ICG Semantic Drone Dataset.
https://www.tugraz.at/index.php?id=22387
The dataset is provided with RGB masks and the labels are provided with RGB values such as:

unlabeled, 0, 0, 0
paved-area, 128, 64, 128
dirt, 130, 76, 0
grass, 0, 102, 0

The issue I have is that the mask’s data has only one value by pixel:

tensor([[[ 0, 0, 0, …, 0, 0, 0],
[ 0, 0, 0, …, 0, 0, 0],
[ 0, 0, 0, …, 0, 0, 0],
…,
[90, 90, 90, …, 0, 0, 0],
[90, 90, 90, …, 0, 0, 0],
[90, 90, 90, …, 0, 0, 0]]])

It hase come to my understanding that the mask is converted to greyscale values.
But the issue I have is that I’m lost with my RGB labels references.
How do I make the link from the greyscale val of mask.data to the RGB values given for the labels?
What is the math behind the RGB to greyscale convertion?
What are the weight given to each channel’s value to get a unique greyscale value?

0 Likes

(Antonin Sumner) #2

Okay so to know how Fast.ai Converts R,G,B values to a single value, I created a PNG image where each pixel is a color of the masks i’ll be dealing with.
I have 24 different colors with RGB values provided by the documentation of the dataset:
image

Here is the resulting image in photoshop:
image

Then I load it with fast.ai as a mask but here’s what I get:

I finally get the colors references but I also realise that lot of colors are merged such as 164, 103, 45
So if I train Unet with thoses values some elements in the images will be merged.

Is there a way to manage the way open_mask() converts colors?

0 Likes

(Dave Luo) #3

You can customize open_mask() by writing your own conversion from RGB–>int value.

Something like this could work (my 1st thought on this so the code could very likely be made simpler):

# create list of RGB values in order of idx value to replace with, i.e. 0: [0,0,0], 1: [255,0,0]
rgb_list = [
    [0,0,0],
    [255,0,0],
    [255,255,0],
    [0,0,255],
]

def convert_mask(old_mask, rgb_list):
  new_mask = torch.zeros((old_mask.shape[-2],old_mask.shape[-1]))
  for idx, rgb in enumerate(rgb_list):
    # create a bytemask for pixels = rgb value to be replaced
    rgb_mask = torch.sum(old_mask.data.view((3,-1)).permute(1,0) == tensor(rgb),dim=1)==3 
    # fill in pixels with new idx value
    new_mask.masked_fill_(rgb_mask.view(new_mask.shape), tensor(idx)) 
  return ImageSegment(new_mask.unsqueeze(0))

def open_mask_converted(fn:PathOrStr, div=False, convert_mode='L', after_open:Callable=None, rgb_list=rgb_list)->ImageSegment:
    "Return `ImageSegment` object create from mask in file `fn`. If `div`, divides pixel values by 255."
    return convert_mask(open_image(fn, div=div, convert_mode=convert_mode, cls=ImageSegment, after_open=after_open), rgb_list)

mask = open_mask_converted(get_y_fn(img_f), convert_mode='RGB', div=False)

which gives this (in my case having 3 different RGB values + a background [0,0,0]):

2 Likes

(Antonin Sumner) #4

Hey @daveluo,

thanks for this help!

I’m tryed to use your functions but it seems the masks are too big.
Here’s what i get from the convert_mask function:

0 Likes

(Dave Luo) #5

You’re right that your input masks are too big (4000x6000) but the error you’re running into is not because of that. It looks like a tensor shape/size mismatch error - some tensor which flattens out to be 8M elements is mismatched with another one that’s 24M elements (4000x6000).

I suggest the best way to find the problem is by adding a pdb.set_trace() to the beginning of the function. Then step through each line as it executes and print out the shapes of each input and output with p YOUR_TENSOR.shape within the debugger to locate the mismatch. Then you’ll have more info about how to fix it. It may be because the input is the wrong size/shape or there’s some tensor reshaping that got screwed up along the way.

Re: the large-sized masks, you should also make these much smaller by preprocessing your input images and masks into smaller square chips/tiles, something like 250x250.

2 Likes