ImageDataBunch from 500 megapixel images as tiles

:stuck_out_tongue: @sgugger
This is all I got to but I’m sure @neuronq can take it from here :slight_smile:

class SegmentationTileItemList(SegmentationItemList):
    def __init__(self, segments_per_image, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.segments_per_image = segments_per_image

    def get_image_segment(self, full_image, segment_idx):
        pass # implement this function

    def get(self, i):
        segment_idx = index % self.segments_per_image
        image_idx = i // self.segments_per_image

        fn = super().get(image_idx)
        full_image = self.open(fn)

        res = self.get_image_segment(full_image, segment_idx)
        self.sizes[i] = res.size
        return res

class SegmentationTileLabelList(SegmentationLabelList):
    def __init__(self, segments_per_label, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.segments_per_label = segments_per_label

    def get_label_segment(self, full_label, segment_idx):
        pass # implement this function

    def get(self, i):
        segment_idx = index % self.segments_per_label
        label_idx = i // self.segments_per_label

        fn = super().get(label_idx)
        full_label = self.open(fn)

        res = self.get_label_segment(full_label, segment_idx)
        self.sizes[i] = res.size
        return res
1 Like

Nice job! You should just add _label_cls=SegmentationTileLabelList as a class variable of your first class, so that it knows to label with this automatically.

Thanks a looot to all! I’ll refactor (current code ended up with the “write tiles to disk” solution that took 15min to code :stuck_out_tongue:…), and post sometime ~Monday-ish the solution that I picked…

(spoiler alert: it’ll likely be @Hadus 's SegmentedDataset solution…)

Chiming in, for breaking an image down into patches, the tensor.unfold function makes it really easy to turn a tensor into tiles.

You could load the large image as a tensor in RAM, break it into patches, then feed batches of patches to the GPU. If you have your segmentation ground truth as a large array, you can break it into patches the same way.

(link discusses 3D images but it works fine for 2D as well)

3 Likes

Are you sure about this get function? First, I don’t see where index is coming from? index = i I assume?

When using from_folder or so, it will load all filenames it can find (lets say 100). When this thing is sampled in the end, that means the i in get will always be between 0 and 100, not 100 * n_segments.

That is the problem that I was trying to refer to. Since the whole WhateverList hierarchy in fastai ultimately extends ItemList, it will be the filenames (.items) that determine how big this index is going to be.
And while you can easily fix that in the __len__ in a pytorch dataset, it won’t be that easy for this deep class hierarchy.

You can of course do data = FancyList.from_whereever(), then modify items and then continue with the datablocks pipeline. However I would not easily assume (without carefully checking the datablocks api code) that this change will have no undesidred behavior compared to having the correct items array set in the __init__. Same goes with the LabelList that is created internally. Maybe modifying items is the only thing left to do, can’t tell for sure right now.

Maybe there’s also something that I don’t see right now (should definitely go to bed :D). Anyways, hope this helps :slight_smile:

Good point!
I totally missed the fact that now we have a lot more data…

But anyhow @neuronq said he is going to most likely use the SegmentedDataset solution instead.

self.get is

how to create item i from self.items

The easiest solution would be to generate a random segment_idx.
The only change is:

    def get(self, i):
        segment_idx = np.random.randint(0, self.segments_per_image)
        image_idx = i
        ...
    def get(self, i):
        segment_idx = np.random.randint(0, self.segments_per_label)
        label_idx = i
        ...

This would work pretty well. It would act as some kind of data augmentation.
This kind of behaviour could be probably done with data augmentation easier…
One epoch would go through all the full images but only one segment per each.




If we want to do it properly then I think we do have to do more complicated stuff with self.items.

self.items[i] (which is a filename)

So in self.items there are all the file names we use; my idea is to just make segments_per_image number of copies of self.items and store it back in self.items:

That way we have the right number of data. When we go through an epoch it will go through the same filename multiple times so we also need something to index each of those filenames to a specific segment.

class SegmentationTileItemList(SegmentationItemList):
    def __init__(self, segments_per_image, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.segments_per_image = segments_per_image

        self.segment_idxs = np.repeat([range(self.segments_per_image)], len(self.items))
        # [0, 0, 0, 1, 1, 1, 2, 2, 2]
        self.items = np.repeat([self.items], self.segments_per_image, axis=0).flatten()
        # ["img0", "img1", "img2", "img0", "img1", "img2", "img0", "img1", "img2"]

    def get_image_segment(self, full_image, segment_idx):
        pass # implement this function

    def get(self, i):
        segment_idx = self.segment_idxs[i]

        fn = super().get(i)
        full_image = self.open(fn)

        res = self.get_image_segment(full_image, segment_idx)

        self.sizes[i] = res.size
        return res

It is pretty much the same for the label one…

When picking segments at random, one should also assure that the same random segment is picked for the label then. Might just be setting a random seed tho.

One whole other aspect in this that my head can’t stop thinking about is: how tiny would you want to tile the image? More tiles means more individual samples than you can augment etc, means a lot of images.
On the other hand more tiles means you put in more ‘crappy’ information through convolution padding and data augmentation padding for which you would often have actual data… (ok for the augmentation part you can get around this if you wanted)
I wonder where the sweet spot for this is. If you find out @neuronq, I would be really interested to hear about it :slight_smile:

This is the step that prevented me completing a solution to a similar need a few months ago. I’d be interested in any update if someone gets it working well.

@digitalspecialists Here is one way to do it:

import contextlib
import numpy as np

@contextlib.contextmanager
def temp_seed(seed):
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)
    def get(self, i):
        with temp_seed(i):
            segment_idx = np.random.randint(0, self.segments_per_image)
        ...

    def get(self, i):
        with temp_seed(i):
            segment_idx = np.random.randint(0, self.segments_per_label)
        ...

This would still mean that every epoch we would have the same segments because the seed would repeat every epoch. To fix this we can use callbacks I think.

2 Likes

@digitalspecialists
To make item and label have the same random segment index use callbacks.

The code below doesn’t work! I am not sure how to write this callback… it should be something like:

class SegmentIdxGenCallback(Callback):
    def on_batch_begin(self):
        random_idx = np.random.randint(0, self.data.items.segments_per_image)
        self.data.labels.segment_idx = random_idx
        self.data.items.segment_idx = random_idx

fit(1, learner, cb=CallbackHandler([SegmentIdxGenCallback]))

then in SegmentationTileItemList change the beginning of get:

    def get(self, i):
        fn = super().get(self.segment_idx)
        ...
2 Likes

Hi, have you solved this yet?

@sgugger I have been messing around taking a stab at trying to fit this into the fastai framework. I’m building off of @Hadus’s solution above. Do you think the best way is to use a random seed for matching up the SegmentationTileLabelList and SegmentationTileItemList? Or if we return all segments by adding an axis and then handle downstream further? What about for the test dataset where we just want to scan across and then recombine and show the full stitched together image?

I’m fairly close on having my pipeline working by just preslicing the images and saving to disk so you load in that way initially. Then for test doing the slicing and then stitch back together outside of the fastai framework, but I think it would be really slick and help a lot of people if we try to get it integrated within.

I used this solution. I created a custom open() method that takes a tuple of file path and tile index and returns a fastai Image of the selected tile. Then I customized also from_folder() method.

ImageTile = namedtuple('ImageTile', 'path idx rows cols')


def calc_n_tiles(size, tile_max_size):
    x, y = size
    n_cols = x % (x // tile_max_size + 1)
    n_rows = y % (y // tile_max_size + 1)
    return n_rows, n_cols, (x//n_cols, y//n_rows)


def get_labels_tiles(fn):
    path, *tile = fn
    path = path_lbl / path.name
    return ImageTile(path, *tile)


def get_tiles(images: PathOrStr, rows: int, cols: int) -> Collection[ImageTile]:
    images_tiles = []
    for img in images:
        for i in range(rows * cols):
            images_tiles.append(ImageTile(img, i, rows, cols))
    return images_tiles


def open_image_tile(img_t: ImageTile, mask=False, **kwargs) -> Image:
    """given and ImageTile it returns and Image with the tile,
    set mask to True for masks"""
    path, idx, rows, cols = img_t
    img = open_image(path, **kwargs) if not mask else open_mask(path, **kwargs)
    row = idx // cols
    col = idx % cols
    tile_x = img.size[0] // cols
    tile_y = img.size[1] // rows
    return Image(img.data[:, col * tile_x:(col + 1) * tile_x, row * tile_y:(row + 1) * tile_y])


class SegmentationTileLabelList(SegmentationLabelList):

    def open(self, fn: ImageTile):
        return open_image_tile(fn, div=True, mask=True)


class SegmentationTileItemList(ImageList):
    _label_cls, _square_show_res = SegmentationTileLabelList, False

   
    def open(self, fn: ImageTile) -> Image:
        return open_image_tile(fn, convert_mode=self.convert_mode, after_open=self.after_open)

    @classmethod
    def from_folder(cls, path: PathOrStr = '.', rows=1, cols=1, extensions: Collection[str] = None, **kwargs) -> ItemList:
        """patchs the from_folder method, generating list of ImageTile with all the possible tiles for all the images in folder"""
        files = get_files(path, extensions, recurse=True)
        files_tiled = get_tiles(files, rows, cols)
        return SegmentationTileItemList(files_tiled, **kwargs)

Hope this is not a silly approach and it can bu useful.

3 Likes

Hi! I am working in a similar project.

I would like to use very big images (Histological images - Whole Slide Images - WSI) and use the Databunch to learn from here. Actually I am using an approach that is similar to one @massaros is using but using a dictionary instead of a tuple and using openslide-python which is a library for managing this kind of images.

So, the code I would like to implement comes from this link where they implement this on PyTorch. But I am kind of stucked to get the labels correctly. Any ideas how I could manage to implement this on fastai?

Thanks!

1 Like

@Joan in my code I am overloading the open function for SegmentationLabelList to return the correct mask (and set the custom LabelList class)

I’m considering trying this with fastai2, but for bounding box labels instead of segmentation, and I would appreciate any updates as to what worked or didn’t with the solutions that were suggested here.

If I understand it right, it looks to me like @massaros has a workable solution, except it is perhaps missing a way to randomize tile selection. Did anyone get a solution going for picking tiles and labels at random from an image?

Thank you so much for laying this out. This is the best approach I have seen. I have a quick question though. I was able to get the original images loaded in, but in my labeling step of the data loader I am having trouble seeing how to pass in the appropriate indices. For example I have the code below:

get_y_fn = lambda x: mask_path + f'/{x[0].stem}.png'

data = SegmentationTileItemList.from_folder(img_path, rows=8, cols=20) \
        .split_by_rand_pct(valid_pct=0.2, seed=5) \
        .label_from_func(get_y_fn, classes={'background':0, 'other':1})

and it looks like it is loading the tiles perfectly, but it is erroring because in open_image_tile on this line: path, idx, rows, cols = img_t It is trying to unpack <path> instead of [<path> 0 8 20]

How were you able to label appropriately while injecting in this Tiling class?

Hi, I have made some refactoring to the code to improve it. You can find the latest version here https://github.com/mone27/fruit-detection/blob/master/semantic_segmentation_tile.py (some code needs improvements to be more general). I added the code to also change the background on the tile which is probably adding some complexity not needed for you.

@mark-hoffmann for me labelling works, I would suggest you to try with the latest version of my code (hoping it is clear enough) then tell me if you still have issues I would be more than happy to try to help you.
(Or you can send the full code you are using, because I did several revisions of the code and I don’t remember know what it did)

Let me know if you have any other question, doubt or suggestion.

1 Like

Hi, the tiles are randomized by the dataloader so you won’t get all the tiles form an image one after another. If you want to get only some random tiles from the image in the from_folder method instead of taking all the possible tiles just take a subset of them customize the get_tiles function:

def get_tiles(images, rows: int, cols: int, tile_info: Collection) -> Collection[ImageTile]:
    images_tiles = []
    for img in images:
        for row, col in product(range(rows), range(cols)):
            images_tiles.append(ImageTile(img, (row, col), *tile_info))
    return images_tiles

Ahh it was a very silly mistake on my end where I just had to reformat the tuple. Just had to sleep on it and I knew how to fix it instantly. Thank you so much though for the quick response!

I’ll check out the new version of your code as well!