TypeError: loss_batch() missing 1 required positional argument: 'yb'

I’m having trouble figuring out why I’m getting this error. It occurs during callbacks, but only with the validation set. If I understand right, ‘yb’ should be the targets. I’ve checked the datasets and dataloaders, and the targets are getting returned like they should. It could possibly be a problem with how I’m using the data block API, but I don’t know where the problem would be. I’m trying to train on ImageNet. My code looks looks like this:

def imagenet_from_folders(datapath=DATAPATH, size=64, bs=32, nw=4):
        '''Return an imagenet databunch, constructed from image files
        inside labeled folders.
        datapath = Path(datapath)
        meta = datapath / 'devkit/data/'
        wnid2id = json.load(open(meta/'id2label.json'))
        val_anno = [int(x) - 1 for x in 
        train_list = (ImageList.from_folder(datapath/'train')
                      .label_from_func(lambda x: wnid2id[x.parent.name])
        val_list = (ImageList.from_folder(datapath/'val')
        data = (LabelLists(datapath, train_list, val_list)
                .transform(get_transforms(), size=size)
                .databunch(bs=bs, num_workers=nw)
        return data

data = imagenet_from_folders(size=128, bs=96, nw=8)
head = torch.nn.Sequential(torch.nn.AdaptiveAvgPool2d(1),
                           torch.nn.Linear(512, 1000))
learner = cnn_learner(data, models.resnet18, pretrained=False, custom_head=head, metrics=[loss_batch, accuracy])
learner.fit_one_cycle(1, 5e-5, 3e-3)

And the resulting exception:

TypeError                                 Traceback (most recent call last)
<ipython-input-28-920e0078111b> in <module>
----> 1 learner.fit_one_cycle(1, 5e-5, 3e-3)

~/fai/lib/python3.7/site-packages/fastai/train.py in fit_one_cycle(learn, cyc_len, max_lr, moms, div_factor, pct_start, final_div, wd, callbacks, tot_epochs, start_epoch)
     20     callbacks.append(OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor, pct_start=pct_start,
     21                                        final_div=final_div, tot_epochs=tot_epochs, start_epoch=start_epoch))
---> 22     learn.fit(cyc_len, max_lr, wd=wd, callbacks=callbacks)
     24 def lr_find(learn:Learner, start_lr:Floats=1e-7, end_lr:Floats=10, num_it:int=100, stop_div:bool=True, wd:float=None):

~/fai/lib/python3.7/site-packages/fastai/basic_train.py in fit(self, epochs, lr, wd, callbacks)
    200         callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
    201         self.cb_fns_registered = True
--> 202         fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
    204     def create_opt(self, lr:Floats, wd:Floats=0.)->None:

~/fai/lib/python3.7/site-packages/fastai/basic_train.py in fit(epochs, learn, callbacks, metrics)
    104             if not cb_handler.skip_validate and not learn.data.empty_val:
    105                 val_loss = validate(learn.model, learn.data.valid_dl, loss_func=learn.loss_func,
--> 106                                        cb_handler=cb_handler, pbar=pbar)
    107             else: val_loss=None
    108             if cb_handler.on_epoch_end(val_loss): break

~/fai/lib/python3.7/site-packages/fastai/basic_train.py in validate(model, dl, loss_func, cb_handler, pbar, average, n_batch)
     61             if not is_listy(yb): yb = [yb]
     62             nums.append(first_el(yb).shape[0])
---> 63             if cb_handler and cb_handler.on_batch_end(val_losses[-1]): break
     64             if n_batch and (len(nums)>=n_batch): break
     65         nums = np.array(nums, dtype=np.float32)

~/fai/lib/python3.7/site-packages/fastai/callback.py in on_batch_end(self, loss)
    306         "Handle end of processing one batch with `loss`."
    307         self.state_dict['last_loss'] = loss
--> 308         self('batch_end', call_mets = not self.state_dict['train'])
    309         if self.state_dict['train']:
    310             self.state_dict['iteration'] += 1

~/fai/lib/python3.7/site-packages/fastai/callback.py in __call__(self, cb_name, call_mets, **kwargs)
    248         "Call through to all of the `CallbakHandler` functions."
    249         if call_mets:
--> 250             for met in self.metrics: self._call_and_update(met, cb_name, **kwargs)
    251         for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)

~/fai/lib/python3.7/site-packages/fastai/callback.py in _call_and_update(self, cb, cb_name, **kwargs)
    239     def _call_and_update(self, cb, cb_name, **kwargs)->None:
    240         "Call `cb_name` on `cb` and update the inner state."
--> 241         new = ifnone(getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs), dict())
    242         for k,v in new.items():
    243             if k not in self.state_dict:

~/fai/lib/python3.7/site-packages/fastai/callback.py in on_batch_end(self, last_output, last_target, **kwargs)
    342         if not is_listy(last_target): last_target=[last_target]
    343         self.count += first_el(last_target).size(0)
--> 344         val = self.func(last_output, *last_target)
    345         if self.world:
    346             val = val.clone()

TypeError: loss_batch() missing 1 required positional argument: 'yb'

Any help would be greatly appreciated!

Your usage of DataBlock API is indeed quite strange, but your problem doesn’t come from that. Still, you should unify the way your train and validation sets are labelled, so that you can do:

data = (ImageList.
            label_from_func(lambda x: wnid2id[x.parent.name]).
            transform(get_transforms(), size=size).
            databunch(bs=bs, num_workers=nw).

Your problem comes from the fact that you are trying to use loss_batch as a metric for some reason, which has the following signature loss_batch(model:nn.Module, xb:Tensor, yb:Tensor, loss_func:OptLossFunc=None, opt:OptOptimizer=None, cb_handler:Optional[CallbackHandler]=None)->Tuple[Union[Tensor,int,float,str]], so your last_input gets used as model and last_target as xb.

The validation loss will always be computed, you don’t need to pass it as a metric. A metric should always have a signature like def metric(input, target, **kwargs) (like accuracy for instance).

Yes, thank you! I’m not sure why I thought I needed to pass that in.

You’re probably right, I should go ahead and restructure the file system up front so that the train and test images are stored in a compatible way. That would simplify the use of the DataBlock API.

Thanks very much.

1 Like