Underlying `Transforms` machinery and `MetaClasses`

(Aman Arora) #1

Understanding TypeDispatch

So I’ve spent a bit of time trying to understand TypeDispatch , and it’s really powerful! Basically, its a dictionary between types and functions.

You can refer to type hierarchy here

Let’s dig deeper and you’ll see how powerful it is!

def __init__(self, *funcs):
    self.funcs,self.cache = {},{}
    for f in funcs: self.add(f)
    self.inst = None

The __init__ takes in a list of functions, and adds the list of functions to the dictionary with type:func mapping. Inside TypeDispatch the type is determined by the annotation of the first parameter of a function f .

Too confusing? Let’s put it together.

#export
class TypeDispatch:
    "Dictionary-like object; `__getitem__` matches keys of types using `issubclass`"
    def __init__(self, *funcs):
        self.funcs,self.cache = {},{}
        for f in funcs: self.add(f)
        self.inst = None

    def _reset(self):
        self.funcs = {k:self.funcs[k] for k in sorted(self.funcs, key=cmp_instance, reverse=True)}
        self.cache = {**self.funcs}

    def add(self, f):
        "Add type `t` and function `f`"
        self.funcs[_p1_anno(f) or object] = f
        self._reset()

    def __repr__(self): return str({getattr(k,'__name__',str(k)):v.__name__ for k,v in self.funcs.items()})

Let’s look at a simpler version of TypeDispatch

Now, let’s create a function:

def some_func(a:numbers.Integral, b:bool)->TensorImage: pass

and pass it to TypeDispatch

t = TypeDispatch(some_func); t

>>>{'Integral': 'some_func'}

Viola! TypeDispatch works…! BUT how?

Step-1: __init__ takes a bunch of functions or a single function. To start with, self.funcs and self.cache are empty as defined by self.funcs,self.cache = {},{}

Step-2: for f in funcs: self.add(f) loop through each function passed and add them to dictionary self.funcs using add .
Inside, add , check for the annotation of the first parameter of function f , if None then use type object and add it to self.funcs .
Thus inside self.funcs creating a mapping between type of first param of f and f itself.

Step-3: Reorder self.funcs dictionary basd on key cmp_instance which sets the order using Python’s type hierarchy in reverse order. Thus if you pass int and bool , the first item inside this dict will be bool .
Finally, make self.cache same as self.funcs . We use cache to loop up mapping later. Since lookup keys inside dict is order f(1) it’s much faster.

And finally we have __repr__ which just returns the mapping self.funcs but return f 's name and type 's name.
Reason why there is a getattr inside getattr(k,'__name__',str(k) is I think because it’s possible that a type doesn’t have __name__ attribute when we use MetaClasses .

Hopefully, this helps everyone! Please feel free to correct me if I understood something wrong.

We do reorder as Jeremy said in walk-thru 5, because we try to find the closest match from Transforms . Thus, for integer the closest match would first be int and not Numbers.Integral .

Also, inside docstring of __getitem__ : "Find first matching type that is a super-class of k "

Understanding TypeDispatch - Part 2

Here’s an insight!

So now that we know TypeDispatch is nothing but a pretty cool dict that looks something like:

{
bool: some_func1,
int: some_func2,
Numbers.Integral: some_func3 
}

ie., it is a mapping between type and the function that needs to be called on that specific type .

This is done through __call__ inside TypeDispatch ofcourse!

    def __call__(self, x, *args, **kwargs):
        f = self[type(x)]
        if not f: return x
        if self.inst: f = types.MethodType(f, self.inst)
        return f(x, *args, **kwargs)

f = self[type(x)] Check type of param being called ie., and look it up in TypeDispatch dict and call that function.
ie., foo(2) will return type(2) as int and then we lookup int which is coming from __getitem__ which simply returns the first matching type that is a super-class of type .

So we lookup inside self.cache which is also a mapping like

{
bool: some_func1,
int: some_func2,
Numbers.Integral: some_func3 
}

and we will find a function some_func2 for int . Thus, __getitem__ will return some_func2 as f .

So, f = self[type(x)] sets f as some_func2 .

This is the magic! We will call the specific function using __call__ for the specific type based on the parameter being passed!!

Thus when we pass a TensorImage, it will find the function that corresponds to TensorImage from inside dict and call it which is just as simple as return f(x, *args, **kwargs) !

How Transforms make use of TypeDispatch

Okay, here’s another one! I couldn’t have imagined that I will ever understand this part of V2, but now that I do, it just seems surreal! This is Python at a next level! And when you come to think of it, you can understand why it’s built this way.

But, lets discuss the thought process a little later.

First let’s understand encodes and decodes inside Transform !

So, from _TfmDict

class _TfmDict(dict):
    def __setitem__(self,k,v):
        if k=='_': k='encodes'
        if k not in ('encodes','decodes') or not isinstance(v,Callable): return super().__setitem__(k,v)
        if k not in self: super().__setitem__(k,TypeDispatch())
        res = self[k]
        res.add(v)

As long as something is not of type encodes or decodes the namespace of the cls would be created using dict as per normal behavior. Note, that __setitem__ is responsible for setting k:v inside dict , thus if you update that, you can get custom behavior!

So as long as something is not encodes or decodes , just use dict to set k:v .

BUT, when it is encodes or decodes then k:TypeDispatch()

And as we know - TypeDispatch is nothing but a cool dict of type:function mapping!

So theoretically speaking, the namespace of this special class which is a subclass of TfmMeta will look something like

{....all the usual stuff like __module__:__main__etc AND 
encodes: 
    {
     bool: some_func1,
     int: some_func2, 
     Numbers.Integral: some_func3 
    }, 
decodes: 
    {
     bool: some_reverse_func1,
     int: some_reverse_func2, 
     Numbers.Integral: some_reverse_func3 
    }, 

And finally ! When you call encodes or decodes - it can be done so for different types, which will be called using __call__ inside TypeDispatch which then call the specific corresponding function to type !

It is all making sense now.

1 Like

(Konstantin Dorichev) #2

I guess, only admins can create wiki posts.

0 Likes

(Aman Arora) #3

Thanks @kdorichev, I wrote that by mistake :slight_smile:

0 Likes

(Aman Arora) #4

Okay, so in my quest to find an answer to this I believe I’ve picked up something else about MetaClasses which is worth sharing.

I think the three most important functions in understand Transforms in V2 are:

class Transform(metaclass=_TfmMeta):
    "Delegates (`__call__`,`decode`) to (`encodes`,`decodes`) if `filt` matches"
    filt,init_enc,as_item_force,as_item,order = None,False,None,True,0
    def __init__(self, enc=None, dec=None, filt=None, as_item=False):
        self.filt,self.as_item = ifnone(filt, self.filt),as_item
        self.init_enc = enc or dec
        if not self.init_enc: return

        # Passing enc/dec, so need to remove (base) class level enc/dec
        del(self.__class__.encodes,self.__class__.decodes)
        self.encodes,self.decodes = (TypeDispatch(),TypeDispatch())
        if enc:
            self.encodes.add(enc)
            self.order = getattr(self.encodes,'order',self.order)
        if dec: self.decodes.add(dec)

    @property
    def use_as_item(self): return ifnone(self.as_item_force, self.as_item)
    def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)
    def decode  (self, x, **kwargs): return self._call('decodes', x, **kwargs)
    def __repr__(self): return f'{self.__class__.__name__}: {self.use_as_item} {self.encodes} {self.decodes}'

    def _call(self, fn, x, filt=None, **kwargs):
        if filt!=self.filt and self.filt is not None: return x
        f = getattr(self, fn)
        if self.use_as_item or not is_listy(x): return self._do_call(f, x, **kwargs)
        res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
        return retain_type(res, x)

    def _do_call(self, f, x, **kwargs):
        return x if f is None else retain_type(f(x, **kwargs), x, f.returns_none(x))

Ofcourse, the class Transform itself.

Then, its metaclass _TfmMeta

#export
class _TfmMeta(type):
    def __new__(cls, name, bases, dict):
        print("I'm alive inside `__new__` in `_TfmMeta`")
        res = super().__new__(cls, name, bases, dict)
        res.__signature__ = inspect.signature(res.__init__)
        return res

    def __call__(cls, *args, **kwargs):
        f = args[0] if args else None
        n = getattr(f,'__name__',None)
        if not hasattr(cls,'encodes'): cls.encodes=TypeDispatch()
        if not hasattr(cls,'decodes'): cls.decodes=TypeDispatch()
        if isinstance(f,Callable) and n in ('decodes','encodes','_'):
            getattr(cls,'encodes' if n=='_' else n).add(f)
            return f
        return super().__call__(*args, **kwargs)

    @classmethod
    def __prepare__(cls, name, bases): return _TfmDict()

And finally _TfmDict

#export
class _TfmDict(dict):
    def __setitem__(self,k,v):
        if k=='_': k='encodes'
        if k not in ('encodes','decodes') or not isinstance(v,Callable): return super().__setitem__(k,v)
        if k not in self: super().__setitem__(k,TypeDispatch())
        res = self[k]
        res.add(v)

According, to section 3.3.3.4. of Python Data Model,

Once the appropriate metaclass has been identified, then the class namespace is prepared. If the metaclass has a __prepare__ attribute, it is called as namespace = metaclass.__prepare__(name, bases, **kwds) (where the additional keyword arguments, if any, come from the class definition).

Which means when we first define our class Transform like so:
class Transform(metaclass=_TfmMeta):

The __prepare__ method inside _TfmMeta get’s called which in turn changes the way a normal __prepare__ method would work by using _TfmDict which inherits from dict.

Since the __setitem__ is updated inside _TfmDict, if k ie., the attribute to be set, is not encodes or decodes the normal machinery of dict is used ie., super().__setitem__(k,v)

From my understanding, while preparing the Transform namespace the encodes or decodes does not actually get passed and the final dict that we get at this stage is

{'__module__': '__main__', '__qualname__': 'Transform', '__doc__': 'Delegates (`__call__`,`decode`) to (`encodes`,`decodes`) if `filt` matches', 'filt': None, 'init_enc': False, 'as_item_force': None, 'as_item': True, 'order': 0, '__init__': <function Transform.__init__ at 0x7f2f3ae80d08>, 'use_as_item': <property object at 0x7f2f3ae97818>, '__call__': <function Transform.__call__ at 0x7f2f3ae80c80>, 'decode': <function Transform.decode at 0x7f2f3ae80bf8>, '__repr__': <function Transform.__repr__ at 0x7f2f3ae88048>, '_call': <function Transform._call at 0x7f2f3ae88268>, '_do_call': <function Transform._do_call at 0x7f2f3ae881e0>, '__return__': None}

Once, the dict is setup, __new__ inside _TfmMeta get’s called which I think comes from here and inside the Python Data Model it is said:

When using the default metaclass type , or any metaclass that ultimately calls type.__new__ , the following additional customisation steps are invoked after creating the class object:

  • first, type.__new__ collects all of the descriptors in the class namespace that define a __set_name__() method;
  • second, all of these __set_name__ methods are called with the class being defined and the assigned name of that particular descriptor;
  • finally, the __init_subclass__() hook is called on the immediate parent of the new class in its method resolution order.
1 Like

Fastai v2 code walk-thru 5
(Jeremy Howard (Admin)) #5

I’ve made the top post a wiki.

1 Like

(Malcolm McLean) #6

A suggestion. fastai introduces hidden changes to Python’s default behavior and uses various non-standard idioms to accomplish common tasks. I’m thinking of things such as metaclasses, parameter type dispatch, return type conversion, instance variable definitions, various decorators, documentation strings, etc. The intent of course is to make the code more concise, efficient, and less prone to bugs. However, for the naive reader, these unusual constructions are confusing and obtuse. I remember spending a lot of time puzzling out the purpose and function of such constructions in fastai v1, with not a word of explanation to be found.

So how about gathering all these fastai “idioms” into one place, and for each one…

  • what it is, and its intent and purpose
  • what it does, and the “ordinary Python” equivalent if there is one
  • how it works (implementation details)
  • when and where to use it

We would also benefit from standard templates for adding classes, transforms, functions, etc. to fastai. Templates that show where and how to initialize, handle kwargs, implement show(), and place documentation strings, for example, That way, contributions would be held to standards of expected functionality.

I feel hesitant to make such a suggestion, because my own Python skills are not up to creating such a guide*. But having a “guide to the idioms” for v1 would have saved me an enormous amount of time and frustration, as well as several questions to @sgugger. So for the sake of us mid-level Python coders…

What do you think?

*I can however offer to edit docs for clarity and completeness. Perhaps having an editor who is not so familiar with the codebase is an advantage for making the docs understandable to users and contributors.

1 Like

(Aman Arora) #7

To further extend the understanding of Metaclasses, here is a simple and easy example:

class _TfmDict(dict):
    def __setitem__(self,k,v): 
        return super().__setitem__(k,v)

class Meta(type): 
    @classmethod
    def __prepare__(self, name, bases, **kwargs):
        print("I am alive")
        dict = _TfmDict()
        dict['surname'] = 'Arora'
        return dict

class A(metaclass=Meta): pass

print(f"\nA.__name__ : {A.__name__}\nA.__qualname__: {A.__qualname__}\n\
A.__name__ is A.__qualname__: {A.__name__ is A.__qualname__}\nA.surname: {A.surname}")

>>>
I am alive

A.__name__ : A
A.__qualname__: A
A.__name__ is A.__qualname__: True
A.surname: Arora

This example uses __setitem__ similar of dict inside _TfmDict function that get’s passed to __prepare__

My experiments have helped me understand this:

class Meta(type): 
    @classmethod
    def __prepare__(self, name, bases, **kwargs):
        print("I am alive")
#         return {"some_object": "some_val"}
        
class A(metaclass=Meta): pass

>>>
I am alive
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-28-5cfa9b9c2ed3> in <module>
      5 #         return {"some_object": "some_val"}
      6 
----> 7 class A(metaclass=Meta): pass

TypeError: Meta.__prepare__() must return a mapping, not NoneType

So we can see that the __prepare__ method of metaclass Meta was invoked straight away.
We also now know that __prepare___() must return a mapping.

class Meta(type): 
    @classmethod
    def __prepare__(self, name, bases, **kwargs):
        print("I am alive")
        return {"some_object": "some_val"}
        
class A(metaclass=Meta): pass

>>>
I am alive

So, this works. Great!

More about __prepare__ can be found here

0 Likes