Understanding TypeDispatch
So I’ve spent a bit of time trying to understand TypeDispatch
, and it’s really powerful! Basically, its a dictionary between types and functions.
You can refer to type hierarchy here
Let’s dig deeper and you’ll see how powerful it is!
def __init__(self, *funcs):
self.funcs,self.cache = {},{}
for f in funcs: self.add(f)
self.inst = None
The __init__
takes in a list of functions, and adds the list of functions to the dictionary with type:func
mapping. Inside TypeDispatch
the type
is determined by the annotation of the first parameter of a function f
.
Too confusing? Let’s put it together.
#export
class TypeDispatch:
"Dictionary-like object; `__getitem__` matches keys of types using `issubclass`"
def __init__(self, *funcs):
self.funcs,self.cache = {},{}
for f in funcs: self.add(f)
self.inst = None
def _reset(self):
self.funcs = {k:self.funcs[k] for k in sorted(self.funcs, key=cmp_instance, reverse=True)}
self.cache = {**self.funcs}
def add(self, f):
"Add type `t` and function `f`"
self.funcs[_p1_anno(f) or object] = f
self._reset()
def __repr__(self): return str({getattr(k,'__name__',str(k)):v.__name__ for k,v in self.funcs.items()})
Let’s look at a simpler version of TypeDispatch
Now, let’s create a function:
def some_func(a:numbers.Integral, b:bool)->TensorImage: pass
and pass it to TypeDispatch
t = TypeDispatch(some_func); t
>>>{'Integral': 'some_func'}
Viola! TypeDispatch
works…! BUT how?
Step-1: __init__
takes a bunch of functions or a single function. To start with, self.funcs
and self.cache
are empty as defined by self.funcs,self.cache = {},{}
Step-2: for f in funcs: self.add(f)
loop through each function passed and add them to dictionary self.funcs
using add
.
Inside, add
, check for the annotation of the first parameter of function f
, if None
then use type object
and add it to self.funcs
.
Thus inside self.funcs
creating a mapping between type
of first param of f
and f
itself.
Step-3: Reorder self.funcs
dictionary basd on key cmp_instance
which sets the order using Python’s type hierarchy in reverse order. Thus if you pass int
and bool
, the first item inside this dict will be bool
.
Finally, make self.cache
same as self.funcs
. We use cache
to loop up mapping later. Since lookup keys inside dict
is order f(1)
it’s much faster.
And finally we have __repr__
which just returns the mapping self.funcs
but return f
's name and type
's name.
Reason why there is a getattr
inside getattr(k,'__name__',str(k)
is I think because it’s possible that a type doesn’t have __name__
attribute when we use MetaClasses
.
Hopefully, this helps everyone! Please feel free to correct me if I understood something wrong.
We do reorder as Jeremy said in walk-thru 5, because we try to find the closest match from Transforms
. Thus, for integer
the closest match would first be int
and not Numbers.Integral
.
Also, inside docstring of __getitem__
: "Find first matching type that is a super-class of k
"
Understanding TypeDispatch
- Part 2
Here’s an insight!
So now that we know TypeDispatch
is nothing but a pretty cool dict
that looks something like:
{
bool: some_func1,
int: some_func2,
Numbers.Integral: some_func3
}
ie., it is a mapping between type
and the function that needs to be called on that specific type
.
This is done through __call__
inside TypeDispatch
ofcourse!
def __call__(self, x, *args, **kwargs):
f = self[type(x)]
if not f: return x
if self.inst: f = types.MethodType(f, self.inst)
return f(x, *args, **kwargs)
f = self[type(x)]
Check type of param being called ie., and look it up in TypeDispatch
dict and call that function.
ie., foo(2)
will return type(2)
as int
and then we lookup int
which is coming from __getitem__
which simply returns the first matching type that is a super-class of type
.
So we lookup inside self.cache
which is also a mapping like
{
bool: some_func1,
int: some_func2,
Numbers.Integral: some_func3
}
and we will find a function some_func2
for int
. Thus, __getitem__
will return some_func2
as f
.
So, f = self[type(x)]
sets f
as some_func2
.
This is the magic! We will call the specific function using __call__
for the specific type based on the parameter being passed!!
Thus when we pass a TensorImage, it will find the function that corresponds to TensorImage
from inside dict
and call it which is just as simple as return f(x, *args, **kwargs)
!
How Transforms make use of TypeDispatch
Okay, here’s another one! I couldn’t have imagined that I will ever understand this part of V2, but now that I do, it just seems surreal! This is Python at a next level! And when you come to think of it, you can understand why it’s built this way.
But, lets discuss the thought process a little later.
First let’s understand encodes
and decodes
inside Transform
!
So, from _TfmDict
class _TfmDict(dict):
def __setitem__(self,k,v):
if k=='_': k='encodes'
if k not in ('encodes','decodes') or not isinstance(v,Callable): return super().__setitem__(k,v)
if k not in self: super().__setitem__(k,TypeDispatch())
res = self[k]
res.add(v)
As long as something is not of type encodes
or decodes
the namespace of the cls
would be created using dict
as per normal behavior. Note, that __setitem__
is responsible for setting k:v
inside dict
, thus if you update that, you can get custom behavior!
So as long as something is not encodes
or decodes
, just use dict
to set k:v
.
BUT, when it is encodes
or decodes
then k:TypeDispatch()
And as we know - TypeDispatch
is nothing but a cool dict
of type:function
mapping!
So theoretically speaking, the namespace of this special class which is a subclass of TfmMeta
will look something like
{....all the usual stuff like __module__:__main__etc AND
encodes:
{
bool: some_func1,
int: some_func2,
Numbers.Integral: some_func3
},
decodes:
{
bool: some_reverse_func1,
int: some_reverse_func2,
Numbers.Integral: some_reverse_func3
},
And finally ! When you call encodes
or decodes
- it can be done so for different types, which will be called using __call__
inside TypeDispatch
which then call the specific corresponding function to type
!
It is all making sense now.