[SOLVED] Patching Failed to Alter ImageBlock's Behavior in DataBlock?

Background
I am trying to build DataBlock for a Kaggle dataset (Human Protein Atlas 2019). Its a multi-labels classification problem with metadata stored as a csv like this:

Id 	Target
00070df0-bbc3-11e8-b2bc-ac1f6b6435d0 	16 0
000a6c98-bb9b-11e8-b2b9-ac1f6b6435d0 	7 1 2 0
000a9596-bbc4-11e8-b2bc-ac1f6b6435d0 	5

Id is the ID to a protein sample, and Target is the target mutli-labels. Each protein sample is associated with 4 image paths (e.g. ./00070df0-bbc3-11e8-b2bc-ac1f6b6435d0_red.png, ./00070df0-bbc3-11e8-b2bc-ac1f6b6435d0_green.png, ./00070df0-bbc3-11e8-b2bc-ac1f6b6435d0_blue.png, ./00070df0-bbc3-11e8-b2bc-ac1f6b6435d0_yellow.png)

Objectives
I wanna build a simple DataBlock to read the data. As a starter, I target to read 3 of the image paths and concatenate them into a 3-channel image (for each protein sample). So firstly I defined a function that maps a Id to a tuple of 3 image paths.

import os
from pathlib import Path
from functools import partial
from typing import Tuple, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from fastai.vision.all import *

TRAIN_DIR = Path('/kaggle/input/human-protein-atlas-image-classification')

def get_channel_paths(row: pd.Series, is_full: bool = False) -> Tuple[Path, Path, Path]:
    img_id = row.Id
    colors = ['red', 'green', 'blue']
    if is_full:
        colors += ['yellow']
    channel_fns = tuple(map(lambda color: TRAIN_DIR/'train'/f'{img_id}_{color}.png', colors))
    assert os.path.isfile(channel_fns[0])
    return channel_fns

and then I tried to alter PILBase.create behavior so that I can reuse ImageBlock for my DataBlock. The way I altered it is to enable this method to handle tuple of Paths as input, as follows:

@patch(cls_method = True)
def create(cls: PILBase, fn: (Path,str,Tensor,ndarray,bytes,Tuple[str, str, str]), **kwargs)->None:
    "Open an `Image` from path `fn`"
    if isinstance(fn,TensorImage): fn = fn.permute(1,2,0).type(torch.uint8)
    if isinstance(fn, TensorMask): fn = fn.type(torch.uint8)
    if isinstance(fn,Tensor): fn = fn.numpy()
    # handle tuple of image paths for HPA 2019
    if isinstance(fn,tuple):
        channel_imgs = [Image.open(img_path) for img_path in fn]
        fn = np.stack(channel_imgs, axis = -1)
    if isinstance(fn,ndarray): return cls(Image.fromarray(fn))
    if isinstance(fn,bytes): fn = io.BytesIO(fn)
    return cls(load_image(fn, **merge(cls._open_args, kwargs)))

Finally, I create my DataBlock as follows (omit batch_tfms and other args for simplicity and debugging purpose):

data_blk = DataBlock(blocks = (ImageBlock, MultiCategoryBlock),
                     get_x = partial(get_channel_paths, is_full = False), 
                     get_y = ColReader(1, label_delim = ' '))
ds = data_blk.datasets(df)
ds.train[0]

Problem
Unexpectedly, I got the following outcome from the above command. I expect the output is a tuple of ImageTensor and TensorMultiCategory, but got tuple of Path as first entry instead:

((Path('/kaggle/input/human-protein-atlas-image-classification/train/04e6f8f8-bb9b-11e8-b2b9-ac1f6b6435d0_red.png'),
  Path('/kaggle/input/human-protein-atlas-image-classification/train/04e6f8f8-bb9b-11e8-b2b9-ac1f6b6435d0_green.png'),
  Path('/kaggle/input/human-protein-atlas-image-classification/train/04e6f8f8-bb9b-11e8-b2b9-ac1f6b6435d0_blue.png')),
 TensorMultiCategory([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))

I suspected the cause is that my dispatched PILBase.create failed to be in place in this DataBlock call. Any advice on how to resolve this?

a few findings after some investigations:

  1. Somewhere after DataBlock.datasets is called. Each functions set in get_x (and get_y) are wrapped by Transform.
  2. When the function is wrapped by Transform, it has special handling on input of tuple type, in the sense that the function is applied on each element in the tuple. This may be the reason why my get_x function lead to unexpected behavior (coz it outputs a tuple)
  3. More importantly of all, something strange happens when I firstly patched PILBase.create like the above and then wrap it by Transform. After wrapped by Transform, it failed to function on inputs of any types. It unexpectedly act as an identity function. So seemingly, the issue happens when the function is wrapped by Transform, and type_dispatch failed to work.

To illustrate point 3:

@patch(cls_method = True)
def create(cls: PILBase, fn: (Path,str,Tensor,ndarray,bytes,Tuple[Path, Path, Path]), **kwargs)->None:
    "Open an `Image` from path `fn`"
    if isinstance(fn,TensorImage): fn = fn.permute(1,2,0).type(torch.uint8)
    if isinstance(fn, TensorMask): fn = fn.type(torch.uint8)
    if isinstance(fn,Tensor): fn = fn.numpy()
    # handle tuple of image paths for HPA 2019
    if isinstance(fn, tuple):
        channel_imgs = [Image.open(img_path) for img_path in fn]
        fn = np.stack(channel_imgs, axis = -1)
    if isinstance(fn,ndarray): return cls(Image.fromarray(fn))
    if isinstance(fn,bytes): fn = io.BytesIO(fn)
    return cls(load_image(fn, **merge(cls._open_args, kwargs)))

create_tfms = Transform(PILBase.create)
create_tfms(Path('/kaggle/input/human-protein-atlas-image-classification/train/000a9596-bbc4-11e8-b2bc-ac1f6b6435d0_red.png'))
# OUTPUT: Path('/kaggle/input/human-protein-atlas-image-classification/train/000a9596-bbc4-11e8-b2bc-ac1f6b6435d0_red.png')

I start to understand the underlying mechanics of fastai2’s framework, but still need time to figure out how to resolve the problem

Eventually, I did a walk-around by manually defining a substitute of PILBase and then let ImageBlock pointing to the substitute. Not a clean way but works:

class AlterPILBase(Image.Image, metaclass=BypassNewMeta):
    _bypass_type=Image.Image
    _show_args = {'cmap':'viridis'}
    _open_args = {'mode': 'RGB'}
    
    @classmethod
    def create(cls, fn:(Path,str,Tensor,ndarray,bytes,list), **kwargs)->None:
        "Open an `Image` from path `fn`"
        if isinstance(fn,TensorImage): fn = fn.permute(1,2,0).type(torch.uint8)
        if isinstance(fn, TensorMask): fn = fn.type(torch.uint8)
        if isinstance(fn,Tensor): fn = fn.numpy()
        
        # handle tuple of image paths for HPA 2019
        if isinstance(fn, list):
            channel_imgs = [Image.open(img_path) for img_path in fn]
            fn = np.stack(channel_imgs, axis = -1)
        
        # return PILImage object for consistency
        if isinstance(fn,ndarray): 
            return PILImage(Image.fromarray(fn))
        if isinstance(fn,bytes): fn = io.BytesIO(fn)
        return PILImage(load_image(fn, **merge(cls._open_args, kwargs)))

    def show(self, ctx=None, **kwargs):
        "Show image using `merge(self._show_args, kwargs)`"
        return show_image(self, ctx=ctx, **merge(self._show_args, kwargs))

    def __repr__(self): return f'{self.__class__.__name__} mode={self.mode} size={"x".join([str(d) for d in self.size])}'


data_blk = DataBlock(blocks = (partial(ImageBlock, cls = AlterPILBase), MultiCategoryBlock),
                    get_x = partial(get_channel_paths, is_full = False),
                    get_y = ColReader(1, label_delim=' '))
ds = data_blk.datasets(df)
ds.train[0]

meanwhile, if anyone have a cleaner solution to this. I would love to know!