Hi,
as pre-requisite of my student project, I implemented a callback that allows to do MC dropout in fastai v2 to estimate uncertainty:
class McDropout(Callback):
"`Callback` that activates Droupout layers at inference time to generate uncertainty in predictions"
def __init__(self, with_input=False, with_loss=False, save_preds=None, save_targs=None, concat_dim=0):
self._active = False
store_attr(self, "with_input,with_loss,save_preds,save_targs,concat_dim")
def collect_samples(self, n, *args, **kwargs):
"""
Collect n samples using MC dropout. Studying the distribution of these answers
you can estimate the epistemic uncertainty of this classifier
:param *args: positional arguments that will be forwarded to Learner.predict
:param **kwargs: keyword arguments that will be forwarded to Learner.predict
:return a list of predictions
"""
self.activate()
try:
res = [self.predict(*args, **kwargs) for _ in range(n)]
finally:
self.deactivate()
return res
def activate(self):
self._active = True
def deactivate(self):
self._active = False
def begin_batch(self):
if self._active:
main_mod = next(self.model.children())
for mod in main_mod.children():
if 'dropout' in mod.__class__.__name__.lower():
mod.train()
def after_batch(self):
if self._active:
main_mod = next(self.model.children())
for mod in main_mod.children():
if 'dropout' in mod.__class__.__name__.lower():
mod.eval()
Usage:
mc_dropout = McDropout()
learn = Learner(dls, raw_model, metrics=error_rate, cbs=[mc_dropout])
# This collects 10 MC samples for the provided files
mc_dropout.collect_samples(10, fns[1])
# or
mc_dropout.collect_samples(10, learn.dls.valid.dataset[2][0])
# or any other way of specifying the input supported by .predict
I have one question: how can I make it so that I evaluate on n copies of the input as one batch? In other words, I want to substitute the loop in .collect_samples with one computation on one batch made of n copies of the input image.
Also, how can I deactivate the progress bar when calling .predict? It does not make much sense to show a progress bar for 1 element.
Any comment appreciated! It’s my first attempt at coding within the fastai ecosystem (instead of just using it )