Wrapping kornia module in fastai2 `Transform`

I’m working with ZCAWhitening which has been added in the kornia recently. You can find the implementation here

Mainly, this transform has following methods:

  1. fit: ZCAWhitening requires you to first fit it on the data, potentially on whole dataset, but one batch willl do as well.
  2. forward: Applies whitening transform to the data
  3. inverse_transform: inverse transform to the whitened data

Here’s my implementation:

class ZCAWhitenWrapper(Transform,GetAttr):
  "Wrapping kornia implementation"
  def __init__(self,**kwargs):
    self._zca = ZCAWhitening(**kwargs)

  def setups(self,dl:DataLoader):
    if not self._zca.fitted:
      x,*_ = dl.one_batch()
      self._zca = self._zca.fit(x)
  def encodes(self,x:TensorImage): return self._zca(x)
  def decodes(self,x:TensorImage): 
    if self._zca.compute_inv:
      x = self._zca.inverse_transform(x)
    return min_max_scale(x)

I would like to know what could have been done better? especially, I want to get rid of self._zca.
cc: @sgugger


The Cifar-10 preprocessing involves GlobalContrastNormalization and ZCAWhitening. The results of these steps are as follows:

Before preprocessing:

After preprocessing:

I’ve seen 2% error_rate improvement with GlobalContrastNormalization. Will update about ZCAWhitening soon.


This looks good to me, I don’t see how you could remove the _zca attribute.

I tried extending kornia.color.ZCAWhitening and Transform but this caused an issue with forward method of kornia Module(Unknown parameter split_idx) . Is it possible to extend any PyTorch Module alongside fastcore’s Transform ?

I don’t see how.

Seems like it’s not possible to inherit Transform and GetAttr and using that transform with DataBlock. It messes up with getattr() somewhere in the pipeline.

TypeError                                 Traceback (most recent call last)

<ipython-input-3-3b9d35837c1e> in <module>()
      6                                 GlobalContrastNorm,
      7                                 ZCAWhitenWrapper()])
----> 8 dblock.summary(path,device='cuda')

5 frames

/usr/local/lib/python3.6/dist-packages/fastai2/data/block.py in summary(self, source, bs, show_batch, **kwargs)
    158     print(f"\nFinal sample: {dsets.train[0]}\n\n")
--> 160     dls = self.dataloaders(source, bs=bs, verbose=True)
    161     print("\nBuilding one batch")
    162     if len([f for f in dls.train.after_item.fs if f.name != 'noop'])!=0:

/usr/local/lib/python3.6/dist-packages/fastai2/data/block.py in dataloaders(self, source, path, verbose, **kwargs)
    107         dsets = self.datasets(source)
    108         kwargs = {**self.dls_kwargs, **kwargs, 'verbose': verbose}
--> 109         return dsets.dataloaders(path=path, after_item=self.item_tfms, after_batch=self.batch_tfms, **kwargs)
    111     _docs = dict(new="Create a new `DataBlock` with other `item_tfms` and `batch_tfms`",

/usr/local/lib/python3.6/dist-packages/fastai2/data/core.py in dataloaders(self, bs, val_bs, shuffle_train, n, path, dl_type, dl_kwargs, device, **kwargs)
    202         dls = [dl] + [dl.new(self.subset(i), bs=(bs if val_bs is None else val_bs), shuffle=False, drop_last=False,
    203                              n=None, **dl_kwargs[i]) for i in range(1, self.n_subsets)]
--> 204         return self._dbunch_type(*dls, path=path, device=device)
    206 FilteredBase.train,FilteredBase.valid = add_props(lambda i,x: x.subset(i))

/usr/local/lib/python3.6/dist-packages/fastai2/data/core.py in __init__(self, path, device, *loaders)
    127     def __init__(self, *loaders, path='.', device=None):
    128         self.loaders,self.path = list(loaders),Path(path)
--> 129         self.device = device
    131     def __getitem__(self, i): return self.loaders[i]

/usr/local/lib/python3.6/dist-packages/fastai2/data/core.py in device(self, d)
    143     @device.setter
    144     def device(self, d):
--> 145         for dl in self.loaders: dl.to(d)
    146         self._device = d

/usr/local/lib/python3.6/dist-packages/fastai2/data/core.py in to(self, device)
    117         self.device = device
    118         for tfm in self.after_batch.fs:
--> 119             for a in L(getattr(tfm, 'parameters', None)): setattr(tfm, a, getattr(tfm, a).to(device))
    120         return self

TypeError: getattr(): attribute name must be string

Removing GetAttr solved the issue. Any workaround ?

I’m not 100% sure you need GetAttr here. Just seeing the doc string for what GetAttr is: "Inherit from this to have all attr accesses in self._xtra passed down to self.default"
And I don’t see you using self.default? (unless that’s _default). Though others can certainly chime in to inform me as well :slight_smile:

So should I use _xtra to be more specific about which attributes from _default to be passed down?

How to reverse the transform applied from kornia.augmentation ?

I found that you can set reverse_transform=True and the forward method will return the transformation matrix along with the output tensor. But I’m not sure how to use it. I tried doing:

tfm_xb,tfm_mat = RandomHorizontalFlip(p=1.,return_transform=True)(xb)
flat_xb = xb.view(4,3,-1)
xb_rev = torch.einsum('abc,abd->abd',tfm_mat,flat_xb).reshape(4,3,320,320)

But the output has something strange happening with image channels

@ducha-aiki could you help me with this ?


Sorry, I am not the familiar with einsum. The output transformation is perspective transform applied. You just need to inverse the matrix, that’s it

xb_rev = kornia.warp_affine(tfm_xb, tfm_mat.inverse()[:,:2,:], (tfm_xb.size(2), tfm_xb.size(3)))

Thanks! it worked like a charm :+1:

This looks great @kshitijpatil09.

Would you mind sharing a more elaborate code example that shows how to use this transform with the datablocks API?
Is it just added to batch_tfms?

Thanks :slight_smile:

Yes, it was possible to use this transform by simply adding it to batch_tfms. Apparently kornia hasn’t added ZCAWhitening to their stable release so my repo had this breaking dependency of installing kornia from sources. So I haven’t exported ZCAWhitenWrapper as of now. You can find learn more about these transforms in the following notebooks:



1 Like

Thanks Kshitij, that’s helpful. I’m having some issues with wrapping multiple transforms in batch_tfms. I’ve created a separate post here. I’d love to hear your thoughts.