FastAi GrandparentSplitter File Path

I am trying to train a segmentation algorithm with FastAi. I have training and validation data in separate folders, so was planning on using GrandparentSplitter() but for some reason the validation set is empty.

My files are organized as below:

Path ---> train ---> images
                ---> masks
     ---> valid ---> images
                ---> masks
codes = np.array(['background', 'alpha', 'beta', 'jamma'])

def label_func(x): return path/'train/masks'/f'{x.stem}_mask.png'

db = DataBlock(blocks=(ImageBlock(), MaskBlock(codes)),
              splitter=GrandparentSplitter(train_name='train', valid_name='valid'),
              get_items=get_image_files,
              get_y=label_func)

dls = db.dataloaders(path/'train/images', bs=1)
dls.show_batch()

mask directory contains 3 times the items of images

mask file name example img-1-0001_s_0_alpha_mask.png
image file name example img-1-0001_s_0_i_0_ax.png

each image could have all 3 categories

Update
I wrote this get_mask function to be passed to get_y

def get_mask(x):
    """
    :param x: filename of scan from FastAi
    :return: path of corresponding mask
    """
    x_mask = ''
    x = str(x)

    if 'train' in x:
        p_id = x.split('/')[-1].split('_idx_')[0]
        main_dir = x.split('/images/')[0]
        for fn in os.listdir(os.path.join(main_dir, 'masks')):

            if p_id in fn:
                x_mask = os.path.join(os.path.join(main_dir, 'masks'), fn)
                print(x_mask)

    return Path(x_mask)

running
dls.train_ds[0] returns the following:

/mnt/Datasets/pngs_datasets/train/masks/img-01-0001_slice_11_alpha_masks.png
/mnt/Datasets/pngs_datasets/train/masks/img-01-0001_slice_11_beta_masks.png
/mnt/Datasets/pngs_datasets/train/masks/img-01-0001_slice_11_jamma_masks.png
Out[1]: (PILImage mode=RGB size=256x256, PILMask mode=L size=256x256)

but if i run
dls.show_batch(max_n=4, vmin=1, vmax=30, figsize=(14,10))
i get an error

<ipython-input-4-93800379df65> in get_mask(x)
     10         p_id = x.split('/')[-1].split('_idx_')[0]
     11         main_dir = x.split('/images/')[0]
---> 12         for fn in os.listdir(os.path.join(main_dir, 'masks')):
     13 
     14             if p_id in fn:

NotADirectoryError: [Errno 20] Not a directory: '/mnt/Datasets/pngs_datasets/train/masks/img-0109_slice_14_alpha_masks.png/masks'

Hi Fabio,

I think that the way you are loading the data into the dataloaders you may only be loading the train data, because of this: path / 'train/images'

In the examples that I have seen in the fastai courses they use another system to load the data:

  1. On the one hand, there are only two folders: images and masks with all the images together.
  2. And on the other hand, they load a document with the names of the images that are in validation: valid.txt

I do not know if you have seen this course but maybe you find it interesting:

I hope you find it useful :slight_smile:
Saioa

Thank you Saioa,

I saw the link, but I don’t get why I have to pass a txt file, I am using the GrandparentSplitter.

Hi Fabio
If you run this command what do you get?
Regards Conwyn

for fn in os.listdir(os.path.join(main_dir, ‘masks’)): print fn

I get something simi
/mnt/Datasets/pngs_datasets/train/masks/img-01-0001_slice_0_alpha_masks.png

In any case there is a fundamental error in the generated dataset, the mask’s images are 3 times the number of the images. I corrected that.
my question now is do i have to expand the get_mask function to include the validation set?

Did you have a fix on this issue? I am facing the same issue last time but no response from anyone and couldn’t find the topic troubleshooting in google.

in my case was an error in the number of files, meaning i had 3 times as many in the val set therefore the
def label_func(x): return path/‘train/masks’/f’{x.stem}_mask.png’ was not working

actually it didnt, i get a error: IndexError: list index out of range

Still trying to get it to work

def label_func(x):

if 'train/images' in str(x):
    x = str(x).replace('/images/','/masks/').replace('.png','_masks.png')
if 'val/images' in str(x):
    x = str(x).replace('/images/','/masks/').replace('.png','_masks.png')

return Path(x)

db = DataBlock(blocks=(ImageBlock(),MaskBlock(codes)),
batch_tfms= [
*aug_transforms( ), #from 0.75
Normalize.from_stats(*imagenet_stats)],
splitter=GrandparentSplitter(train_name=‘train’, valid_name=‘val’),
item_tfms=[Resize(im_size)],
get_items=partial(get_image_files, folders=[‘train’, ‘val’]),
get_y=label_func
)
ds = db.datasets(source=data_dir) with output

/mnt/Datasets/img_sets/SV_masks/train/images/0085_slice_0_idx_43_SV.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0085_slice_0_idx_43_SV_masks.png

dls = db.dataloaders(data_dir, bs = bs)

dls.show_batch() with output

/mnt/Datasets/img_sets/SV_masks/train/masks/0313_slice_24_idx_186_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0313_slice_24_idx_186_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/images/0055_slice_5_idx_26_SV.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0055_slice_5_idx_26_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0284_slice_10_idx_162_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0284_slice_10_idx_162_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0059_slice_10_idx_30_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0059_slice_10_idx_30_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0308_slice_2_idx_181_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0308_slice_2_idx_181_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/images/0264_slice_14_idx_148_SV.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0264_slice_14_idx_148_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0026_slice_18_idx_13_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0026_slice_18_idx_13_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/images/0026_slice_15_idx_13_SV.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0026_slice_15_idx_13_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0055_slice_12_idx_26_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0055_slice_12_idx_26_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/images/0284_slice_10_idx_162_SV.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0284_slice_10_idx_162_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0323_slice_11_idx_194_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0323_slice_11_idx_194_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/images/0130_slice_6_idx_72_SV.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0130_slice_6_idx_72_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/images/0129_slice_5_idx_71_SV.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0129_slice_5_idx_71_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0310_slice_10_idx_183_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0310_slice_10_idx_183_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0311_slice_8_idx_184_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0311_slice_8_idx_184_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0086_slice_19_idx_44_SV_masks.png
/mnt/Datasets/img_sets/SV_masks/train/masks/0086_slice_19_idx_44_SV_masks.png

which of course is wrong, since images should match masks

well after much researching and trying i finally got it to work. I must say that the

get_image_files(path, recurse=True, folders=None)

does not work or (very plausible) I do not get the usage.

My solution, I hope it will help others.

get_my_images simply creates a list of my images

def get_my_images(data_dir=data_dir, folders=[‘train/images’, ‘val/images’]):
res =[]
for d in folders:
for f in os.listdir(data_dir/f’{d}’):
res.append(Path(os.path.join(data_dir, d, f)))

return L(res)

assigns labels (masks)

def label_func(x):

if '/images/' in str(x):
    x = str(x).replace('/images/','/masks/').replace('.png','_masks.png')

return Path(x)

datablock does not need a splitter

db = DataBlock(blocks=(ImageBlock(),MaskBlock(codes)),
get_items=partial(get_my_images),
get_y=label_func,
item_tfms=[Resize(im_size)],
batch_tfms= [
*aug_transforms(),
Normalize.from_stats(*imagenet_stats)],
)

ds = db.datasets(source=data_dir)
dls = db.dataloaders(data_dir, bs = bs)

look at my answer

https://forums.fast.ai/t/fastai-grandparentsplitter-file-path/91477/9?u=fabio.geraci