Is it available in fastai v2 or planning on being implemented at a later point?
For those interested in learning more it is based on a paper from Yarin Gal https://arxiv.org/abs/1506.02142 it’s a very cool topic and very simple way to check a models certainty (i.e. measuring std in predictions made).
PR’s are always welcome However that being said, I’d say make one but don’t be surprised if it takes a few weeks to be approved, Jeremy wants the initial release of the library to not have too many extra bells and whistles, but it’s certainly a welcomed implementation nonetheless!
Do you know if we can use callbacks on the get_preds? I find this function to be very hard to read, compared to the fit/
I do see that the validation is done in line:
def apply_dropout(self, m):
"If a module contains dropout in it's name, it will be switched to .train() mode."
if 'dropout' in m.__class__.__name__.lower():
m.train()
We can’t currently, however I think that could be adjusted. If we follow the context managers (what the callbacks truly are), we’ll see the following:
def get_preds(self, ds_idx=1, dl=None, with_input=False, with_decoded=False, with_loss=False, act=None,
inner=False, reorder=True, **kwargs):
if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False)
if reorder and hasattr(dl, 'get_idxs'):
idxs = dl.get_idxs()
dl = dl.new(get_idxs = _ConstantFunc(idxs))
cb = GatherPredsCallback(with_input=with_input, with_loss=with_loss, **kwargs)
ctx_mgrs = [self.no_logging(), self.added_cbs(cb), self.no_mbar()]
if with_loss: ctx_mgrs.append(self.loss_not_reduced())
You’ll notice at no point can we pass any new callbacks. So I think the right approach would be to optionally pass in cbs, then add GatherPredsCallback to cbs and pass that into self.added_cbs. Do you follow what I mean?
And then from there it’s as simple as making a Monte Carlo callback
Well, that’s easy enough! Good eye. I’d absolutely put in a PR, Jeremy will comment more on it as to if he wants it open until the official release, or just directly integrated now
Unless you have your own get_preds separately, no. Because there is zero opportunity for a callback to be passed into get_preds that is actually run during it. GatherPreds is the only callback being used. (If you try following the ExitStack portions, none of them pass any callbacks).
@jeremy can interject if my understanding is wrong, but this is how the code reads.
Because with validate we already have them in there. IE we can pass cbs into validate, it’s added to added_cbs and they then run _do_epoch_validate (which is where everything is run, get_preds, etc)
With get_preds it’s just get the predictions, which is only a mini-step in the prior. There’s no opportunity to attach on new callbacks before you reach the ctx_mgrs step (with self.added_cbs)
I got it, but I was not working for me on get_preds, it is now.
Thanks Zach.
Submitted the PR to modify get_preds
Now that the PR has been merged, you can use a callback to predict with the Dropout turned on.
Do you have a working example of your fastai v2 callback to make MC dropout work? I would really like to play with this. Checked fastai2 source code and there’s no MDropoutCB.
def flatten_model(el):
flattened = [flatten(children) for children in el.children()]
res = [el]
for c in flattened:
res += c
return res
class MCDropoutCallback(Callback):
def before_validate(self):
for m in [m for m in flatten_model(self.model) if 'dropout' in m.__class__.__name__.lower()]:
m.train()
def after_validate(self):
for m in [m for m in flatten_model(self.model) if 'dropout' in m.__class__.__name__.lower()]:
m.eval()
Then, you have to loop through get_preds multiple times. Here is an example:
all_imgs = []
b = learn.dls[1].one_batch()
input_img, nums, targets = to_detach(b)
for i in range(10):
preds, targs = learn.get_preds(dl=[b], cbs=[MCDropoutCallback()])
imgs, _, _, _ = preds
all_imgs += [imgs]
all_imgs = np.stack(all_imgs)
all_imgs.shape
Which returns a shape of [10, 128, 1, 28, 28] corresponding to [samples from distribution, batch size, one channel, height, width] in my case.