How to get class labels from predictions

So I was trying to predict on a test set using learn.get_preds(ds_type=DatasetType.Test)
and I got the predictions. But I need the labels of the top k predicted classes. Can anyone tell me how to do that

1 Like

In fastai v1.0.27 you can get class to index mapping using You can use torch.topk to get the indexes of the highest k values. Use these to get the class labels.


Thanks I didn’t know fastai had such function

No problem :slight_smile: We are all learning new things from each other.

Getting this error while running learn.get_preds(ds_type=DatasetType.Test)

IndexError                                Traceback (most recent call last)
<ipython-input-10-ce77ea7f5e1e> in <module>
----> 1 learn.get_preds(ds_type=DatasetType.Test)

/opt/anaconda3/lib/python3.6/site-packages/fastai/ in get_preds(self, ds_type, with_loss, n_batch, pbar)
    209         lf = self.loss_func if with_loss else None
    210         return get_preds(self.model, self.dl(ds_type), cb_handler=CallbackHandler(self.callbacks),
--> 211                          activ=_loss_func2activ(self.loss_func), loss_func=lf, n_batch=n_batch, pbar=pbar)
    213     def pred_batch(self, ds_type:DatasetType=DatasetType.Valid, pbar:Optional[PBar]=None) -> List[Tensor]:

/opt/anaconda3/lib/python3.6/site-packages/fastai/ in get_preds(model, dl, pbar, cb_handler, activ, loss_func, n_batch)
     36     "Tuple of predictions and targets, and optional losses (if `loss_func`) using `dl`, max batches `n_batch`."
     37     res = [ for o in
---> 38            zip(*validate(model, dl, cb_handler=cb_handler, pbar=pbar, average=False, n_batch=n_batch))]
     39     if loss_func is not None: res.append(calc_loss(res[0], res[1], loss_func))
     40     if activ is not None: res[0] = activ(res[0])

/opt/anaconda3/lib/python3.6/site-packages/fastai/ in validate(model, dl, loss_func, cb_handler, pbar, average, n_batch)
     47     with torch.no_grad():
     48         val_losses,nums = [],[]
---> 49         for xb,yb in progress_bar(dl, parent=pbar, leave=(pbar is not None)):
     50             if cb_handler: xb, yb = cb_handler.on_batch_begin(xb, yb, train=False)
     51             val_losses.append(loss_batch(model, xb, yb, loss_func, cb_handler=cb_handler))

/opt/anaconda3/lib/python3.6/site-packages/fastprogress/ in __iter__(self)
     63         self.update(0)
     64         try:
---> 65             for i,o in enumerate(self._gen):
     66                 yield o
     67                 if self.auto_update: self.update(i+1)

/opt/anaconda3/lib/python3.6/site-packages/fastai/ in __iter__(self)
     45     def __iter__(self):
     46         "Process and returns items from `DataLoader`."
---> 47         for b in self.dl:
     48             y = b[1][0] if is_listy(b[1]) else b[1]
     49             if not self.skip_size1 or y.size(0) != 1: yield self.proc_batch(b)

/opt/anaconda3/lib/python3.6/site-packages/torch/utils/data/ in __next__(self)
    635                 self.reorder_dict[idx] = batch
    636                 continue
--> 637             return self._process_next_batch(batch)
    639     next = __next__  # Python 2 compatibility

/opt/anaconda3/lib/python3.6/site-packages/torch/utils/data/ in _process_next_batch(self, batch)
    656         self._put_indices()
    657         if isinstance(batch, ExceptionWrapper):
--> 658             raise batch.exc_type(batch.exc_msg)
    659         return batch

IndexError: Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.6/site-packages/torch/utils/data/", line 138, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/opt/anaconda3/lib/python3.6/site-packages/torch/utils/data/", line 138, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/opt/anaconda3/lib/python3.6/site-packages/fastai/", line 415, in __getitem__
    if self.item is None: x,y = self.x[idxs],self.y[idxs]
  File "/opt/anaconda3/lib/python3.6/site-packages/fastai/", line 82, in __getitem__
    if isinstance(try_int(idxs), int): return self.get(idxs)
  File "/opt/anaconda3/lib/python3.6/site-packages/fastai/vision/", line 289, in get
    fn = super().get(i)
  File "/opt/anaconda3/lib/python3.6/site-packages/fastai/", line 52, in get
    def get(self, i)->Any: return self.items[i]
IndexError: index 0 is out of bounds for axis 0 with size 0

Do you know the reason for integer encoding the classes ?
To me the most intuitive encoding would be a binary encoding, such that for each image, the label would be a vector of length len(classes), with 1s wherever the class is labeled positive.
I am simply wondering why the almighty creators of fastAI chose to use dictionaries instead.

don’t you think that would be quite memory intensive as compared to encoding them into dictionaries. Think of a situation where there are many classes. The one-hot vector would become quite sparse as compared to a dictionary

1 Like