Equivalent of open_image in fastai2

In fastai1 it was possible to set vision.data.open_image to a new function in order to customize the deserialization of images.
I can’t find how to do that in fastai2. Any ideas?

To clarify. I am aware that one way to do this is by using the mid-level API. However, I was wondering whether there is a higher level approach at the Block Data level, equivalent to the former open_image

load_image is the new version.

3 Likes

Thank you very much, Zachary. I had seen that function in vision core but the description is very confusing

@muellerzr in fastai1 we use to override get function for custom open of image.
Which method of ImageDataLoaders shall we override in fastai2.

I researched this topic and it seems the best way is to pass a class to ImageBlock through the parameter cls that implements appropriate functions, particularly the “create” class method. You can use PilBase as a template or even inherit from it and override “create”.
You can find the source code for PilBase at
https://github.com/fastai/fastai/blob/master/fastai/vision/core.py#L98

and an example of passing a class to ImageBlock in the MNIST section of the datablock tutorial:
https://docs.fast.ai/tutorial.datablock.html

This is how i did

def load_image(fn,mode=None):
    
    fn=str(fn)# some how cv2 throws an error without out
    im=cv2.imread(str(fn))
    im=to_image(im)
    im = im._new(im.im)
    #print(im.shape)
    
    return im
 
class PILImage(PILBase): 
    @classmethod
    def create(cls, fn:(Path,str,Tensor,ndarray,bytes), **kwargs)->None:
        "Open an `Image` from path `fn`"
        if isinstance(fn,TensorImage): fn = fn.permute(1,2,0).type(torch.uint8)
        if isinstance(fn, TensorMask): fn = fn.type(torch.uint8)
        if isinstance(fn,Tensor): fn = fn.numpy()
        if isinstance(fn,ndarray): return cls(Image.fromarray(fn))
        if isinstance(fn,bytes): fn = io.BytesIO(fn)
        return cls(load_image(fn, **merge(cls._open_args, kwargs)))

    def show(self, ctx=None, **kwargs):
        "Show image using `merge(self._show_args, kwargs)`"
        return show_image(self, ctx=ctx, **merge(self._show_args, kwargs))

    def __repr__(self): return f'{self.__class__.__name__} mode={self.mode} size={"x".join([str(d) for d in self.size])}'


hubmap_seg = DataBlock(blocks=(ImageBlock(PILImage), MaskBlock()),
                       get_items=get_image_files,
                       splitter=splitter(val_ids,train_ids),
                       get_y=get_msk,
                       batch_tfms= b_tfms)
dsets = hubmap_seg.datasets(Path(TRAIN))
dsets.train[0][0]
1 Like