Fastai v2 transforms / pipeline / data blocks

Understanding TypeDispatch

So I’ve spent a bit of time trying to understand TypeDispatch , and it’s really powerful! Basically, its a dictionary between types and functions.

You can refer to type hierarchy here

Let’s dig deeper and you’ll see how powerful it is!

def __init__(self, *funcs):
    self.funcs,self.cache = {},{}
    for f in funcs: self.add(f)
    self.inst = None

The __init__ takes in a list of functions, and adds the list of functions to the dictionary with type:func mapping. Inside TypeDispatch the type is determined by the annotation of the first parameter of a function f .

Too confusing? Let’s put it together.

#export
class TypeDispatch:
    "Dictionary-like object; `__getitem__` matches keys of types using `issubclass`"
    def __init__(self, *funcs):
        self.funcs,self.cache = {},{}
        for f in funcs: self.add(f)
        self.inst = None

    def _reset(self):
        self.funcs = {k:self.funcs[k] for k in sorted(self.funcs, key=cmp_instance, reverse=True)}
        self.cache = {**self.funcs}

    def add(self, f):
        "Add type `t` and function `f`"
        self.funcs[_p1_anno(f) or object] = f
        self._reset()

    def __repr__(self): return str({getattr(k,'__name__',str(k)):v.__name__ for k,v in self.funcs.items()})

Let’s look at a simpler version of TypeDispatch

Now, let’s create a function:

def some_func(a:numbers.Integral, b:bool)->TensorImage: pass

and pass it to TypeDispatch

t = TypeDispatch(some_func); t

>>>{'Integral': 'some_func'}

Viola! TypeDispatch works…! BUT how?

Step-1: __init__ takes a bunch of functions or a single function. To start with, self.funcs and self.cache are empty as defined by self.funcs,self.cache = {},{}

Step-2: for f in funcs: self.add(f) loop through each function passed and add them to dictionary self.funcs using add .
Inside, add , check for the annotation of the first parameter of function f , if None then use type object and add it to self.funcs .
Thus inside self.funcs creating a mapping between type of first param of f and f itself.

Step-3: Reorder self.funcs dictionary basd on key cmp_instance which sets the order using Python’s type hierarchy in reverse order. Thus if you pass int and bool , the first item inside this dict will be bool .
Finally, make self.cache same as self.funcs . We use cache to loop up mapping later. Since lookup keys inside dict is order f(1) it’s much faster.

And finally we have __repr__ which just returns the mapping self.funcs but return f 's name and type 's name.
Reason why there is a getattr inside getattr(k,'__name__',str(k) is I think because it’s possible that a type doesn’t have __name__ attribute when we use MetaClasses .

Hopefully, this helps everyone! Please feel free to correct me if I understood something wrong.

We do reorder as Jeremy said in walk-thru 5, because we try to find the closest match from Transforms . Thus, for integer the closest match would first be int and not Numbers.Integral .

Also, inside docstring of __getitem__ : "Find first matching type that is a super-class of k "

Understanding TypeDispatch - Part 2

Here’s an insight!

So now that we know TypeDispatch is nothing but a pretty cool dict that looks something like:

{
bool: some_func1,
int: some_func2,
Numbers.Integral: some_func3 
}

ie., it is a mapping between type and the function that needs to be called on that specific type .

This is done through __call__ inside TypeDispatch ofcourse!

    def __call__(self, x, *args, **kwargs):
        f = self[type(x)]
        if not f: return x
        if self.inst: f = types.MethodType(f, self.inst)
        return f(x, *args, **kwargs)

f = self[type(x)] Check type of param being called ie., and look it up in TypeDispatch dict and call that function.
ie., foo(2) will return type(2) as int and then we lookup int which is coming from __getitem__ which simply returns the first matching type that is a super-class of type .

So we lookup inside self.cache which is also a mapping like

{
bool: some_func1,
int: some_func2,
Numbers.Integral: some_func3 
}

and we will find a function some_func2 for int . Thus, __getitem__ will return some_func2 as f .

So, f = self[type(x)] sets f as some_func2 .

This is the magic! We will call the specific function using __call__ for the specific type based on the parameter being passed!!

Thus when we pass a TensorImage, it will find the function that corresponds to TensorImage from inside dict and call it which is just as simple as return f(x, *args, **kwargs) !

How Transforms make use of TypeDispatch

Okay, here’s another one! I couldn’t have imagined that I will ever understand this part of V2, but now that I do, it just seems surreal! This is Python at a next level! And when you come to think of it, you can understand why it’s built this way.

But, lets discuss the thought process a little later.

First let’s understand encodes and decodes inside Transform !

So, from _TfmDict

class _TfmDict(dict):
    def __setitem__(self,k,v):
        if k=='_': k='encodes'
        if k not in ('encodes','decodes') or not isinstance(v,Callable): return super().__setitem__(k,v)
        if k not in self: super().__setitem__(k,TypeDispatch())
        res = self[k]
        res.add(v)

As long as something is not of type encodes or decodes the namespace of the cls would be created using dict as per normal behavior. Note, that __setitem__ is responsible for setting k:v inside dict , thus if you update that, you can get custom behavior!

So as long as something is not encodes or decodes , just use dict to set k:v .

BUT, when it is encodes or decodes then k:TypeDispatch()

And as we know - TypeDispatch is nothing but a cool dict of type:function mapping!

So theoretically speaking, the namespace of this special class which is a subclass of TfmMeta will look something like

{....all the usual stuff like __module__:__main__etc AND 
encodes: 
    {
     bool: some_func1,
     int: some_func2, 
     Numbers.Integral: some_func3 
    }, 
decodes: 
    {
     bool: some_reverse_func1,
     int: some_reverse_func2, 
     Numbers.Integral: some_reverse_func3 
    }, 

And finally ! When you call encodes or decodes - it can be done so for different types, which will be called using __call__ inside TypeDispatch which then call the specific corresponding function to type !

8 Likes