Monte Carlo Dropout fastai v2

Hi,

In fastai v1 there was an implementation of Monte Carlo Dropout functionality introduced here PR: Ability to use dropout at prediction time (Monte Carlo Dropout)

Is it available in fastai v2 or planning on being implemented at a later point?

For those interested in learning more it is based on a paper from Yarin Gal https://arxiv.org/abs/1506.02142 it’s a very cool topic and very simple way to check a models certainty (i.e. measuring std in predictions made).

3 Likes

We could try to propose a PR?
Probably it is just a Callback to ensure that dropout layers are set to train?
@muellerzr probably knows the answer…

PR’s are always welcome :slight_smile: However that being said, I’d say make one but don’t be surprised if it takes a few weeks to be approved, Jeremy wants the initial release of the library to not have too many extra bells and whistles, but it’s certainly a welcomed implementation nonetheless!

Do you know if we can use callbacks on the get_preds? I find this function to be very hard to read, compared to the fit/
I do see that the validation is done in line:

self._do_epoch_validate(dl=dl)

So if I am correct, just implementing a callback:

def begin_validate(self):
   apply_dropout(self.model)

Should do the trick.

def apply_dropout(self, m):
    "If a module contains dropout in it's name, it will be switched to .train() mode."
    if 'dropout' in m.__class__.__name__.lower():
        m.train()
2 Likes

We can’t currently, however I think that could be adjusted. If we follow the context managers (what the callbacks truly are), we’ll see the following:

    def get_preds(self, ds_idx=1, dl=None, with_input=False, with_decoded=False, with_loss=False, act=None,
                  inner=False, reorder=True, **kwargs):
        if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False)
        if reorder and hasattr(dl, 'get_idxs'):
            idxs = dl.get_idxs()
            dl = dl.new(get_idxs = _ConstantFunc(idxs))
        cb = GatherPredsCallback(with_input=with_input, with_loss=with_loss, **kwargs)
        ctx_mgrs = [self.no_logging(), self.added_cbs(cb), self.no_mbar()]
        if with_loss: ctx_mgrs.append(self.loss_not_reduced())

You’ll notice at no point can we pass any new callbacks. So I think the right approach would be to optionally pass in cbs, then add GatherPredsCallback to cbs and pass that into self.added_cbs. Do you follow what I mean?

And then from there it’s as simple as making a Monte Carlo callback

I just added a callback to the learner with:

learn.cbs += [MDropoutCB]
1 Like

Well, that’s easy enough! :slight_smile: Good eye. I’d absolutely put in a PR, Jeremy will comment more on it as to if he wants it open until the official release, or just directly integrated now

It does not work … =(.

Can you try with the modification?

I don’t get what you are proposing Zach. Modifying get_preds to suuport added callbacks?
You don’t think there is a way without modifying the lib?

Unless you have your own get_preds separately, no. Because there is zero opportunity for a callback to be passed into get_preds that is actually run during it. GatherPreds is the only callback being used. (If you try following the ExitStack portions, none of them pass any callbacks).

@jeremy can interject if my understanding is wrong, but this is how the code reads.

something like this:

ctx_mgrs = [self.no_logging(), self.added_cbs(L(cbs)+[cb]), self.no_mbar()]

validate already has this option.

1 Like

Then I do believe that’s where the PR would need to be at :slight_smile:

I am trying to work out why with validate it works, but with get_preds it does not.

Because with validate we already have them in there. IE we can pass cbs into validate, it’s added to added_cbs and they then run _do_epoch_validate (which is where everything is run, get_preds, etc)

With get_preds it’s just get the predictions, which is only a mini-step in the prior. There’s no opportunity to attach on new callbacks before you reach the ctx_mgrs step (with self.added_cbs)

The exact two lines of code I am comparing is here at get_preds:

        cb = GatherPredsCallback(with_input=with_input, with_loss=with_loss, **kwargs)
        ctx_mgrs = [self.no_logging(), self.added_cbs(cb), self.no_mbar()]

vs here on validate:

    def validate(self, ds_idx=1, dl=None, cbs=None):
        if dl is None: dl = self.dls[ds_idx]
        with self.added_cbs(cbs), self.no_logging(), self.no_mbar():

I got it, but I was not working for me on get_preds, it is now.
Thanks Zach.
Submitted the PR to modify get_preds
Now that the PR has been merged, you can use a callback to predict with the Dropout turned on.

3 Likes

Do you have a working example of your fastai v2 callback to make MC dropout work? I would really like to play with this. Checked fastai2 source code and there’s no MDropoutCB.

Thanks!

No,
but we could implement that easily.
There two main thing one can do:

  • A Callback to activate the dropouts and create a single output. (simple)
  • A Callback that modifies get_preds to generate multiple outputs at one (a distribution of outputs).
  • Or create a new mc_get_preds that does this.
1 Like

For anyone interested:

def flatten_model(el):
    flattened = [flatten(children) for children in el.children()]
    res = [el]
    for c in flattened:
        res += c
    return res

class MCDropoutCallback(Callback):
    def before_validate(self):
        for m in [m for m in flatten_model(self.model) if 'dropout' in m.__class__.__name__.lower()]:
            m.train()
    
    def after_validate(self):
        for m in [m for m in flatten_model(self.model) if 'dropout' in m.__class__.__name__.lower()]:
            m.eval()

Then, you have to loop through get_preds multiple times. Here is an example:

all_imgs = []
b = learn.dls[1].one_batch()
input_img, nums, targets = to_detach(b)

for i in range(10):
    preds, targs = learn.get_preds(dl=[b], cbs=[MCDropoutCallback()])
    imgs, _, _, _ = preds
    all_imgs += [imgs]
    
all_imgs = np.stack(all_imgs)
all_imgs.shape

Which returns a shape of [10, 128, 1, 28, 28] corresponding to [samples from distribution, batch size, one channel, height, width] in my case.

6 Likes