In fastai1 it was possible to set vision.data.open_image to a new function in order to customize the deserialization of images.
I can’t find how to do that in fastai2. Any ideas?
To clarify. I am aware that one way to do this is by using the mid-level API. However, I was wondering whether there is a higher level approach at the Block Data level, equivalent to the former open_image
load_image
is the new version.
Thank you very much, Zachary. I had seen that function in vision core but the description is very confusing
@muellerzr in fastai1 we use to override get function for custom open of image.
Which method of ImageDataLoaders shall we override in fastai2.
I researched this topic and it seems the best way is to pass a class to ImageBlock through the parameter cls that implements appropriate functions, particularly the “create” class method. You can use PilBase as a template or even inherit from it and override “create”.
You can find the source code for PilBase at
https://github.com/fastai/fastai/blob/master/fastai/vision/core.py#L98
and an example of passing a class to ImageBlock in the MNIST section of the datablock tutorial:
https://docs.fast.ai/tutorial.datablock.html
This is how i did
def load_image(fn,mode=None):
fn=str(fn)# some how cv2 throws an error without out
im=cv2.imread(str(fn))
im=to_image(im)
im = im._new(im.im)
#print(im.shape)
return im
class PILImage(PILBase):
@classmethod
def create(cls, fn:(Path,str,Tensor,ndarray,bytes), **kwargs)->None:
"Open an `Image` from path `fn`"
if isinstance(fn,TensorImage): fn = fn.permute(1,2,0).type(torch.uint8)
if isinstance(fn, TensorMask): fn = fn.type(torch.uint8)
if isinstance(fn,Tensor): fn = fn.numpy()
if isinstance(fn,ndarray): return cls(Image.fromarray(fn))
if isinstance(fn,bytes): fn = io.BytesIO(fn)
return cls(load_image(fn, **merge(cls._open_args, kwargs)))
def show(self, ctx=None, **kwargs):
"Show image using `merge(self._show_args, kwargs)`"
return show_image(self, ctx=ctx, **merge(self._show_args, kwargs))
def __repr__(self): return f'{self.__class__.__name__} mode={self.mode} size={"x".join([str(d) for d in self.size])}'
hubmap_seg = DataBlock(blocks=(ImageBlock(PILImage), MaskBlock()),
get_items=get_image_files,
splitter=splitter(val_ids,train_ids),
get_y=get_msk,
batch_tfms= b_tfms)
dsets = hubmap_seg.datasets(Path(TRAIN))
dsets.train[0][0]