MC Dropout in fastai V2

Hi,
as pre-requisite of my student project, I implemented a callback that allows to do MC dropout in fastai v2 to estimate uncertainty:

class McDropout(Callback):
    "`Callback` that activates Droupout layers at inference time to generate uncertainty in predictions"
    def __init__(self, with_input=False, with_loss=False, save_preds=None, save_targs=None, concat_dim=0):
        
        self._active = False
        
        store_attr(self, "with_input,with_loss,save_preds,save_targs,concat_dim")
    
    def collect_samples(self, n, *args, **kwargs):
        """
        Collect n samples using MC dropout. Studying the distribution of these answers
        you can estimate the epistemic uncertainty of this classifier
        
        :param *args: positional arguments that will be forwarded to Learner.predict
        :param **kwargs: keyword arguments that will be forwarded to Learner.predict
        :return a list of predictions
        """
        self.activate()
        
        try:
            
            res = [self.predict(*args, **kwargs) for _ in range(n)]
        
        finally:
            
            self.deactivate()
        
        return res
    
    def activate(self):
        
        self._active = True
    
    def deactivate(self):
        
        self._active = False
    
    def begin_batch(self):
        
        if self._active:
        
            main_mod = next(self.model.children())

            for mod in main_mod.children():

                if 'dropout' in mod.__class__.__name__.lower():
                    mod.train()

    def after_batch(self):
        
        if self._active:
            
            main_mod = next(self.model.children())

            for mod in main_mod.children():

                if 'dropout' in mod.__class__.__name__.lower():
                    mod.eval()

Usage:

mc_dropout = McDropout()
learn = Learner(dls, raw_model, metrics=error_rate, cbs=[mc_dropout])
# This collects 10 MC samples for the provided files
mc_dropout.collect_samples(10, fns[1])
# or
mc_dropout.collect_samples(10, learn.dls.valid.dataset[2][0])
# or any other way of specifying the input supported by .predict

I have one question: how can I make it so that I evaluate on n copies of the input as one batch? In other words, I want to substitute the loop in .collect_samples with one computation on one batch made of n copies of the input image.

Also, how can I deactivate the progress bar when calling .predict? It does not make much sense to show a progress bar for 1 element.

Any comment appreciated! It’s my first attempt at coding within the fastai ecosystem (instead of just using it :slight_smile: )

4 Likes

Should be able to do something like this:

ex=torch.randn([3,28,28])
n=10
ex[None].expand(n,-1,-1,-1).shape #-1 just means use the shape of ex

2 Likes

Thanks for posting this. Set me on the right path for implementing MC Dropout. I didn’t think of using Callback. Couple of things I noticed when trying to use this:

  1. It’s before_batch not begin_batch
  2. The loop over main_mod.children() won’t necessarily find all dropout layers (e.g. in ResNets). Better loop over self.model.modules()
2 Likes