Can you please tell us what you’re doing to get this? I don’t think anyone has specified how to recreate it, so we can’t test for it and fix it…
Sure thing…
So I’ve been playing around with Tanishq Abraham’s SETI Breakthrough Listen starter notebook on Kaggle (which seems to use your timm_learner
code – thanks for that too!), and am wrapping a fast.ai DataLoaders
around two PyTorch DataLoader
s; after augmentation the data have shape (1, 256, 256)
and dtype=torch.float32
.
Using the timm_learner
and roc_auc
as defined in the Kaggle notebook, I call:
mixup = MixUp(1.)
learn = timm_learner(dls,'resnext50_32x4d',pretrained=True,n_in=1,n_out=1,
metrics=[roc_auc], opt_func=ranger,
cbs=[mixup, WandbCallback(log_model=True)],
loss_func=BCEWithLogitsLossFlat()).to_fp16()
_, lr_steep = learn.lr_find()
I get the following error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-36-7a54b2a9310c> in <module>
----> 1 _, lr_steep = learn.lr_find()
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastai/callback/schedule.py in lr_find(self, start_lr, end_lr, num_it, stop_div, show_plot, suggestions)
220 n_epoch = num_it//len(self.dls.train) + 1
221 cb=LRFinder(start_lr=start_lr, end_lr=end_lr, num_it=num_it, stop_div=stop_div)
--> 222 with self.no_logging(): self.fit(n_epoch, cbs=cb)
223 if show_plot: self.recorder.plot_lr_find()
224 if suggestions:
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
209 self.opt.set_hypers(lr=self.lr if lr is None else lr)
210 self.n_epoch = n_epoch
--> 211 self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
212
213 def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
158
159 def _with_events(self, f, event_type, ex, final=noop):
--> 160 try: self(f'before_{event_type}'); f()
161 except ex: self(f'after_cancel_{event_type}')
162 self(f'after_{event_type}'); final()
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastai/learner.py in _do_fit(self)
200 for epoch in range(self.n_epoch):
201 self.epoch=epoch
--> 202 self._with_events(self._do_epoch, 'epoch', CancelEpochException)
203
204 def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
158
159 def _with_events(self, f, event_type, ex, final=noop):
--> 160 try: self(f'before_{event_type}'); f()
161 except ex: self(f'after_cancel_{event_type}')
162 self(f'after_{event_type}'); final()
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastai/learner.py in _do_epoch(self)
194
195 def _do_epoch(self):
--> 196 self._do_epoch_train()
197 self._do_epoch_validate()
198
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_train(self)
186 def _do_epoch_train(self):
187 self.dl = self.dls.train
--> 188 self._with_events(self.all_batches, 'train', CancelTrainException)
189
190 def _do_epoch_validate(self, ds_idx=1, dl=None):
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
158
159 def _with_events(self, f, event_type, ex, final=noop):
--> 160 try: self(f'before_{event_type}'); f()
161 except ex: self(f'after_cancel_{event_type}')
162 self(f'after_{event_type}'); final()
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastai/learner.py in all_batches(self)
164 def all_batches(self):
165 self.n_iter = len(self.dl)
--> 166 for o in enumerate(self.dl): self.one_batch(*o)
167
168 def _do_one_batch(self):
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastai/learner.py in one_batch(self, i, b)
182 self.iter = i
183 self._split(b)
--> 184 self._with_events(self._do_one_batch, 'batch', CancelBatchException)
185
186 def _do_epoch_train(self):
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
158
159 def _with_events(self, f, event_type, ex, final=noop):
--> 160 try: self(f'before_{event_type}'); f()
161 except ex: self(f'after_cancel_{event_type}')
162 self(f'after_{event_type}'); final()
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastai/learner.py in __call__(self, event_name)
139
140 def ordered_cbs(self, event): return [cb for cb in self.cbs.sorted('order') if hasattr(cb, event)]
--> 141 def __call__(self, event_name): L(event_name).map(self._call_one)
142
143 def _call_one(self, event_name):
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastcore/foundation.py in map(self, f, gen, *args, **kwargs)
152 def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step))
153
--> 154 def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args, gen=gen, **kwargs))
155 def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, negate, **kwargs))
156 def filter(self, f=noop, negate=False, gen=False, **kwargs):
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastcore/basics.py in map_ex(iterable, f, gen, *args, **kwargs)
664 res = map(g, iterable)
665 if gen: return res
--> 666 return list(res)
667
668 # Cell
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastcore/basics.py in __call__(self, *args, **kwargs)
649 if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
650 fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 651 return self.func(*fargs, **kwargs)
652
653 # Cell
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastai/learner.py in _call_one(self, event_name)
143 def _call_one(self, event_name):
144 if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
--> 145 for cb in self.cbs.sorted('order'): cb(event_name)
146
147 def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastai/callback/core.py in __call__(self, event_name)
42 (self.run_valid and not getattr(self, 'training', False)))
43 res = None
---> 44 if self.run and _run: res = getattr(self, event_name, noop)()
45 if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
46 return res
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastai/callback/mixup.py in before_batch(self)
47 if not self.stack_y:
48 ny_dims = len(self.y.size())
---> 49 self.learn.yb = tuple(L(self.yb1,self.yb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=ny_dims-1)))
50
51 # Cell
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastcore/foundation.py in map_zip(self, f, cycled, *args, **kwargs)
176 def zip(self, cycled=False): return self._new((zip_cycle if cycled else zip)(*self))
177 def zipwith(self, *rest, cycled=False): return self._new([self, *rest]).zip(cycled=cycled)
--> 178 def map_zip(self, f, *args, cycled=False, **kwargs): return self.zip(cycled=cycled).starmap(f, *args, **kwargs)
179 def map_zipwith(self, f, *rest, cycled=False, **kwargs): return self.zipwith(*rest, cycled=cycled).starmap(f, **kwargs)
180 def shuffle(self):
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastcore/foundation.py in starmap(self, f, *args, **kwargs)
173 return self.map(lambda o: o.get(k,default) if isinstance(o, dict) else nested_attr(o,k,default))
174
--> 175 def starmap(self, f, *args, **kwargs): return self._new(itertools.starmap(partial(f,*args,**kwargs), self))
176 def zip(self, cycled=False): return self._new((zip_cycle if cycled else zip)(*self))
177 def zipwith(self, *rest, cycled=False): return self._new([self, *rest]).zip(cycled=cycled)
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastcore/foundation.py in _new(self, items, *args, **kwargs)
108 @property
109 def _xtra(self): return None
--> 110 def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
111 def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
112 def copy(self): return self._new(self.items.copy())
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
95 def __call__(cls, x=None, *args, **kwargs):
96 if not args and not kwargs and x is not None and isinstance(x,cls): return x
---> 97 return super().__call__(x, *args, **kwargs)
98
99 # Cell
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
103 def __init__(self, items=None, *rest, use_list=False, match=None):
104 if (use_list is not None) or not is_array(items):
--> 105 items = listify(items, *rest, use_list=use_list, match=match)
106 super().__init__(items)
107
~/anaconda3/envs/fastai-v227/lib/python3.8/site-packages/fastcore/basics.py in listify(o, use_list, match, *rest)
54 elif isinstance(o, list): res = o
55 elif isinstance(o, str) or is_array(o): res = [o]
---> 56 elif is_iter(o): res = list(o)
57 else: res = [o]
58 if match is not None:
RuntimeError: expected dtype long int for `weights` but got dtype float
I think the errors were identical on v2.4.0 and v2.3.1, but I’d be happy to reproduce for one or both if you’d like… Let me know if I can provide anything else that would be helpful, and thanks a ton…
How are you building the dataloaders? (I’m unfamiliar with that comp, is it single label classification? Segmentation? etc)
The comp’s just binary classification (i.e. alien technosignatures present or absent), and the labels are presented in a two-col CSV containing the filename stem for the .npy
representing the signal and a 0 or 1, respectively. I’m presently using a barely-modified version of the starter notebook’s PyTorch Dataset
right now:
class SETIDatasetBaseline:
"""
`spatial` indicates whether you want to stack the spectrograms vertically
`sixchan` indicates whether you want to preserve the nod-offs; if False, they are ignored
"""
def __init__(self, df:pd.DataFrame, spatial:bool=True, sixchan:bool=True, transform=None):
self.df = df
self.spatial = spatial
self.sixchan = sixchan
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, index):
label = self.df.iloc[index].target
filename = self.df.iloc[index].path
data = np.load(filename).astype(np.float32)
if not self.sixchan: data = data[::2].astype(np.float32) # so this ignores the nod-offs
if self.spatial:
data = np.vstack(data).transpose((1, 0))
data = skimage.transform.resize(data, (256,256))
data_tensor = torch.tensor(data).float().unsqueeze(0) # adds the channel dummy dimension, should now be CxHxW
if self.transform:
data_tensor = self.transform(data_tensor) # apply transforms
else:
data = np.transpose(data, (1,2,0))
data = skimage.transform.resize(data, (256,256))
data = np.transpose(data, (2, 0, 1)).astype(np.float32)
data_tensor = torch.tensor(data).float()
if self.transform:
data_tensor = self.transform(data_tensor)
return (data_tensor, torch.tensor(label))
To get the individual DataLoader
s (with just basic torchvision.transforms
):
trainset = SETIDatasetBaseline(train_df, transform=cv_transforms)
validset = SETIDatasetBaseline(valid_df, transform=cv_transforms)
And then
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, num_workers=4, pin_memory=True)
validloader = torch.utils.data.DataLoader(validset, batch_size=64, num_workers=4, pin_memory=True)
dls = DataLoaders(trainloader, validloader)
I suspect that this is not a bug, but rather that I’m making an error wrt numeric typecasting – i.e. Mixup
is working with floats whereas BCEWithLogitsLossFlat
requires Long
ints. I’ll poke at it and post something later…
Is this a multi-label classification task?
If so, MixUp only works with single-label classification in fastai. This means a loss function of nn.CrossEntropyLoss (or CrossEntropyLossFlat)
Got this to work – the issue, I think, was that my labels were the wrong numeric type, namely integers (0 or 1). When I typecast them into torch.float
s on my call to torch.tensor
, the Learner
using BCEWithLogitsLossFlat
worked fine when fitting.
Why stop at just mixing two examples?
I changed how the loss is calculated to make mixup work for segmentation problems. The main idea is to broadcast lam
to the size of the predictions instead of the batch size. Same change applied to the code of manifold mixup V2 by @nestorDemeure can make it work for segmentation as well.
Although the code works, I’m not sure if this is the correct way to apply mixup to segmentation problems. Any suggestion or correction is appreciated.
from torch.distributions.beta import Beta
from fastai.callback.mixup import reduce_loss
class MixHandler_Segmentation(Callback):
"A handler class for implementing `MixUp` style scheduling"
run_valid = False
def __init__(self, alpha=0.5):
self.distrib = Beta(tensor(alpha), tensor(alpha))
def before_train(self):
self.stack_y = getattr(self.learn.loss_func, 'y_int', False)
if self.stack_y: self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf
def after_train(self):
if self.stack_y: self.learn.loss_func = self.old_lf
def lf(self, pred, *yb):
if not self.training: return self.old_lf(pred, *yb)
with NoneReduce(self.old_lf) as lf:
sz = yb[0].size()
new_lam = self.lam.expand(*sz[1:], sz[0]).transpose(0,-1).contiguous().view(-1)
loss = torch.lerp(lf(pred,*self.yb1), lf(pred,*yb), new_lam)
# loss = torch.lerp(lf(pred,*self.yb1), lf(pred,*yb), self.lam)
return reduce_loss(loss, getattr(self.old_lf, 'reduction', 'mean'))
class MixUp_Segmentation(MixHandler_Segmentation):
"Implementation of https://arxiv.org/abs/1710.09412"
def __init__(self, alpha=.4): super().__init__(alpha)
def before_batch(self):
lam = self.distrib.sample((self.y.size(0),)).squeeze().to(self.x.device)
lam = torch.stack([lam, 1-lam], 1)
self.lam = lam.max(1)[0]
shuffle = torch.randperm(self.y.size(0)).to(self.x.device)
xb1,self.yb1 = tuple(L(self.xb).itemgot(shuffle)),tuple(L(self.yb).itemgot(shuffle))
nx_dims = len(self.x.size())
self.learn.xb = tuple(L(xb1,self.xb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nx_dims-1)))
if not self.stack_y:
ny_dims = len(self.y.size())
self.learn.yb = tuple(L(self.yb1,self.yb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=ny_dims-1)))
Hi! Mixup is used by fastai by default?
It is not, you need to pass in the callback yourself
Something like:
learn.fit_one_cycle(..., cbs=[MixUp()])
Oh! I read that somewhere. It was wrong. So, how are augmentations applied? If you could point me out to some link with full explanation I’d really appreciate it, I’m really interested on how it’s working.
Would be awesome if you could share the code that modifies the mixup callback to work with text embeddings
Why do we call the Mix-up approach a data augmentation technique?
We double or quadruple the data using classic augmentation techniques (e.g., Jittering, Scaling, Magnitude Warping). For instance, if the original data set contained 4000 samples, there will be 8000 samples in the data set after the augmentation.
On the other hand, according to my understanding, in Mixup data augmentation, we do not add the data but rather mix the samples and their labels and use these new mixed samples for training to produce a more regularized model. Am I correct? If yes, then why is the Mixup method referred to as data augmentation? Since we only mix samples and not artificially increase the data set size?