I’m trying to setup a generic wrapper around standard kornia
transforms. I’m having some unexpected behavior with batch_tfms
when setting up some kornia
transforms as follows:
class KorniaWrapper(RandTransform):
'''
Pass in a kornia function, module, list of modules, or nn.Sequential
containers to `kornia_tfm`.
If passing functions, you can pass in function arguments as keyword
args (**kwargs), which can also be random number generators.
Example
=======
* KorniaWrapper(kornia.adjust_hue, hue_factor=1.2)
* KorniaWrapper(kornia.adjust_hue, hue_factor=np.random.random)
* KorniaWrapper(kornia.adjust_hue, hue_factor=partial(np.random.uniform, low=1.1, high=1.5))
* KorniaWrapper(kornia.augmentation.ColorJitter(.2,.3,.1,.2))
* KorniaWrapper(kornia.augmentation.ColorJitter, brightness=.2, contrast=.3)
* KorniaWrapper(nn.Sequential(*[kornia.augmentation.ColorJitter()]))
* KorniaWrapper([
kornia.augmentation.ColorJitter(.2),
kornia.augmentation.RandomMotionBlur(3, 5., 1.)
]))
'''
order = 10
def __init__(self, kornia_tfm=None,p=1., **kwargs):
super().__init__(p=p)
self.tfm = kornia_tfm
self.input_kwargs = kwargs
self.call_kwargs = dict.fromkeys(kwargs)
self._pipe = Pipeline([ToTensor(), IntToFloatTensor()])
self.process_tfm()
def before_call(self, b, split_idx, verbose=False):
'Compute `p` of applying transform, process input kwargs if applicable'
self.do = self.p==1. or random.random() < self.p
for arg,value in self.input_kwargs.items():
if hasattr(value, '__call__'): self.call_kwargs[arg] = value()
else: self.call_kwargs[arg] = value
def process_tfm(self):
'Process the input `kornia_tfm` argument and make it callable'
if hasattr(self.tfm, 'forward') and hasattr(self.tfm, '__iter__'):
pass ## -- nn.Sequential
elif hasattr(self.tfm, 'forward') and type(self.tfm) is not type:
self.tfm = nn.Sequential(self.tfm) ## -- Kornia module (called)
elif hasattr(self.tfm, 'forward') and type(self.tfm) is type:
#self.tfm = nn.Sequential(self.tfm) ## -- Kornia module (uncalled)
pass
elif isinstance(self.tfm, list):
self.tfm = nn.Sequential(*self.tfm) ## -- list of Kornia Modules
def _encode(self, o:TensorImage): return TensorImage(self.tfm(o, **self.call_kwargs)) if self.do else o
def encodes(self, o:torch.Tensor): return self._encode(o)
def encodes(self, o:Image.Image): return self._encode(self._pipe(PILImage(o)))
def encodes(self, o:TensorImage): return self._encode(o)
def encodes(self, o:PILImage): return self._encode(self._pipe(o))
def encodes(self, o:(str,Path)): return self._encode(self._pipe(PILImage.create(o)))
def encodes(self, o:(TensorCategory,TensorMultiCategory)): return o
def __repr__(self): return self.tfm.__repr__()
Kornia transforms setup:
## Setup Kornia Transforms
import kornia as K
Grayscale = KorniaWrapper(K.augmentation.RandomGrayscale(p=1.0), p=0.3)
StyleTfm = KorniaWrapper(K.filters.MedianBlur(kernel_size=(5,5)), p=0.2)
ColorJitter = KorniaWrapper(
K.augmentation.ColorJitter(brightness=.1,
contrast=0,
saturation=(.1,.9),
hue=.2),
p=0.6)
MotionBlur = KorniaWrapper(
K.augmentation.RandomMotionBlur(kernel_size = (7,7),
angle = (5., 15.),
direction = (-1., 1.)),
p=0.4)
Here’s how I’m constructing my DataBlock
:
## DataBlock
dblock = DataBlock(
blocks = (ImageBlock, CategoryBlock),
get_items = get_image_files,
get_x = Pipeline([PILImage.create]),
get_y = parent_label,
splitter = RandomSplitter(seed=42, valid_pct=0.),
item_tfms = [Resize(size=(400,400), method=ResizeMethod.Squish, pad_mode=PadMode.Zeros)],
#batch_tfms = Pipeline([Grayscale,ColorJitter, StyleTfm,MotionBlur,Normalize.from_stats(*imagenet_stats)])
batch_tfms = [ColorJitter, MotionBlur, Grayscale, StyleTfm, Normalize.from_stats(*imagenet_stats)]
)
If I understand correctly, batch_tfms
behind the scenes is being implemented as a Pipeline
. When I construct the pipeline individually as follows:
Pipeline([ColorJitter, MotionBlur, Grayscale, StyleTfm, Normalize.from_stats(*imagenet_stats)])
the __repr__
output is Pipeline: KorniaWrapper -> KorniaWrapper -> KorniaWrapper -> KorniaWrapper -> Normalize
, as is expected.
However, when trying out dblock.summary('data')
, the textual output for batch_tfms
suggests otherwise:
Setting up after_item: Pipeline: Resize -> ToTensor
Setting up before_batch: Pipeline:
Setting up after_batch: Pipeline: IntToFloatTensor -> KorniaWrapper -> Normalize
When looking at dblock.dataloaders('data').show_batch()
, I can see very clearly that only the last KorniaWrapper
i.e. StyleTfm
, was applied. It also oddly tries to apply the transform to the y-batch, which is why I explicitly added the def encodes(self, o:(TensorCategory,TensorMultiCategory)): return o
line to KorniaWrapper
.
I also tried passing in batch_tfms
as a Pipeline
object, but that failed with the following error:
TypeError: '<' not supported between instances of 'L' and 'int'
What am I missing?
Thanks for taking the time to read this. Any help is appreciated.
Related Post: Wrapping kornia module in fastai2 Transform
(ZCEWhitening)