Using `kornia` with V2 in a generic way

I’m trying to setup a generic wrapper around standard kornia transforms. I’m having some unexpected behavior with batch_tfms when setting up some kornia transforms as follows:

class KorniaWrapper(RandTransform):
    Pass in a kornia function, module, list of modules, or nn.Sequential
    containers to `kornia_tfm`.
    If passing functions, you can pass in function arguments as keyword
    args (**kwargs), which can also be random number generators.
    * KorniaWrapper(kornia.adjust_hue, hue_factor=1.2)
    * KorniaWrapper(kornia.adjust_hue, hue_factor=np.random.random)
    * KorniaWrapper(kornia.adjust_hue, hue_factor=partial(np.random.uniform, low=1.1, high=1.5))
    * KorniaWrapper(kornia.augmentation.ColorJitter(.2,.3,.1,.2))
    * KorniaWrapper(kornia.augmentation.ColorJitter, brightness=.2, contrast=.3)
    * KorniaWrapper(nn.Sequential(*[kornia.augmentation.ColorJitter()]))
    * KorniaWrapper([
        kornia.augmentation.RandomMotionBlur(3, 5., 1.)
    order = 10
    def __init__(self, kornia_tfm=None,p=1., **kwargs):
        self.tfm = kornia_tfm
        self.input_kwargs = kwargs
        self.call_kwargs  = dict.fromkeys(kwargs)
        self._pipe = Pipeline([ToTensor(), IntToFloatTensor()])
    def before_call(self, b, split_idx, verbose=False):
        'Compute `p` of applying transform, process input kwargs if applicable' = self.p==1. or random.random() < self.p
        for arg,value in self.input_kwargs.items():
            if hasattr(value, '__call__'): self.call_kwargs[arg] = value()
            else: self.call_kwargs[arg] = value
    def process_tfm(self):
        'Process the input `kornia_tfm` argument and make it callable'
        if hasattr(self.tfm, 'forward') and hasattr(self.tfm, '__iter__'):
            pass                                ## -- nn.Sequential

        elif hasattr(self.tfm, 'forward') and type(self.tfm) is not type:      
            self.tfm = nn.Sequential(self.tfm)  ## -- Kornia module (called)
        elif hasattr(self.tfm, 'forward') and type(self.tfm) is type:      
            #self.tfm = nn.Sequential(self.tfm)  ## -- Kornia module (uncalled)
        elif isinstance(self.tfm, list):
            self.tfm = nn.Sequential(*self.tfm) ## -- list of Kornia Modules
    def _encode(self, o:TensorImage): return TensorImage(self.tfm(o, **self.call_kwargs)) if else o
    def encodes(self, o:torch.Tensor): return self._encode(o)
    def encodes(self, o:Image.Image):  return self._encode(self._pipe(PILImage(o)))
    def encodes(self, o:TensorImage):  return self._encode(o)
    def encodes(self, o:PILImage):     return self._encode(self._pipe(o))
    def encodes(self, o:(str,Path)):   return self._encode(self._pipe(PILImage.create(o)))
    def encodes(self, o:(TensorCategory,TensorMultiCategory)): return o
    def __repr__(self): return self.tfm.__repr__()

Kornia transforms setup:

## Setup Kornia Transforms
import kornia as K

Grayscale   = KorniaWrapper(K.augmentation.RandomGrayscale(p=1.0), p=0.3)
StyleTfm    = KorniaWrapper(K.filters.MedianBlur(kernel_size=(5,5)), p=0.2)
ColorJitter = KorniaWrapper(

MotionBlur  = KorniaWrapper(
    K.augmentation.RandomMotionBlur(kernel_size = (7,7),
                                    angle       = (5., 15.),
                                    direction   = (-1., 1.)),

Here’s how I’m constructing my DataBlock:

## DataBlock 
dblock = DataBlock(
    blocks     = (ImageBlock, CategoryBlock),
    get_items  = get_image_files,
    get_x      = Pipeline([PILImage.create]),
    get_y      = parent_label,
    splitter   = RandomSplitter(seed=42, valid_pct=0.),
    item_tfms  = [Resize(size=(400,400), method=ResizeMethod.Squish, pad_mode=PadMode.Zeros)],
    #batch_tfms = Pipeline([Grayscale,ColorJitter, StyleTfm,MotionBlur,Normalize.from_stats(*imagenet_stats)])
    batch_tfms = [ColorJitter, MotionBlur, Grayscale, StyleTfm, Normalize.from_stats(*imagenet_stats)]

If I understand correctly, batch_tfms behind the scenes is being implemented as a Pipeline. When I construct the pipeline individually as follows:

Pipeline([ColorJitter, MotionBlur, Grayscale, StyleTfm, Normalize.from_stats(*imagenet_stats)])

the __repr__ output is Pipeline: KorniaWrapper -> KorniaWrapper -> KorniaWrapper -> KorniaWrapper -> Normalize, as is expected.

However, when trying out dblock.summary('data'), the textual output for batch_tfms suggests otherwise:

Setting up after_item: Pipeline: Resize -> ToTensor
Setting up before_batch: Pipeline: 
Setting up after_batch: Pipeline: IntToFloatTensor -> KorniaWrapper -> Normalize

When looking at dblock.dataloaders('data').show_batch(), I can see very clearly that only the last KorniaWrapper i.e. StyleTfm, was applied. It also oddly tries to apply the transform to the y-batch, which is why I explicitly added the def encodes(self, o:(TensorCategory,TensorMultiCategory)): return o line to KorniaWrapper.

I also tried passing in batch_tfms as a Pipeline object, but that failed with the following error:

TypeError: '<' not supported between instances of 'L' and 'int'

What am I missing?
Thanks for taking the time to read this. Any help is appreciated.

Related Post: Wrapping kornia module in fastai2 Transform (ZCEWhitening)


I’d like to hear your thoughts about few things:

  1. How you’re going to deal with split_idx? For instance, I’d like to apply RandomResizedCrop for train and ResizeCenterCrop for validation split.
  2. In the case of Object Detection or any Keypoint regression, how you’ll make sure that the same translation is being applied to Coordinates as well. Of course, we can take advantage of type-dispatch by storing the transform matrix applied to the image and applying the same to coordinates in y (bounding-box/keypoints)

The main advantage of using Kornia over fastai transforms is they apply different transformations within batch as well, whereas fastai applies the same to whole batch. So, a contrived way to deal with all of these incompatibilities would be, writing a Callback instead. But again, I’m not sure we would be able to solve the split_idx issue.


Great questions. I suppose I should rename the title of this post to “Using kornia with V2 in a generic way for image classifiers”

  1. I’m not sure. I never apply any augmentation to my validation sets, so I hadn’t even considered it until you just mentioned it
  2. I think you answered it better than I would have :slight_smile: (never trained such models yet, so never considered it)

The idea to use them as a callback is interesting. I think Learner has an in_training attribute that would totally solve the split_idx issue you brought up. The one downside is that this means that using WandbCallback would not log all the transforms you’d really be using – but that probably only a concern for a minority of users.

BTW, I still can’t figure out what’s causing only the last of the transforms to be called in the example posted on the very top of this post.

The workaround I’ve been using is to wrap all the transforms in an nn.Sequential object, but that removes the ability to apply individual probabilities for applying them (where kornia doesn’t inherently have a p argument)

You can use transforms.RandomApply for the transforms that aren’t inherently “random” like so:

color_jitter = K.augmentation.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)

Wrapping individual transform would be needed when you’ve different behavior for different types, like images would be resized using Bilinear interpolation and masks using nearest neighbour. If the set of trasforms are only meant for single fastai type, wrapping them in nn.Sequential and then in some fastai wrapper seems a better idea.

I think if we create separate wrappers for train_tfms and test_tfms like:

class TrainTfms(Transform):
    def __init__(self,tfm):
    def encodes(self,x:TensorImage): return self.tfm(x)
    def __repr__(self): <listing all the transforms>

class ValidTfms(Transform):
    def __init__(self,tfm):
    def encodes(self,x:TensorImage): return self.tfm(x)
    def __repr__(self): <listing all the transforms>

it would be possible to include them in batch_tfms as well. Override __repr__ method to list all the underlying transform so that it won’t just appear like KorniaWrapper

I had already stumbled upon this problem and had asked for some way to pass these parameters through constructor.

this might be related to order of transform, try assigning different order for each transform

1 Like

All of your kornia transforms inside the datablock are from the same class (KorniaWrapper). It’s possible that this code that is used inside the DataBlock is removing all but the last transform. From it’s docs:

"Group the `tfms` in a single list, removing duplicates (from the same class) and instantiating"

Hey, apologies for the delayed response, I was working under a tight deadline.

I like your idea of separate wrappers for train and valid transforms.


  1. Where does transforms.RandomApply come from? torchvision? Perhaps I missed it, but I couldn’t find it in the kornia library. I tried the version from torchvision; that isn’t compatible with this pipeline (I ran a test which I now forget how to reproduce -_-).

  2. Assuming I’m right and (1) doesn’t work, wrapping them in an nn.Sequential object would mean foregoing the option of setting a distinct p per transform.

  3. I concur that the transforms must be listed. def __repr__(self): return self.tfm.__repr__() defined above already does so

Tested Observations

Great catch, this helped a lot!

I’ve been using this solution and can confirm that it works.

The idea is to use KorniaWrapper (perhaps KorniaBase would be a more apt name in this scenario) as a base class to wrap individual transforms. In my use case, I’m using ColorJitter and RandomMotionBlur:

class MotionBlur(KorniaBase):
    def __init__(self, p=.2, kernel_size=(3,21), angle=(15., 15.), direction=(-1., 1.)):
            kornia_tfm = K.augmentation.RandomMotionBlur(kernel_size = kernel_size,
                                                         angle       = angle,
                                                         direction   = direction),
class ColorJitter(KorniaBase):
    def __init__(self, p=.2, jitter_brightness=.1, jitter_contrast=.1, jitter_saturation=(.1, .9), jitter_hue=.2):
            kornia_tfm = K.augmentation.ColorJitter(brightness = jitter_brightness,
                                                    contrast   = jitter_contrast,
                                                    saturation = jitter_saturation,
                                                    hue        = jitter_hue),

NOTE: If MotionBlur's order is not after Normalize (order=100), the loss shoots to NaN.

dls = DataBlock(
    batch_tfms=[MotionBlur(p=0.25, kernel_size=(3,21),


  • Verbose
  • If you wanted to use them in the validation set, we’ll need to add a split_idx argument, and possibly add some code redundancy


  • Works
  • Clarity

Interesting. I went down this exact same path of first trying to use Albumentations, then moving to Kornia as albu was too slow.

Yes. transforms.RandomApply is from torchvision

Exactly! if we had the privilege of assigning split_idx and order through constructor parameters, a handy function would have done the job for us. Maybe we need to learn more about meta-programming and fastcore’s metaclasses.

Are you guys running into issue where Kornia expects a tensor, but fastai uses this TensorImage object? Also getting compatibility errors that fastai uses int8 for the tensor format?