Fastai2 Test Time Augmentation

Hi all,

I’m trying to understand how test time augmentation is implemented in fastai2.

In fastai1, the default tta looks like it had 4 corner crops along with a horizontal flip + some other things

In fastai2, what data augmentations does FA2 use? Are they set to the validation dataloader settings by default? I was initially looking for a the basic 4 corners + horizontal flip. I saw this in the FA1 source code, but did not see it in FA2. source code.

Daniel Lam

It’s doing the exact same transforms that were applied to the training set. We know this due to this line here:

with dl.dataset.set_split_idx(0), self.no_mbar():

It then gathers these predictions over n times (whatever we pass in for n, it defaults to 4), before finally grabbing a fifth one that does just the validation transforms (we see this via with dl.dataset.set_split_idx(1): preds,targs = self.get_preds(dl=dl, inner=True))

(split_idx controls when each transform is done. 0 = training and 1 = validation)

1 Like

In the learner arguments, ds_idx=1, then dl=self.dls[ds_idx]
Does this point to the validation dataloader?

def tta(self:Learner, ds_idx=1, dl=None, n=4, item_tfms=None, batch_tfms=None, beta=0.25, use_max=False):
    "Return predictions on the `ds_idx` dataset or `dl` using Test Time Augmentation"
    if dl is None: dl = self.dls[ds_idx]
    if item_tfms is not None or batch_tfms is not None: dl = dl.new(after_item=item_tfms, after_batch=batch_tfms)
    with dl.dataset.set_split_idx(0), self.no_mbar():

Yes it does. And we override the validation dataloader’s transforms to do the training dataloaders by setting that split_idx

Here’s my understanding of data transformation usage

  1. train dataloader : train augmentations
  2. validation dataloader: filtered train augmentations
  3. tta dataloader: filtered train augmentations

Isn’t the
dl.dataset.set_split_idx(0)
data loader already using the validation dataloader because of
dl = self.dls[ds_idx] where ds_idx = 1
So, we’re not using the training set transformations, but a filtered version (similar to 2).

By the way, thanks for the help.

No, we are. We set the training dataloader transformations in that call I mentioned above. So they’re non-filtered. The validation filter does, yes, if the particular transformation has special behavior that occurs with only a split_idx of 0 (which then must have a split_idx of 1). If no split_idx is given (IE it is None) then the behavior across training and validation is the exact same.

So what we do is we generate four different dataloaders with training augmentations (there’s some randomness that can be involved with said transforms that’s why they’re different), and finally we have a fifth dataloader that is the validation’s transforms.

Does this make things clear? We override those transforms with the set_split_idx. Which when we override it it is only for the transforms so they know what to apply. The validation set’s is set to 1 by default when you generate it

The reason we can do this is only the item_tfms are done beforehand, whereas the batch_tfms are done on the fly

Ok, I will look into the set_split_idx method. I must be misunderstanding something. Thanks.

1 Like

Hi all, I’m having some trouble with the .tta method. After loading in the dataloader exactly like I did for training:

dls=ImageDataLoaders.from_folder('/content/', train='train', valid='val', 
item_tfms=Resize(460), batch_tfms=aug_transforms(flip_vert=True, max_rotate=45.0, size=224))

and loading the saved model as:

learn=load_learner(model_path)

I’ve verified that the learner works and is loaded correctly (e.g. learn.predict(path_to_img) works).

learn.tta(dl=dls)

Errors out giving

epoch	train_loss	valid_loss	accuracy	time
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-68-bf2efe30447d> in <module>()
----> 1 learn.tta(dl=dls)

13 frames
/usr/local/lib/python3.6/dist-packages/fastai/learner.py in tta(self, ds_idx, dl, n, item_tfms, batch_tfms, beta, use_max)
    558     if item_tfms is not None or batch_tfms is not None: dl = dl.new(after_item=item_tfms, after_batch=batch_tfms)
    559     try:
--> 560         self(_before_epoch)
    561         with dl.dataset.set_split_idx(0), self.no_mbar():
    562             if hasattr(self,'progress'): self.progress.mbar = master_bar(list(range(n)))

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in __call__(self, event_name)
    131     def ordered_cbs(self, event): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, event)]
    132 
--> 133     def __call__(self, event_name): L(event_name).map(self._call_one)
    134 
    135     def _call_one(self, event_name):

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in map(self, f, *args, **kwargs)
    381              else f.format if isinstance(f,str)
    382              else f.__getitem__)
--> 383         return self._new(map(g, self))
    384 
    385     def filter(self, f, negate=False, **kwargs):

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    331     @property
    332     def _xtra(self): return None
--> 333     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    334     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
    335     def copy(self): return self._new(self.items.copy())

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
     45             return x
     46 
---> 47         res = super().__call__(*((x,) + args), **kwargs)
     48         res._newchk = 0
     49         return res

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    322         if items is None: items = []
    323         if (use_list is not None) or not _is_array(items):
--> 324             items = list(items) if use_list else _listify(items)
    325         if match is not None:
    326             if is_coll(match): match = len(match)

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in _listify(o)
    258     if isinstance(o, list): return o
    259     if isinstance(o, str) or _is_array(o): return [o]
--> 260     if is_iter(o): return list(o)
    261     return [o]
    262 

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __call__(self, *args, **kwargs)
    224             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    225         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 226         return self.fn(*fargs, **kwargs)
    227 
    228 # Cell

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _call_one(self, event_name)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name), event_name
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in <listcomp>(.0)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name), event_name
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

/usr/local/lib/python3.6/dist-packages/fastai/callback/core.py in __call__(self, event_name)
     42                (self.run_valid and not getattr(self, 'training', False)))
     43         res = None
---> 44         if self.run and _run: res = getattr(self, event_name, noop)()
     45         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     46         return res

/usr/local/lib/python3.6/dist-packages/fastai/callback/progress.py in before_epoch(self)
     21 
     22     def before_epoch(self):
---> 23         if getattr(self, 'mbar', False): self.mbar.update(self.epoch)
     24 
     25     def before_train(self):    self._launch_pbar()

/usr/local/lib/python3.6/dist-packages/fastprogress/fastprogress.py in update(self, val)
     92             yield o
     93 
---> 94     def update(self, val): self.main_bar.update(val)
     95 
     96 # Cell

/usr/local/lib/python3.6/dist-packages/fastprogress/fastprogress.py in update(self, val)
     55             self.pred_t,self.last_v,self.wait_for = 0,0,1
     56             self.update_bar(0)
---> 57         elif val <= self.first_its or val >= self.last_v + self.wait_for or val >= self.total:
     58             cur_t = time.time()
     59             avg_t = (cur_t - self.start_t) / val

AttributeError: 'NBProgressBar' object has no attribute 'wait_for'

Any idea what I’m doing wrong and how to fix it?

Just in case it’s helpful: I’m running on Colab using the course setup:

!pip install -Uqq fastbook

import fastbook

fastbook.setup_book()

!pip install utils

from utils import *
1 Like

Note quite the same, but using notebook 07_sizing-and-tta from an anaconda install today I have the following error:
‘NBProgressBar’ object has no attribute ‘start_t’

2 Likes

[I was incorrect]

1 Like

Hey Ben, test time augmentation (tta) is used for the test set, after you have trained. It’s been a while since I dug into the code, but I would guess tta(n=10) does not affect your fine_tune. Chapter/Notebook 7 should have an example.

If you’re looking for input data augmentation, the item_tfms and batch_tfms inside of the dataloaders already do that.

[I was incorrect]

Have you been able to resolve the error: ‘NBProgressBar’ object has no attribute ‘start_t’? I’m encountering the same issue too!

No I’m still getting this error when I try to apply it. I’m going to submit a bug report. @sgugger any thoughts? My workflow I’m trying to use is (in Google colab)

learn = load_learner('learnfile.pkl')

test_imgs = get_image_files(img_dir)
tst_dl = learn.dls.test_dl(test_imgs)

targs, preds = learn.tta(dl=tst_dl)

This gives the following error

epoch	train_loss	valid_loss	accuracy	time
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-23-ece6d3f653d6> in <module>()
----> 1 learn.tta(dl=tst_dl)

13 frames
/usr/local/lib/python3.6/dist-packages/fastai/learner.py in tta(self, ds_idx, dl, n, item_tfms, batch_tfms, beta, use_max)
    562     if item_tfms is not None or batch_tfms is not None: dl = dl.new(after_item=item_tfms, after_batch=batch_tfms)
    563     try:
--> 564         self(_before_epoch)
    565         with dl.dataset.set_split_idx(0), self.no_mbar():
    566             if hasattr(self,'progress'): self.progress.mbar = master_bar(list(range(n)))

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in __call__(self, event_name)
    131     def ordered_cbs(self, event): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, event)]
    132 
--> 133     def __call__(self, event_name): L(event_name).map(self._call_one)
    134 
    135     def _call_one(self, event_name):

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in map(self, f, *args, **kwargs)
    381              else f.format if isinstance(f,str)
    382              else f.__getitem__)
--> 383         return self._new(map(g, self))
    384 
    385     def filter(self, f, negate=False, **kwargs):

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    331     @property
    332     def _xtra(self): return None
--> 333     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    334     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
    335     def copy(self): return self._new(self.items.copy())

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
     45             return x
     46 
---> 47         res = super().__call__(*((x,) + args), **kwargs)
     48         res._newchk = 0
     49         return res

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    322         if items is None: items = []
    323         if (use_list is not None) or not _is_array(items):
--> 324             items = list(items) if use_list else _listify(items)
    325         if match is not None:
    326             if is_coll(match): match = len(match)

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in _listify(o)
    235     if isinstance(o, list): return o
    236     if isinstance(o, str) or _is_array(o): return [o]
--> 237     if is_iter(o): return list(o)
    238     return [o]
    239 

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __call__(self, *args, **kwargs)
    298             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    299         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 300         return self.fn(*fargs, **kwargs)
    301 
    302 # Cell

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _call_one(self, event_name)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name), event_name
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in <listcomp>(.0)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name), event_name
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

/usr/local/lib/python3.6/dist-packages/fastai/callback/core.py in __call__(self, event_name)
     42                (self.run_valid and not getattr(self, 'training', False)))
     43         res = None
---> 44         if self.run and _run: res = getattr(self, event_name, noop)()
     45         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     46         return res

/usr/local/lib/python3.6/dist-packages/fastai/callback/progress.py in before_epoch(self)
     21 
     22     def before_epoch(self):
---> 23         if getattr(self, 'mbar', False): self.mbar.update(self.epoch)
     24 
     25     def before_train(self):    self._launch_pbar()

/usr/local/lib/python3.6/dist-packages/fastprogress/fastprogress.py in update(self, val)
     92             yield o
     93 
---> 94     def update(self, val): self.main_bar.update(val)
     95 
     96 # Cell

/usr/local/lib/python3.6/dist-packages/fastprogress/fastprogress.py in update(self, val)
     57         elif val <= self.first_its or val >= self.last_v + self.wait_for or val >= self.total:
     58             cur_t = time.time()
---> 59             avg_t = (cur_t - self.start_t) / val
     60             self.wait_for = max(int(self.update_every / (avg_t+1e-8)),1)
     61             self.pred_t = avg_t * self.total

AttributeError: 'NBProgressBar' object has no attribute 'start_t'

Just submitted a bug report at https://github.com/fastai/fastai/issues/2764. Also made a colab notebook reproducing the error, which upon further investigation shows that it’s only with learners loaded with load_learner

https://colab.research.google.com/drive/1l1tnOtboOwOhsYjwi8K8cUEPIPpDJMaU?usp=sharing

Did you attempt upgrading your fastprogress?

I thought the !pip install -Uqq fastbook upgrade everything. Is that wrong? What’s the current recommended best install/import practice for Colab?

It should, however when running into issues best debugging practices are desired. You can simply do !pip install fastprogress --upgrade to see if it simply needs a release or somethings.

Thank you for the suggestion, I tried it and unfortunately the error persists.

2 Likes

Have encountered the same error :

“AttributeError: ‘NBProgressBar’ object has no attribute ‘start_t’”