PR: Ability to use dropout at prediction time (Monte Carlo Dropout)

PR link: https://github.com/fastai/fastai/pull/2132#issue-284247316

We’ve been researching into different methods using to figure out how to find what predictions your model is ‘sure’ or ‘unsure’ of at Max Kelsen.

And we found a few use cases, particularly Uber’s use case for measuring uncertainty in neural networks. Where they talk about the use of Monte Carlo Dropout as a way to make many predictions on a sample and then use the variance of the predictions as a measure of certainty.

We’ve used these on a text classification problem and they’ve proven helpful.

The major changes are all within the Learner class in basic_train.py:

Adding apply_dropout():

def apply_dropout(self, m):
    "If a module contains dropout in it's name, it will be switched to .train() mode."
    if 'dropout' in m.__class__.__name__.lower():
        m.train()

Adding predict_with_mc_dropout():

def predict_with_MC_dropout(self, item:ItemBase, return_x:bool=False, batch_first:bool=True, with_dropout:bool=True, n_times=10, **kwargs):
    "Predict with dropout turned on for n_times (default 10)."
    predictions = []
    for _ in range(n_times):
        predictions.append(self.predict(item, with_dropout=with_dropout))
    return predictions

Altering pred_batch():

def pred_batch(self, ds_type:DatasetType=DatasetType.Valid, batch:Tuple=None, reconstruct:bool=False, with_dropout:bool=False) -> List[Tensor]:
    "Return output of the model on one batch from `ds_type` dataset."
    if batch is not None: xb,yb = batch
    else: xb,yb = self.data.one_batch(ds_type, detach=False, denorm=False)
    cb_handler = CallbackHandler(self.callbacks)
    xb,yb = cb_handler.on_batch_begin(xb,yb, train=False)
    
    # with_dropout clause
    if not with_dropout:
        preds = loss_batch(self.model.eval(), xb, yb, cb_handler=cb_handler)
    else:
        # Apply dropout at eval() time
        preds = loss_batch(self.model.eval().apply(self.apply_dropout), xb, yb, cb_handler=cb_handler)
        
    res = _loss_func2activ(self.loss_func)(preds[0])
    if not reconstruct: return res
    res = res.detach().cpu()
    ds = self.dl(ds_type).dataset
    norm = getattr(self.data, 'norm', False)
    if norm and norm.keywords.get('do_y',False):
        res = self.data.denorm(res, do_x=True)
    return [ds.reconstruct(o) for o in res]

Altering predict():

def predict(self, item:ItemBase, return_x:bool=False, batch_first:bool=True, with_dropout:bool=False, **kwargs):
    "Return predicted class, label and probabilities for `item`."
    batch = self.data.one_item(item)
    
    # Added with_dropout
    res = self.pred_batch(batch=batch, with_dropout=with_dropout) 
    
    raw_pred,x = grab_idx(res,0,batch_first=batch_first),batch[0]
    norm = getattr(self.data,'norm',False)
    if norm:
        x = self.data.denorm(x)
        if norm.keywords.get('do_y',False): raw_pred = self.data.denorm(raw_pred)
    ds = self.data.single_ds
    pred = ds.y.analyze_pred(raw_pred, **kwargs)
    x = ds.x.reconstruct(grab_idx(x, 0))
    y = ds.y.reconstruct(pred, x) if has_arg(ds.y.reconstruct, 'x') else ds.y.reconstruct(pred)
    return (x, y, pred, raw_pred) if return_x else (y, pred, raw_pred)

Use:

Using predict_with_mc_dropout(n_times:int=10) will return a list of 10 predictions all made with dropout turned on (essentially 10 predictions with slightly different models), of which you can take the variance and use it as a measure of how certain your model is about a certain prediction.

Low variance = low uncertainty
High variance = high uncertainty

Pros:

  • Gives you more insight into what your model doesn’t know. This is particularly useful if you want to take things to production (e.g. automatically classify the samples with 0 variance and send the rest to a human classifier).

  • Instead of having a completely black box neural network making predictions at X% accuracy, you can make predictions at X% accuracy and have a complementary metric of how certain each prediction is.

Cons:

  • Increases prediction inference time by n_times.
6 Likes

Just made a few style changes, I’ll let you look at them and tell me if something is off before I merge.

Thanks for that! All looks good to go.

@sgugger @mrdbourke
Thanks for this great functionality! I am new to fastAI and have a quick, perhaps silly question. I can’t seem to figure out how to predict with mc dropout for an entire dataset. Essentially I want to call learn.get_preds but with mc dropout. How can I do this without having to loop through the entire dataset one point at a time and calling predict_with_mc_dropout on it?

Thanks!

1 Like

Thank you for this contribution, I was about to recode it when I found out that it was already done!

For users wanting to get uncertainty data for their models, I would recommend reading the work of Yarin Gal who published the dropout as a way to get uncertainty idea as a part of his PhD thesis.

He has a blog post with very nice animations and the proper equations to convert the variance you get when you run the model several time with the dropout activated into a proper uncertainty on the output.

Thanks for doing this! If I understand correctly what you did, @mrdbourke and @sgugger , what you implemented is with a different dropout mask for each data record. Did I understand that right?

I was looking at something that @zlapp also seems to have looked at, namely BatchBALD. There it is crucial* that for one sample of predictions you use the same dropout mask for every data record (of course, to characterize uncertainty, you want different masks for each sample from the predictions). Maybe I did not dig deeply enough, but I got the impression that that’s not part of the current solution?

Or is this already easily possible? (If not, any pointers on how to do it, would be most welcome…)

* My attempt at providing an approx. intuition of what BatchBALD tries to achieve is: If you select a single new record to label, you simply pick a record with a lot of variation in the predictions (you get that by just getting predictions with dropout at inference time repeatedly). However, when you select a whole batch of records for labelling, then you also want diversity in the records to learn about different things. The idea here is that if the predictions for two records from different models are quite highly correlated, then they are not all that diverse and if you only pick a few records, you may just one of these two and may want to select a different record that is changing in a less related way (even if the “univariate” uncertainty about it may be smaller).