I am trying to implement KFold c.v.
using both PyTorch
and Fastai
.
for fold, (train_idx, valid_idx) in enumerate(kfold.split(glomer_dataset)):
print(f'FOLD {fold}')
print('--------------------------------')
train_subsampler = SubsetRandomSampler(train_idx)
valid_subsampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(my_dataset, batch_size=BS, sampler=train_subsampler)
valid_loader = DataLoader(my_dataset, batch_size=BS, sampler=valid_subsampler)
dls = DataLoaders(train_loader, valid_loader)
opt = ranger
learn = unet_learner(dls, resnet34, loss=nn.CrossEntropyLoss() , self_attention=True, act_cls=Mish, opt_func=opt)
lr = 3e-4
learn.fit_flat_cos(3, slice(lr))
When I run the code above, I get this error:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[10], line 14
11 dls = DataLoaders(train_loader, valid_loader)
13 opt = ranger
---> 14 learn = unet_learner(dls, resnet34, loss=nn.CrossEntropyLoss() , self_attention=True, act_cls=Mish, opt_func=opt)
15 lr = 3e-4
16 learn.fit_flat_cos(3, slice(lr))
File /opt/conda/lib/python3.10/site-packages/fastai/vision/learner.py:262, in unet_learner(dls, arch, normalize, n_out, pretrained, config, loss_func, opt_func, lr, splitter, cbs, metrics, path, model_dir, wd, wd_bn_bias, train_bn, moms, **kwargs)
260 meta = model_meta.get(arch, _default_meta)
261 n_in = kwargs['n_in'] if 'n_in' in kwargs else 3
--> 262 if normalize: _add_norm(dls, meta, pretrained, n_in)
264 n_out = ifnone(n_out, get_c(dls))
265 assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
File /opt/conda/lib/python3.10/site-packages/fastai/vision/learner.py:196, in _add_norm(dls, meta, pretrained, n_in)
194 if stats is None: return
195 if n_in != len(stats[0]): return
--> 196 if not dls.after_batch.fs.filter(risinstance(Normalize)):
197 dls.add_tfms([Normalize.from_stats(*stats)],'after_batch')
File /opt/conda/lib/python3.10/site-packages/fastcore/basics.py:496, in GetAttr.__getattr__(self, k)
494 if self._component_attr_filter(k):
495 attr = getattr(self,self._default,None)
--> 496 if attr is not None: return getattr(attr,k)
497 raise AttributeError(k)
AttributeError: 'DataLoader' object has no attribute 'after_batch'
I assume it is due to a compatibility issue between PyTorch DataLoader and Fastai, but I am not sure.