Implementing `CutMix` in FastaiV2

@init_27, @akashpalrecha, @barnacl, @DrHB and I have worked on trying to implement CutMix research paper for FastaiV2.

We did this together and everyone has equal contributions towards the implementation.

We are sharing the experiment notebook as below and in our experiments we find, that for the same number of epochs for resnet50 we get very similar accuracy to cnn_learner.

Our CutMix implementation changes the batch to:

There’s no computation overhead as such and the training times are very similar without CutMix.

We adapt the MixUp implementation inside the library and use the same loss function to compute CutMix loss.

However, we are trying to refactor code and need help particularly with indexing into 4d tensors - we are trying to do a batch implementation to keep things fast and are not sure how to index into these 4d tensors for rectangular bounding boxes of different sizes. Our implementation currently looks like:

class CutMix(Callback):
    run_after,run_valid = [Normalize],False
    def __init__(self, alpha=1.): self.distrib = Beta(tensor(alpha), tensor(alpha))
    def begin_fit(self):
        self.stack_y = getattr(self.learn.loss_func, 'y_int', False)
        if self.stack_y: self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf

    def after_fit(self):
        if self.stack_y: self.learn.loss_func = self.old_lf

    def begin_batch(self):
        W, H = self.xb[0].size(3), self.xb[0].size(2)
        
        lam = self.distrib.sample((self.y.size(0),)).squeeze().to(self.x.device)
        lam = torch.stack([lam, 1-lam], 1)
        self.lam = lam.max(1)[0]
        shuffle = torch.randperm(self.y.size(0)).to(self.x.device)
        xb1,self.yb1 = tuple(L(self.xb).itemgot(shuffle)),tuple(L(self.yb).itemgot(shuffle))
        nx_dims = len(self.x.size())

        rx = (self.distrib.sample((64,))*W).type(torch.long).to(self.x.device)
        ry = (self.distrib.sample((64,))*H).type(torch.long).to(self.x.device)
        rw = (torch.sqrt(1-self.lam)*W).to(self.x.device)
        rh = (torch.sqrt(1-self.lam)*H).to(self.x.device)

        x1 = torch.round(torch.clamp(rx-rw//2, min=0, max=W)).to(self.x.device).type(torch.long)
        x2 = torch.round(torch.clamp(rx+rw//2, min=0, max=W)).to(self.x.device).type(torch.long)
        y1 = torch.round(torch.clamp(ry-rh//2, min=0, max=H)).to(self.x.device).type(torch.long)
        y2 = torch.round(torch.clamp(ry+rh//2, min=0, max=H)).to(self.x.device).type(torch.long)
        

        for i in range(len(x1)):
            self.learn.xb[0][i, :, x1[i]:x2[i], y1[i]:y2[i]] = xb1[0][i, :, x1[i]:x2[i], y1[i]:y2[i]]
        
        self.lam = (1 - ((x2-x1)*(y2-y1))/(W*H)).type(torch.float)
        
        if not self.stack_y:
            ny_dims = len(self.y.size())
            self.learn.yb = tuple(L(self.yb1,self.yb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=ny_dims-1)))

    def lf(self, pred, *yb):
        if not self.training: return self.old_lf(pred, *yb)
        with NoneReduce(self.old_lf) as lf:
            loss = torch.lerp(lf(pred,*self.yb1), lf(pred,*yb), self.lam)
        return reduce_loss(loss, getattr(self.old_lf, 'reduction', 'mean'))

We believe that we particularly need help in this part of the code:

for i in range(len(x1)):
            self.learn.xb[0][i, :, x1[i]:x2[i], y1[i]:y2[i]] = xb1[0][i, :, x1[i]:x2[i], y1[i]:y2[i]]

We admit that there is a lot of cleanup yet to be done and @sgugger @jeremy if you could please guide us on how to properly index the code and provide feedback :slight_smile:


For indexation, so far we have tried:

import torch

xb1 = torch.randn(64, 3, 128, 128)
xb2 = torch.randn(64, 3, 128, 128)

x1.shape

x1 = torch.randint(0, 128, (64,))
x2 = torch.randint(0, 128, (64,))
y1 = torch.randint(0, 128, (64,))
y2 = torch.randint(0, 128, (64,))

x1.shape

xb1[:, :, x1:x2, y1:y2] = xb2[:, :, x1:x2, y1:y2]

but get an error:

TypeError                                 Traceback (most recent call last)
<ipython-input-11-88c5a959f6b1> in <module>
     13 x1.shape
     14 
---> 15 xb1[:, :, x1:x2, y1:y2] = xb2[:, :, x1:x2, y1:y2]

TypeError: only integer tensors of a single element can be converted to an index

Here is the GIST for the notebook:

15 Likes

I posted on official torch.forum same question about speeding up index… nobody so far replayed.

I was experimenting with @torch.jit and found it can speed up a little bit for loop, below is the test:

import torch

bs=5024
xb1 = torch.randn(bs, 3, 128, 128)
xb2 = torch.randn(bs, 3, 128, 128)
x1 = torch.randint(0, 128, (bs,))
x2 = torch.randint(0, 128, (bs,))
y1 = torch.randint(0, 128, (bs,))
y2 = torch.randint(0, 128, (bs,))

for loop

%%timeit 
for i in range(len(x1)):
    xb2[i, :, x1[i]:x2[i], y1[i]:y2[i]] = xb1[i, :, x1[i]:x2[i], y1[i]:y2[I]]

10 loops, best of 3: 169 ms per loop

@torch.jit.script

@torch.jit.script
def cut_mix_img(xb1, xb2, x1, x2, y1, y2):
  for i in range(len(x1)):
      xb2[i, :, x1[i]:x2[i], y1[i]:y2[i]] = xb1[i, :, x1[i]:x2[i], y1[i]:y2[i]]
  return xb1, xb2

%%timeit 
cut_mix_img(xb1, xb2, x1, x2, y1, y2)

10 loops, best of 3: 128 ms per loop

approx. 40 ms faster =)

3 Likes

Cool work & thanks for sharing! :smiley:

Do you tried this from the PyTorch forums: https://discuss.pytorch.org/t/typeerror-only-integer-tensors-of-a-single-element-can-be-converted-to-an-index/45641/2

These resources could be interesting for you:

Keep up the good work!

3 Likes

Hi there

Recently, I was working in my custom implementation of CutMix. I could avoid looping over all samples. However, my implementation assumes that you have only 1 input and N outputs as I wrote it for Bengali Kaggle competition (https://www.kaggle.com/c/bengaliai-cv19/overview). I think that your implementations could adapt my approach to avoid looping over all samples :slight_smile:

# Based on current fastai2 MixUp (29-01-2020)
from torch.distributions.beta import Beta
import numpy as np

def rand_bbox(size, lam):
	W = size[2]
	H = size[3]
	cut_rat = np.sqrt(1. - lam)
	cut_w = np.int(W * cut_rat)
	cut_h = np.int(H * cut_rat)

	# uniform
	cx = np.random.randint(W)
	cy = np.random.randint(H)

	bbx1 = np.clip(cx - cut_w // 2, 0, W)
	bby1 = np.clip(cy - cut_h // 2, 0, H)
	bbx2 = np.clip(cx + cut_w // 2, 0, W)
	bby2 = np.clip(cy + cut_h // 2, 0, H)

	return bbx1, bby1, bbx2, bby2


class CutMix(Callback):
	"""
	https://arxiv.org/pdf/1905.04899.pdf
	
	Implementation (based on CutMix custom):
	https://forums.fast.ai/t/cutmix-mixup/47809/3
	https://www.kaggle.com/c/bengaliai-cv19/discussion/126504#725455
	"""
	run_after,run_valid = [Normalize],False
	def __init__(self, alpha=1.0): self.distrib = Beta(tensor(float(alpha)), tensor(alpha))
	
	def begin_fit(self):
	    self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf

	def after_fit(self):
	    self.learn.loss_func = self.old_lf

	def begin_batch(self):
	    self.lam = self.distrib.sample((1,)).item()
	    shuffle = torch.randperm(self.y[0].size(0)).to(self.x.device)
	    
	    bbx1, bby1, bbx2, bby2 = rand_bbox(self.x.size(), self.lam)
	    
	    # TODO: I could do it in fastai2 style
	    # xb1 = tuple(L(self.xb).itemgot(shuffle))
	    # nx_dims = len(self.x.size())
	    # self.learn.xb = tuple(L(xb1,self.xb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nx_dims-1)))
	    self.learn.xb[0][:, :, bbx1:bbx2, bby1:bby2] = self.xb[0][shuffle, :, bbx1:bbx2, bby1:bby2]
	    self.yb1 = tuple([self.yb[i][shuffle] for i in range(len(self.yb))])        

	def lf(self, pred, *yb):
	    if not self.training: return self.old_lf(pred, *yb)
	    return self.old_lf(pred, *yb) * self.lam + self.old_lf(pred, *self.yb1) * (1 - self.lam)

Attentive CutMix looks cool :smiley: Maybe, this weekend I’ll try to implement it.

1 Like

Hi victor =)
In your implementation every batch will have same cut ratio. What we are trying to do for each image in the batch have random cut ratio, very similar to Fastai MixUp =)

1 Like

wow great article! thanks for posting. I think the implementation is pretty straightforward and it summarize in this paragraph.

3 Likes

This will be quite memory intensive though. I guess at the start of training we can save all the bbox locations for all images as that’s not gonna change (since we have an unchanging
pretrained model). After that, during training it’ll be a simple lookup for the bbox locations for each image.

1 Like

I run a quick tests employing a fixed bounding box. It’s about 6 times faster than the looping with torch.jit.script in a 1050 gtx mobile using cuda tensors, bs=64 and images of 224x224.

However, I thinks that it’s a premature optimization. Assuming that:

  • the jit.script optimization takes 6 ms more than indexing with fixed bounding box crop (about the time I get it, see below)
  • 1 million of RGB images 224x224
  • Training 100 epochs
  • Batch size = 64

The total extra time training will be (100 * 1e6 / 64 ) *(5 / 1000) = 7812.5 seconds = 2 h 36 min extra than using a fixed bounding box crop. I believe that in this scenario, the extra time will be insignificant compared to the time training time.

I believe that Cutmix is a useful augmentation. So, I think that it’s better to put in the fastai library and then try to squeeze some extra millisecond.

Finally, it could be worth to investigate if Cutmix can be used along MixUp. In Bengali Kaggle competition, some participants reported that doing 50% of the time MixUp and the other 50% Cutmix improved their score more than using CutMix o MixUp alone https://www.kaggle.com/c/bengaliai-cv19/discussion/123198. So, ¿why not do MixUpOrCutMix transform :smiley:?

import torch


bs=64
xb1 = torch.randn(bs, 3, 224, 224).cuda()
xb2 = torch.randn(bs, 3, 224, 224).to(xb1.device)
x1 = torch.randint(0, 224, (bs,)).to(xb1.device)
x2 = torch.randint(0, 224, (bs,)).to(xb1.device)
y1 = torch.randint(0, 224, (bs,)).to(xb1.device)
y2 = torch.randint(0, 224, (bs,)).to(xb1.device)
%%timeit 
for i in range(len(x1)):
    xb2[i, :, x1[i]:x2[i], y1[i]:y2[i]] = xb1[i, :, x1[i]:x2[i], y1[i]:y2[i]]

6.49 ms ± 33.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

@torch.jit.script
def cut_mix_img(xb1, xb2, x1, x2, y1, y2):
    for i in range(len(x1)):
        xb2[i, :, x1[i]:x2[i], y1[i]:y2[i]] = xb1[i, :, x1[i]:x2[i], y1[i]:y2[i]]
    return xb1, xb2

%%timeit 
cut_mix_img(xb1, xb2, x1, x2, y1, y2)

6.01 ms ± 188 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

shuffle=torch.randperm(xb1.shape[0]).to(xb1.device)
%%timeit
xb2[:, :, 24:156, 54:210] = xb1[shuffle, :, 24:156, 54:210]

954 µs ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Avoid copying to another tensor (it didn’t affect)

shuffle=torch.randperm(xb1.shape[0]).to(xb1.device)
%%timeit
xb1[:, :, 24:156, 54:210] = xb1[shuffle, :, 24:156, 54:210]

944 µs ± 773 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

4 Likes

I am getting the following error and wondering if anyone else has got this error. Using same code as listed in https://github.com/fastai/fastai2/blob/master/nbs/74_callback.cutmix.ipynb. Not much info available on this error and wondering if its a setup issue. Thanks

RuntimeError: "lerp_cuda" not implemented for 'Int'

fastai2 version: 0.0.17
fastcore version: 0.1.17
Windows


RuntimeError                              Traceback (most recent call last)
<ipython-input-21-c08ef54f419b> in <module>
----> 1 learn.fit(7, 1e-2)

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastcore\utils.py in _f(*args, **kwargs)
    428         init_args.update(log)
    429         setattr(inst, 'init_args', init_args)
--> 430         return inst if to_return else f(*args, **kwargs)
    431     return _f
    432 

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastai2\learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    198                     try:
    199                         self.epoch=epoch;          self('begin_epoch')
--> 200                         self._do_epoch_train()
    201                         self._do_epoch_validate()
    202                     except CancelEpochException:   self('after_cancel_epoch')

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastai2\learner.py in _do_epoch_train(self)
    173         try:
    174             self.dl = self.dls.train;                        self('begin_train')
--> 175             self.all_batches()
    176         except CancelTrainException:                         self('after_cancel_train')
    177         finally:                                             self('after_train')

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastai2\learner.py in all_batches(self)
    151     def all_batches(self):
    152         self.n_iter = len(self.dl)
--> 153         for o in enumerate(self.dl): self.one_batch(*o)
    154 
    155     def one_batch(self, i, b):

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastai2\learner.py in one_batch(self, i, b)
    156         self.iter = i
    157         try:
--> 158             self._split(b);                                  self('begin_batch')
    159             self.pred = self.model(*self.xb);                self('after_pred')
    160             if len(self.yb) == 0: return

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastai2\learner.py in __call__(self, event_name)
    132     def ordered_cbs(self, event): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, event)]
    133 
--> 134     def __call__(self, event_name): L(event_name).map(self._call_one)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastcore\foundation.py in map(self, f, *args, **kwargs)
    373              else f.format if isinstance(f,str)
    374              else f.__getitem__)
--> 375         return self._new(map(g, self))
    376 
    377     def filter(self, f, negate=False, **kwargs):

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastcore\foundation.py in _new(self, items, *args, **kwargs)
    324     @property
    325     def _xtra(self): return None
--> 326     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    327     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
    328     def copy(self): return self._new(self.items.copy())

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastcore\foundation.py in __call__(cls, x, *args, **kwargs)
     45             return x
     46 
---> 47         res = super().__call__(*((x,) + args), **kwargs)
     48         res._newchk = 0
     49         return res

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastcore\foundation.py in __init__(self, items, use_list, match, *rest)
    315         if items is None: items = []
    316         if (use_list is not None) or not _is_array(items):
--> 317             items = list(items) if use_list else _listify(items)
    318         if match is not None:
    319             if is_coll(match): match = len(match)

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastcore\foundation.py in _listify(o)
    251     if isinstance(o, list): return o
    252     if isinstance(o, str) or _is_array(o): return [o]
--> 253     if is_iter(o): return list(o)
    254     return [o]
    255 

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastcore\foundation.py in __call__(self, *args, **kwargs)
    217             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    218         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 219         return self.fn(*fargs, **kwargs)
    220 
    221 # Cell

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastai2\learner.py in _call_one(self, event_name)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastai2\learner.py in <listcomp>(.0)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastai2\callback\core.py in __call__(self, event_name)
     22         _run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
     23                (self.run_valid and not getattr(self, 'training', False)))
---> 24         if self.run and _run: getattr(self, event_name, noop)()
     25         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     26 

<ipython-input-9-3e1dac87939e> in begin_batch(self)
     25         if not self.stack_y:
     26             ny_dims = len(self.y.size())
---> 27             self.learn.yb = tuple(L(self.yb1,self.yb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=ny_dims-1)))
     28 
     29     def lf(self, pred, *yb):

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastcore\foundation.py in map_zip(self, f, cycled, *args, **kwargs)
    399     def zip(self, cycled=False): return self._new((zip_cycle if cycled else zip)(*self))
    400     def zipwith(self, *rest, cycled=False): return self._new([self, *rest]).zip(cycled=cycled)
--> 401     def map_zip(self, f, *args, cycled=False, **kwargs): return self.zip(cycled=cycled).starmap(f, *args, **kwargs)
    402     def map_zipwith(self, f, *rest, cycled=False, **kwargs): return self.zipwith(*rest, cycled=cycled).starmap(f, **kwargs)
    403     def concat(self): return self._new(itertools.chain.from_iterable(self.map(L)))

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastcore\foundation.py in starmap(self, f, *args, **kwargs)
    396     def cycle(self): return cycle(self)
    397     def map_dict(self, f=noop, *args, **kwargs): return {k:f(k, *args,**kwargs) for k in self}
--> 398     def starmap(self, f, *args, **kwargs): return self._new(itertools.starmap(partial(f,*args,**kwargs), self))
    399     def zip(self, cycled=False): return self._new((zip_cycle if cycled else zip)(*self))
    400     def zipwith(self, *rest, cycled=False): return self._new([self, *rest]).zip(cycled=cycled)

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastcore\foundation.py in _new(self, items, *args, **kwargs)
    324     @property
    325     def _xtra(self): return None
--> 326     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    327     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
    328     def copy(self): return self._new(self.items.copy())

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastcore\foundation.py in __call__(cls, x, *args, **kwargs)
     45             return x
     46 
---> 47         res = super().__call__(*((x,) + args), **kwargs)
     48         res._newchk = 0
     49         return res

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastcore\foundation.py in __init__(self, items, use_list, match, *rest)
    315         if items is None: items = []
    316         if (use_list is not None) or not _is_array(items):
--> 317             items = list(items) if use_list else _listify(items)
    318         if match is not None:
    319             if is_coll(match): match = len(match)

c:\users\avird\anaconda3\envs\fastten\lib\site-packages\fastcore\foundation.py in _listify(o)
    251     if isinstance(o, list): return o
    252     if isinstance(o, str) or _is_array(o): return [o]
--> 253     if is_iter(o): return list(o)
    254     return [o]
    255 

RuntimeError: "lerp_cuda" not implemented for 'Int'

Maybe, lerp cuda version only support floats. Torch doc example uses floats when they could have used integers. So you need to int/long values to float

Thanks for that. I managed to figure it out after alot of trial and error. I also wanted to check if this was a ‘windows’ issue so I tested on AWS, similar error:

RuntimeError: "lerp_cuda" not implemented for 'Long'

Error on Colab is maybe a bit more to the point:

RuntimeError: expected dtype long for weights but got dtype long

And it all came down to 1 line! :sweat_smile: It seems that you have to run learn twice (not sure why), The first is when you want to view the batch (as in the example):

learn  = cnn_learner(dls, resnet50, loss_func=CrossEntropyLossFlat(), cbs=cutmix, metrics=[accuracy, error_rate])
learn._do_begin_fit(1)
learn.epoch,learn.training = 0,True
learn.dl = dls.train
b = dls.one_batch()
learn._split(b)
learn('begin_batch')
_,axs = plt.subplots(3,3, figsize=(9,9))
dls.show_batch(b=(cutmix.x,cutmix.y), ctxs=axs.flatten())

As learn was already called I was then doing learn.fit_one_cycle(1) and that resulted in the errors. It works if you do this instead:

learn = cnn_learner(dls, resnet50, loss_func=CrossEntropyLossFlat(), cbs=cutmix, metrics=[accuracy, error_rate])
learn.fit_one_cycle(1)

Hi,

I’m about to make a variant of cutmix, but I have a question on the implementation choice.

How come you guys went with a callback instead of a transform? Are there any pros/cons? I’m following the chapter 11 siamese image example.

Thanks,
Daniel

Afaik, you need to change the loss function too, therefore, data augmentation will not be enough.

1 Like

Interesting, thanks.

Hi @amritv could u help me know how do i overcome this wierd issue.
I use above workaround but still get the issue.

learn.fit_flat_cos(6,lr, wd=0.1,cbs= [# GradientAccumulation(n_acc=32) 
                              #ProgressiveLabelCorrection(),
                             mixup,
                             ReduceLROnPlateau(monitor='accuracy',factor=5,patience=3)              
                            ] )

Hey @champs.jaideep, what’s the error your getting? I’m not getting any errors running this:

learn.fit_flat_cos(1, 1e-3, wd=0.1,cbs= [#GradientAccumulation(n_acc=32) ,
                             # ProgressiveLabelCorrection(),
                             mixup,
                             ReduceLROnPlateau(monitor='accuracy',factor=5,patience=3)              
                            ] )