How best to have get_preds or TTA apply specified transforms?


(Malcolm McLean) #1

I would like to apply the dihedral group only to the Test set (8 runs, square images) in order to average the results. What is the best way to do this with fastai? Or PyTorch if needed.

Thank you!

P. S. I think the dihedral group is 4 rotations x 2 flips.


#2

You have the transform build in fastai (called dihedral) so you just need to call it on your images eight times: img.dihedral(k).


(Malcolm McLean) #3

Thanks, I will give it a try.


(Malcolm McLean) #4

Though I hate to admit it, I can’t figure out how to use img.dihedral(k) in practice.

Specifically, let’s say I start with…
testprobs,testlabels = learn.get_preds(ds_type=DatasetType.Test)

Can you please show me example code for getting back the predictions for each of the 8 dihedral transforms? Thanks again.


#5

You should set manually the transform you want on data.test_ds before that call. It’s in the attribute data.test_ds.tfms, so just add dihedral(k=..., p=1) to that list. You’ll have to manually loop over the eight values.


(RobG) #6

I also hate to admit it, but I am sure I am doing this sub optimally. I tend to find workarounds and stick to them. For flexibility i often don’t create a test_ds and just pass augmented images into predict. Currently I have a hacky function that makes predictions on all 8 orientations then reorients and averages. I found that you can just call dihedral again with the same k value to reorient the image. Except for k=5,6 where you call the opposite k=6,5 to get back to origin. Looping k=0-7 over something like img_result = learn.predict(dihedral(img_in,k)) img_result_oriented = dihedral(Image(img_result),k)

Based on sgugger’s advice above I’ll definitely revisit using test_ds and attached transforms, sequentially passing in 8 test sets.


(Malcolm McLean) #7

Hi. Back again. Here’s what I get following the above…

data.test_ds.tfms = [dihedral(k=2,p=1)]
data.test_ds.tfms #[RandTransform(tfm=TfmPixel (dihedral), kwargs={'k': 0}, p=1, resolved={}, do_run=True, is_random=True)]
testprobs,val_labels = learn.get_preds(ds_type=DatasetType.Test)
   etc...
     54         val_losses,nums = [],[]
     55         if cb_handler: cb_handler.set_dl(dl)
---> 56         for xb,yb in progress_bar(dl, parent=pbar, leave=(pbar is not None)):
     57             if cb_handler: xb, yb = cb_handler.on_batch_begin(xb, yb, train=False)
     58             val_loss = loss_batch(model, xb, yb, loss_func, cb_handler=cb_handler)

~/anaconda3/envs/fastaiv3/lib/python3.7/site-packages/fastprogress/fastprogress.py in __iter__(self)
     64         self.update(0)
     65         try:
---> 66             for i,o in enumerate(self._gen):
     67                 yield o
     68                 if self.auto_update: self.update(i+1)

~/anaconda3/envs/fastaiv3/lib/python3.7/site-packages/fastai/basic_data.py in __iter__(self)
     73     def __iter__(self):
     74         "Process and returns items from `DataLoader`."
---> 75         for b in self.dl: yield self.proc_batch(b)
     76 
     77     @classmethod

~/anaconda3/envs/fastaiv3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    613         if self.num_workers == 0:  # same-process loading
    614             indices = next(self.sample_iter)  # may raise StopIteration
--> 615             batch = self.collate_fn([self.dataset[i] for i in indices])
    616             if self.pin_memory:
    617                 batch = pin_memory_batch(batch)

~/anaconda3/envs/fastaiv3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in <listcomp>(.0)
    613         if self.num_workers == 0:  # same-process loading
    614             indices = next(self.sample_iter)  # may raise StopIteration
--> 615             batch = self.collate_fn([self.dataset[i] for i in indices])
    616             if self.pin_memory:
    617                 batch = pin_memory_batch(batch)

~/anaconda3/envs/fastaiv3/lib/python3.7/site-packages/fastai/data_block.py in __getitem__(self, idxs)
    631             else:                 x,y = self.item   ,0
    632             if self.tfms or self.tfmargs:
--> 633                 x = x.apply_tfms(self.tfms, **self.tfmargs)
    634             if hasattr(self, 'tfms_y') and self.tfm_y and self.item is None:
    635                 y = y.apply_tfms(self.tfms_y, **{**self.tfmargs_y, 'do_resolve':False})

~/anaconda3/envs/fastaiv3/lib/python3.7/site-packages/fastai/vision/image.py in apply_tfms(self, tfms, do_resolve, xtra, size, resize_method, mult, padding_mode, mode, remove_out)
    105         if resize_method <= 2 and size is not None: tfms = self._maybe_add_crop_pad(tfms)
    106         tfms = sorted(tfms, key=lambda o: o.tfm.order)
--> 107         if do_resolve: _resolve_tfms(tfms)
    108         x = self.clone()
    109         x.set_sample(padding_mode=padding_mode, mode=mode, remove_out=remove_out)

~/anaconda3/envs/fastaiv3/lib/python3.7/site-packages/fastai/vision/image.py in _resolve_tfms(tfms)
    519 def _resolve_tfms(tfms:TfmList):
    520     "Resolve every tfm in `tfms`."
--> 521     for f in listify(tfms): f.resolve()
    522 
    523 def _grid_sample(x:TensorImage, coords:FlowField, mode:str='bilinear', padding_mode:str='reflection', remove_out:bool=True)->TensorImage:

~/anaconda3/envs/fastaiv3/lib/python3.7/site-packages/fastai/vision/image.py in resolve(self)
    498             if k in self.tfm.params:
    499                 rand_func = self.tfm.params[k]
--> 500                 self.resolved[k] = rand_func(*listify(v))
    501             # ...otherwise use the value directly
    502             else: self.resolved[k] = v

~/anaconda3/envs/fastaiv3/lib/python3.7/site-packages/fastai/torch_core.py in uniform_int(low, high, size)
    356 def uniform_int(low:int, high:int, size:Optional[List[int]]=None)->IntOrTensor:
    357     "Generate int or tensor `size` of ints between `low` and `high` (included)."
--> 358     return random.randint(low,high) if size is None else torch.randint(low,high+1,size)
    359 
    360 def one_param(m: nn.Module)->Tensor:

TypeError: randint() received an invalid combination of arguments - got (int, int, int), but expected one of:
 * (int high, tuple of ints size, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool requires_grad)
 * (int high, tuple of ints size, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool requires_grad)
 * (int low, int high, tuple of ints size, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool requires_grad)
 * (int low, int high, tuple of ints size, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool requires_grad)

To be clear, I do not want any random anything. Just to specify the dihedral group element by integer and apply it in learn.get_preds().

I did spend some time with PyCharm’s debugger trying to sort this out. But I got lost without an understanding of the intention behind the code design, esp. with parameter introspection and the meaning of instance variables set at some previous time.

Thanks for your help and patience.


#8

Ugh, there is some stuff in the resolve that doesn’t work properly. A temporary workaround is to define your custom dihedral:

def _custom_dihedral(x, k=5):
    "Randomly flip `x` image based on `k`."
    flips=[]
    if k&1: flips.append(1)
    if k&2: flips.append(2)
    if flips: x = torch.flip(x,flips)
    if k&4: x = x.transpose(1,2)
    return x.contiguous()
custom_dihedral = TfmPixel(_custom_dihedral)

then add that transform to the list (don’t forget parenthesis: custom_dihedral()). Then adapt the value of k to your needs.

Otherwise img.dihedral(5) works, but that requires you to loop over all the images.


(Malcolm McLean) #9

Thanks, that worked, and it advanced me four places on the leaderboard!

For future reference:

  • To use the PyCharm debugger with data loaders, you must set
    data.test_dl.num_workers = 0
    Otherwise PyCharm does not honor any breakpoints set in the spawned worker processes.

  • Here is a working code fragment that can be adapted. It averages the test set probabilities over the eight dihedral transformations.

      sums = 0
          for ki in range(8):
              data.test_ds.tfms = [custom_dihedral(k=ki)]
              testprobs,test_labels = learn.get_preds(ds_type=DatasetType.Test)
              testprobs = testprobs.sigmoid() #Not automatic because does not recognize my custom loss function 
              foo = testprobs[:,0]/(testprobs[:,0]+testprobs[:,3]) #Normalize the pertinent probabilities to sum to 1
              sums += foo
          average = sums/8.0