Not able to predict with learner in segmentation problem

Hi everyone,
I get an error message when I use learn.predict(items[0]), which is stated as follows:
ValueError: not enough values to unpack (expected 2, got 1)
I used the following code for training:

import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image import PIL
from fastai.vision.all import *

items = get_image_files('./Train')
labels = get_image_files('./Train_Label')
def open_tif(fn, chnls=None):
    im = (np.array(Image.open(fn))).astype('float32')
    return PILMask.create(im)
def get_msk(base_fn):
     return Path(str(base_fn).replace('Train', 'Train_Label'))
 
import albumentations as A
 
class SegmentationAlbumentationsTransform(ItemTransform):
    #split_idx = 0,2
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
        img,mask = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
     
aug_pipe = A.Compose([                    
                      A.RandomCrop(384, 384),
                     ]) 
tfms1=SegmentationAlbumentationsTransform(aug_pipe)
db = DataBlock(blocks=(TransformBlock([open_tif]), 
                        TransformBlock([partial(get_msk), 
                       partial(open_tif)])),
                item_tfms=tfms1,
                splitter=RandomSplitter(valid_pct=0.2),       
               )
 
def loss_fn(pred, targ):
    return torch.nn.functional.cross_entropy(pred, targ.squeeze(1).type(torch.long))
 
learn = unet_learner(dls = dl, arch = resnet18, pretrained = True, normalize = False, n_in=1, n_out=3, loss_func = loss_fn)
 
learn.fit_one_cycle(20,lr_max=1e-4, wd=0.8)
learn.export("./fastai_unet.pkl")
learn.predict(items[0])

Does somebody understand where the error in the prediction comes from?

Here is the complete error message:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-97-2ccff2aa1d8f> in <module>
----> 1 learn.predict(items[0])

~/.local/lib/python3.6/site-packages/fastai/learner.py in predict(self, item, rm_type_tfms, with_input)
    264     def predict(self, item, rm_type_tfms=None, with_input=False):
    265         dl = self.dls.test_dl([item], rm_type_tfms=rm_type_tfms, num_workers=0)
--> 266         inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
    267         i = getattr(self.dls, 'n_inp', -1)
    268         inp = (inp,) if i==1 else tuplify(inp)

~/.local/lib/python3.6/site-packages/fastai/learner.py in get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)
    251         if with_loss: ctx_mgrs.append(self.loss_not_reduced())
    252         with ContextManagers(ctx_mgrs):
--> 253             self._do_epoch_validate(dl=dl)
    254             if act is None: act = getattr(self.loss_func, 'activation', noop)
    255             res = cb.all_tensors()

~/.local/lib/python3.6/site-packages/fastai/learner.py in _do_epoch_validate(self, ds_idx, dl)
    201         if dl is None: dl = self.dls[ds_idx]
    202         self.dl = dl
--> 203         with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
    204 
    205     def _do_epoch(self):

~/.local/lib/python3.6/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    161 
    162     def _with_events(self, f, event_type, ex, final=noop):
--> 163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
    165         self(f'after_{event_type}');  final()

~/.local/lib/python3.6/site-packages/fastai/learner.py in all_batches(self)
    167     def all_batches(self):
    168         self.n_iter = len(self.dl)
--> 169         for o in enumerate(self.dl): self.one_batch(*o)
    170 
    171     def _do_one_batch(self):

~/.local/lib/python3.6/site-packages/fastai/data/load.py in __iter__(self)
    107         self.before_iter()
    108         self.__idxs=self.get_idxs() # called in context of main process (not workers/subprocesses)
--> 109         for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
    110             if self.device is not None: b = to_device(b, self.device)
    111             yield self.after_batch(b)

~/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py in __next__(self)
    519             if self._sampler_iter is None:
    520                 self._reset()
--> 521             data = self._next_data()
    522             self._num_yielded += 1
    523             if self._dataset_kind == _DatasetKind.Iterable and \

~/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    559     def _next_data(self):
    560         index = self._next_index()  # may raise StopIteration
--> 561         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    562         if self._pin_memory:
    563             data = _utils.pin_memory.pin_memory(data)

~/.local/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     32                 raise StopIteration
     33         else:
---> 34             data = next(self.dataset_iter)
     35         return self.collate_fn(data)
     36 

~/.local/lib/python3.6/site-packages/fastai/data/load.py in create_batches(self, samps)
    116         if self.dataset is not None: self.it = iter(self.dataset)
    117         res = filter(lambda o:o is not None, map(self.do_item, samps))
--> 118         yield from map(self.do_batch, self.chunkify(res))
    119 
    120     def new(self, dataset=None, cls=None, **kwargs):

~/.local/lib/python3.6/site-packages/fastcore/basics.py in chunked(it, chunk_sz, drop_last, n_chunks)
    214     if not isinstance(it, Iterator): it = iter(it)
    215     while True:
--> 216         res = list(itertools.islice(it, chunk_sz))
    217         if res and (len(res)==chunk_sz or not drop_last): yield res
    218         if len(res)<chunk_sz: return

~/.local/lib/python3.6/site-packages/fastai/data/load.py in do_item(self, s)
    131     def prebatched(self): return self.bs is None
    132     def do_item(self, s):
--> 133         try: return self.after_item(self.create_item(s))
    134         except SkipItemException: return None
    135     def chunkify(self, b): return b if self.prebatched else chunked(b, self.bs, self.drop_last)

~/.local/lib/python3.6/site-packages/fastcore/transform.py in __call__(self, o)
    198         self.fs = self.fs.sorted(key='order')
    199 
--> 200     def __call__(self, o): return compose_tfms(o, tfms=self.fs, split_idx=self.split_idx)
    201     def __repr__(self): return f"Pipeline: {' -> '.join([f.name for f in self.fs if f.name != 'noop'])}"
    202     def __getitem__(self,i): return self.fs[i]

~/.local/lib/python3.6/site-packages/fastcore/transform.py in compose_tfms(x, tfms, is_enc, reverse, **kwargs)
    148     for f in tfms:
    149         if not is_enc: f = f.decode
--> 150         x = f(x, **kwargs)
    151     return x
    152 

~/.local/lib/python3.6/site-packages/fastcore/transform.py in __call__(self, x, **kwargs)
    111     "A transform that always take tuples as items"
    112     _retain = True
--> 113     def __call__(self, x, **kwargs): return self._call1(x, '__call__', **kwargs)
    114     def decode(self, x, **kwargs):   return self._call1(x, 'decode', **kwargs)
    115     def _call1(self, x, name, **kwargs):

~/.local/lib/python3.6/site-packages/fastcore/transform.py in _call1(self, x, name, **kwargs)
    115     def _call1(self, x, name, **kwargs):
    116         if not _is_tuple(x): return getattr(super(), name)(x, **kwargs)
--> 117         y = getattr(super(), name)(list(x), **kwargs)
    118         if not self._retain: return y
    119         if is_listy(y) and not isinstance(y, tuple): y = tuple(y)

~/.local/lib/python3.6/site-packages/fastcore/transform.py in __call__(self, x, **kwargs)
     71     @property
     72     def name(self): return getattr(self, '_name', _get_name(self))
---> 73     def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)
     74     def decode  (self, x, **kwargs): return self._call('decodes', x, **kwargs)
     75     def __repr__(self): return f'{self.name}:\nencodes: {self.encodes}decodes: {self.decodes}'

~/.local/lib/python3.6/site-packages/fastcore/transform.py in _call(self, fn, x, split_idx, **kwargs)
     81     def _call(self, fn, x, split_idx=None, **kwargs):
     82         if split_idx!=self.split_idx and self.split_idx is not None: return x
---> 83         return self._do_call(getattr(self, fn), x, **kwargs)
     84 
     85     def _do_call(self, f, x, **kwargs):

~/.local/lib/python3.6/site-packages/fastcore/transform.py in _do_call(self, f, x, **kwargs)
     87             if f is None: return x
     88             ret = f.returns(x) if hasattr(f,'returns') else None
---> 89             return retain_type(f(x, **kwargs), x, ret)
     90         res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
     91         return retain_type(res, x)

~/.local/lib/python3.6/site-packages/fastcore/dispatch.py in __call__(self, *args, **kwargs)
    116         elif self.inst is not None: f = MethodType(f, self.inst)
    117         elif self.owner is not None: f = MethodType(f, self.owner)
--> 118         return f(*args, **kwargs)
    119 
    120     def __get__(self, inst, owner):

<ipython-input-16-a4a49adafb71> in encodes(self, x)
      3     def __init__(self, aug): self.aug = aug
      4     def encodes(self, x):
----> 5         img,mask = x
      6         aug = self.aug(image=np.array(img), mask=np.array(mask))
      7         return PILImage.create(aug["image"]), PILMask.create(aug["mask"])

ValueError: not enough values to unpack (expected 2, got 1)

You commented out the split_idx on your custom transform, which is extremely important. Albumentations, at least this one, should only be used during training, as currently it’s expecting an input that has both the x and the y (which we do not have on our labels). This is shown in the Albumentations tutorial as well:

If we want to use the augmentation transform we created before, we just need to add one thing to it: we want it to be applied on the training set only, not the validation set. To do this, we specify it should only be applied on a specific idx of our splits by adding split_idx=0 (0 is for the training set, 1 for the validation set):

class SegmentationAlbumentationsTransform(ItemTransform):
    split_idx = 0
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
        img,mask = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"])

Thank you very much for your reply. :slight_smile:
I totally understand that point and i changed split_idx to 0.
However, it did not help. Do you think there is also a problem with the syntax of the inteference? I have two folders with training images and its labels, respectively. The images are 1 channel grayscale images of a large and also changing size (approx. 1000 ,1200 pixels). That is why is use random cropping for the training. My question is how to use some of these images for inteference.

What is the exact error now. It should be different.

The error message is unfortunately still exactly the same.

Have you tried using get_preds.

That isn’t the issue as predict calls get preds.

Can you share what the full code looks now, even if it may be redundant.

I think you might need to create a predict augmentation function, like mentioned here (the 5th post)
So another encodes function.

I tried it, but it didn’t help

Thank you very much for your help. :slight_smile: When I use split_idx=0 in the class SegmentationAlbumentationsTransform the dataloader doesn’t apply the transform (That’s is why commented it again). I am not sure why. Here is the full code:

import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image import PIL
from fastai.vision.all import *

items = get_image_files('./Train')
labels = get_image_files('./Train_Label')

def open_tif(fn, chnls=None):
    im = (np.array(Image.open(fn))).astype('float32')
    return PILMask.create(im)
def get_msk(base_fn):
    return Path(str(base_fn).replace('Train', 'Train_Label'))

class SegmentationAlbumentationsTransform(ItemTransform):
    split_idx = 0
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
     #   print(type(x))
        img,mask = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
    
aug_pipe = A.Compose([                   
                      A.RandomCrop(384, 384),
                    ])
db = DataBlock(blocks=(TransformBlock([open_tif]), 
                       TransformBlock([partial(get_msk), 
                      partial(open_tif)])),
               item_tfms=tfms1,
               splitter=RandomSplitter(valid_pct=0.2),       
              )

dl = db.dataloaders(items, bs=10)

def loss_fn(pred, targ):
    return torch.nn.functional.cross_entropy(pred, targ.squeeze(1).type(torch.long))

learn = unet_learner(dls = dl, arch = resnet18, pretrained = True, normalize = False, n_in=1, n_out=3, loss_func = loss_fn)

learn.lr_find()

learn.fit_one_cycle(20,lr_max=1e-4, wd=0.8)

learn.predict(items[0])

That’s a good point thank you. I tried the two different encode methods with def encodes(self, x: tuple): and def encodes(self, img: TensorImage): in the encode method. However, when i used this the transforms were not applied. When I used print(type(x)) in the encode method the printed type is class list. Does this help?

I attached a link with three images and their labels from my dataset. Maby this help in order to
reproduce the error.

https://syncandshare.lrz.de/getlink/fiTVzefAAqvsDvSedo4Z4tsi/Data.zip

(It is CT data)

The answer is it does, but only during training, not inference. As you don’t have a label during inference, we don’t want that.

I’ll take a look later tonight at your code, but the answer should be rewriting your transform a bit (And will update here once I have that, just not free to do that currently)

1 Like

Thanks a lot. :slight_smile: I tried the following:

idx=3
img, mask = dl.do_item(idx)

I looked at img and mask. The transform was not applied after using split_idx=0.

Did you have time already to look at the code (sorry to bother you :))?

Thank you for your help.

Best regards

Simon