Segmentation Learn.predict() failes when using Albumentations

I have trained a model for 4channel input to do segmentation.

learn.show_results() shows the predictions just fine, but when I try to do inference on a arbitrary picture it failes.

I suspect the problem is related to Albumenations augmentations, because I have previously trained the same dataset using only 3channels input using fastai augmentations. And that worked just fine.

I have been stuck for a week now, and begin to feel desparate.

Here is the callstack from when I try to predict image

ValueError                                Traceback (most recent call last)
<ipython-input-41-e4ac01516bbf> in <module>
----> 1 prediction = learn.predict(img)

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastai\ in predict(self, item, rm_type_tfms, with_input)
    246     def predict(self, item, rm_type_tfms=None, with_input=False):
    247         dl = self.dls.test_dl([item], rm_type_tfms=rm_type_tfms, num_workers=0)
--> 248         inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
    249         i = getattr(self.dls, 'n_inp', -1)
    250         inp = (inp,) if i==1 else tuplify(inp)

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastai\ in get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)
    233         if with_loss: ctx_mgrs.append(self.loss_not_reduced())
    234         with ContextManagers(ctx_mgrs):
--> 235             self._do_epoch_validate(dl=dl)
    236             if act is None: act = getattr(self.loss_func, 'activation', noop)
    237             res = cb.all_tensors()

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastai\ in _do_epoch_validate(self, ds_idx, dl)
    186         if dl is None: dl = self.dls[ds_idx]
    187         self.dl = dl
--> 188         with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
    190     def _do_epoch(self):

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastai\ in _with_events(self, f, event_type, ex, final)
    154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
    157         finally:   self(f'after_{event_type}')        ;final()

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastai\ in all_batches(self)
    159     def all_batches(self):
    160         self.n_iter = len(self.dl)
--> 161         for o in enumerate(self.dl): self.one_batch(*o)
    163     def _do_one_batch(self):

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastai\data\ in __iter__(self)
    100         self.before_iter()
    101         self.__idxs=self.get_idxs() # called in context of main process (not workers/subprocesses)
--> 102         for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
    103             if self.device is not None: b = to_device(b, self.device)
    104             yield self.after_batch(b)

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\torch\utils\data\ in __next__(self)
    362     def __next__(self):
--> 363         data = self._next_data()
    364         self._num_yielded += 1
    365         if self._dataset_kind == _DatasetKind.Iterable and \

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\torch\utils\data\ in _next_data(self)
    401     def _next_data(self):
    402         index = self._next_index()  # may raise StopIteration
--> 403         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    404         if self._pin_memory:
    405             data = _utils.pin_memory.pin_memory(data)

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\torch\utils\data\_utils\ in fetch(self, possibly_batched_index)
     32                 raise StopIteration
     33         else:
---> 34             data = next(self.dataset_iter)
     35         return self.collate_fn(data)

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastai\data\ in create_batches(self, samps)
    109 = iter(self.dataset) if self.dataset is not None else None
    110         res = filter(lambda o:o is not None, map(self.do_item, samps))
--> 111         yield from map(self.do_batch, self.chunkify(res))
    113     def new(self, dataset=None, cls=None, **kwargs):

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\ in chunked(it, chunk_sz, drop_last, n_chunks)
    196     if not isinstance(it, Iterator): it = iter(it)
    197     while True:
--> 198         res = list(itertools.islice(it, chunk_sz))
    199         if res and (len(res)==chunk_sz or not drop_last): yield res
    200         if len(res)<chunk_sz: return

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastai\data\ in do_item(self, s)
    122     def prebatched(self): return is None
    123     def do_item(self, s):
--> 124         try: return self.after_item(self.create_item(s))
    125         except SkipItemException: return None
    126     def chunkify(self, b): return b if self.prebatched else chunked(b,, self.drop_last)

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\ in __call__(self, o)
    196         self.fs.append(t)
--> 198     def __call__(self, o): return compose_tfms(o, tfms=self.fs, split_idx=self.split_idx)
    199     def __repr__(self): return f"Pipeline: {' -> '.join([ for f in self.fs if != 'noop'])}"
    200     def __getitem__(self,i): return self.fs[i]

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\ in compose_tfms(x, tfms, is_enc, reverse, **kwargs)
    148     for f in tfms:
    149         if not is_enc: f = f.decode
--> 150         x = f(x, **kwargs)
    151     return x

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\ in __call__(self, x, **kwargs)
    111     "A transform that always take tuples as items"
    112     _retain = True
--> 113     def __call__(self, x, **kwargs): return self._call1(x, '__call__', **kwargs)
    114     def decode(self, x, **kwargs):   return self._call1(x, 'decode', **kwargs)
    115     def _call1(self, x, name, **kwargs):

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\ in _call1(self, x, name, **kwargs)
    115     def _call1(self, x, name, **kwargs):
    116         if not _is_tuple(x): return getattr(super(), name)(x, **kwargs)
--> 117         y = getattr(super(), name)(list(x), **kwargs)
    118         if not self._retain: return y
    119         if is_listy(y) and not isinstance(y, tuple): y = tuple(y)

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\ in __call__(self, x, **kwargs)
     71     @property
     72     def name(self): return getattr(self, '_name', _get_name(self))
---> 73     def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)
     74     def decode  (self, x, **kwargs): return self._call('decodes', x, **kwargs)
     75     def __repr__(self): return f'{}:\nencodes: {self.encodes}decodes: {self.decodes}'

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\ in _call(self, fn, x, split_idx, **kwargs)
     81     def _call(self, fn, x, split_idx=None, **kwargs):
     82         if split_idx!=self.split_idx and self.split_idx is not None: return x
---> 83         return self._do_call(getattr(self, fn), x, **kwargs)
     85     def _do_call(self, f, x, **kwargs):

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\ in _do_call(self, f, x, **kwargs)
     87             if f is None: return x
     88             ret = f.returns_none(x) if hasattr(f,'returns_none') else None
---> 89             return retain_type(f(x, **kwargs), x, ret)
     90         res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
     91         return retain_type(res, x)

c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\ in __call__(self, *args, **kwargs)
    127         elif self.inst is not None: f = MethodType(f, self.inst)
    128         elif self.owner is not None: f = MethodType(f, self.owner)
--> 129         return f(*args, **kwargs)
    131     def __get__(self, inst, owner):

<ipython-input-16-5c3084d812ae> in encodes(self, x)
      6     def encodes(self, x):
----> 7          img,mask = x
      9          # for albumentations to work correctly, the channels must be at the last dimension

ValueError: not enough values to unpack (expected 2, got 1)

Here is the full code:

def read_cmv(fn):
    f = open(fn,'rb')
    cmv_type= int(np.fromfile(f, dtype='>i4', count=1 ))
    cmv_payload_len = int(np.fromfile(f, dtype='>u4', count=1))

    if cmv_type != 7 : 
        raise NotImplementedError(f'CMV type {cmv_type} is not implemented') 

    cmv_time = float(np.fromfile(f, dtype='>f8', count = 1))
    cmv_size_x, cmv_size_y  = np.fromfile(f, dtype='>u4', count = 2)
    # CHW order
    return torch.Tensor(np.fromfile(f, dtype = 'f4', count = cmv_size_x*cmv_size_y)).view(cmv_size_y,cmv_size_x).flip([0])

def normalise_cmv(tn):
    min_val = -82.4148
    max_val = 68.7552
    return (tn - min_val)*255./(max_val-min_val)

def get_cmv_filename(imagePath):
    return path.joinpath('height', f'{imagePath.stem}_height.cmv')

def open_rgb_cmv(fn):
    im_rgb = TensorImage(PILImage.create(fn)).permute((2,0,1)) # HWC to CHW order
    im_cmv = normalise_cmv(read_cmv(get_cmv_filename(fn)))
    return TensorImage(,torch.unsqueeze(im_cmv,0)),0))

import albumentations as A

class AlbumentationsTransform(ItemTransform):
    split_idx = None
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
         img,mask = x

         # for albumentations to work correctly, the channels must be at the last dimension
         aug = self.aug(image=np.array(img.permute(1,2,0)), mask=np.array(mask))
         return TensorImage(aug['image'].transpose(2,0,1)), TensorMask(aug['mask'])

class SegmentationAlbumentationsTransform(ItemTransform):
    split_idx = 0
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
        img,mask = x

        # for albumentations to work correctly, the channels must be at the last dimension
        aug = self.aug(image=np.array(img.permute(1,2,0)), mask=np.array(mask))
        return TensorImage(aug['image'].transpose(2,0,1)), TensorMask(aug['mask'])

def get_y(inPath):
    return path.joinpath('labels', f'{inPath.stem}.png')

dataBlock = DataBlock( 
    blocks=(TransformBlock(open_rgb_cmv,batch_tfms=IntToFloatTensor ),MaskBlock(codes=np.loadtxt(path/'codes.txt', dtype=str))),
    splitter=RandomSplitter(seed=42, valid_pct=0.3),
    item_tfms=[AlbumentationsTransform(A.PadIfNeeded(min_height=511,min_width=511)), SegmentationAlbumentationsTransform(A.Compose([A.ShiftScaleRotate(p=0.9), A.HorizontalFlip(), A.RandomBrightnessContrast(contrast_limit=0.1, brightness_by_max=False)]))],
    batch_tfms= Normalize()

dls = dataBlock.dataloaders(path/"images", bs=12, num_workers=0)

from fastai.callback.fp16 import *

learn = unet_learner(dls, resnet18, n_in=4,normalize=False,pretrained=False,

learn.fit_one_cycle(10,  0.0002)

images = get_image_files(path/"images")
img = open_rgb_cmv(images[22]);
i2f_tfm = IntToFloatTensor()

# Here is where it failes
prediction = learn.predict(img)

Maybe I’m missing something but when the transform is run on the validation set, why would we expect a mask to be returned? Isn’t the mask the label which would not be part of the predict() sequence?

Sorry, I’m not sure what you mean. Are you talking about my custom transforms (AlbumenationsTransform and SegmentationAlbumenationTransform)?

(FYI. Im a noob at this. Started learning 6 weeks ago )

I tried to follow the example here

Some of my input does not have size 511x511. So I try to pad the input using AlbumentationsTransform(A.PadIfNeeded(min_height=511,min_width=511)) (which has split_idx = none)

And then the idea was to not augment the validations, so I created SegmentationAlbumentationsTransform with has split_idx=0.

What puzzles me is that learn.show_results() works, and learn.predict() fails. In my mind show_result() does something similar as predict under the hood.

No need to apologize. I’m no expert myself. But what may be happening is that when you call learn.predict(img) and than img goes through the Albumentations transforms that you have written, the encodes() expect for x to be 2-item tuple - an img and a mask. But the mask is the label that the img is trying to predict and it not being supplied with learn.predict(). The fact that show_results() works aligns with this theory because show_results() pulls a batch of images and label masks and so it has both available to it. Again, I may be off on all this but I just don’t see how x could be unpacked into an image and a label mask when you are not providing a label mask.

Any suggestions on what I can try to do?

I don’t think you want to apply those transforms to the validation, do you? If not, I think both Albumentations classes you wrote should split_idx = 0 as a class variable so that it only gets applied at training time to the training images. Does it work if you do that?

You should follow @Patrick’s advice and set the split_idx to 0 (this is also shown later in the Albumentations tutorial):

 class SegmentationAlbumentationsTransform(ItemTransform):
    split_idx = 0
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
        img,mask = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"])

Some of my input does not have size 511x511, e.g 449x511. So I’m thinking that I always need to pad the input to 511x511. Thats why I have two classes , AlbumentationsTransform and SegmentationAlbumentationsTransform.

AlbumentationsTransform set split_idx = None and I only supply it with A.PadIfNeeded(min_height=511,min_width=511).

The rest of the augmentations goes into SegmentationAlbumentationsTransform, with split_idx = 0. Which I only want during training.

But AlbumentationsTransform does not seem to work during learn.predict(), as it’s encodes() expects a tuple. So I’m currently fishing for idéas on how I can pad the input during learn.predict().

Any suggestions are greatly appreciated.

I solved my issue by creating a new dataloader while predicting. With different item_tfms.

class AlbumentationsPredictionTransform(ItemTransform):
    split_idx = None
    def __init__(self, aug): self.aug = aug
    def encodes(self, img: TensorImage):
        #For albumentations to work correctly, the channels must be at the last dimension. (Permute)
        aug_img = self.aug(image=np.array(img.permute(1,2,0)))['image']
        return TensorImage(aug_img.transpose(2,0,1))
    def encodes(self, msk: PILMask):
        #For albumentations to work correctly, the channels must be at the last dimension. (Permute)
        aug_msk = self.aug(image=np.array(msk))['image']
        return PILMask.create(aug_msk)

    def encodes(self, img: PILImage):
        aug_img = self.valid_aug(image=np.array(img))['image']
        return PILImage.create(aug_img)

pred_dataBlock = DataBlock( 
    blocks=(TransformBlock(open_rgb_cmv,batch_tfms=IntToFloatTensor ),MaskBlock(codes=np.loadtxt(path/'codes.txt', dtype=str))),
    splitter=RandomSplitter(seed=42, valid_pct=0.3),
    batch_tfms= Normalize()

learn.dls = pred_dataBlock.dataloaders(path/"images", bs=12, num_workers=0)

p= learn.predict(images[90])

Yes - had this issue also. Another solution I found, after much effort is below.

class SegmentationAlbumentationsTransform(ItemTransform):
#    split_idx=0
    def __init__(self, aug, **kwargs): 
        self.aug = aug
    def encodes(self, x: tuple):   #<== add the word 'tuple'
#        img = img/img.max()
        img,mask = x
        aug = self.aug(image=np.array(img.permute(1,2,0)), mask=np.array(mask))
        the_ret = TensorImage(aug['image'].transpose(2,0,1)), TensorMask(aug['mask'])        
        return the_ret
    def encodes(self, img: TensorImage):   #<== add this entire function (now used by learn.predict)
        #For albumentations to work correctly, the channels must be at the last dimension. (Permute)
        aug_img = self.aug(image=np.array(img.permute(1,2,0)))
        return TensorImage(aug_img['image'].transpose(2,0,1))    

So now you don’t need class AlbumentationsPredictionTransform or pred_dataBlock any more and can continue using your original ‘learn’.

1 Like

Brilliant! Thanks for sharing that solution!