Error Message in learn.predict when using custom Datablock/ Dataloader

Hello everyone,
i need some help with my Fastai pipeline.
I want to do semantic segmentation on a 2 channel input image with augmentation.
I adapted my procedure from the good introduction in How to create a DataBlock for Multispectral Satellite Image Segmentation with the Fastai-v2 | by Maurício Cordeiro | Towards Data Science .
I have 2 channel images which are saved as numpy arrays (.npy).

See my code below (Sorry for all the screenshots):










I tried to predict images in three different ways and also with learn.get_preds() and the dataloader, but it was not successful. The problem seems to be the encodes function for the masks, and images for the augmentation.

When i run: cat, tensor, probs=learn.predict(img)

Following error appears, but i don’t know how to fix this.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_14397/663310027.py in <module>
----> 1 cat, tensor, probs=learn.predict(img)

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastai/learner.py in predict(self, item, rm_type_tfms, with_input)
    264     def predict(self, item, rm_type_tfms=None, with_input=False):
    265         dl = self.dls.test_dl([item], rm_type_tfms=rm_type_tfms, num_workers=0)
--> 266         inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
    267         i = getattr(self.dls, 'n_inp', -1)
    268         inp = (inp,) if i==1 else tuplify(inp)

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastai/learner.py in get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)
    251         if with_loss: ctx_mgrs.append(self.loss_not_reduced())
    252         with ContextManagers(ctx_mgrs):
--> 253             self._do_epoch_validate(dl=dl)
    254             if act is None: act = getattr(self.loss_func, 'activation', noop)
    255             res = cb.all_tensors()

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastai/learner.py in _do_epoch_validate(self, ds_idx, dl)
    201         if dl is None: dl = self.dls[ds_idx]
    202         self.dl = dl
--> 203         with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
    204 
    205     def _do_epoch(self):

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    161 
    162     def _with_events(self, f, event_type, ex, final=noop):
--> 163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
    165         self(f'after_{event_type}');  final()

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastai/learner.py in all_batches(self)
    167     def all_batches(self):
    168         self.n_iter = len(self.dl)
--> 169         for o in enumerate(self.dl): self.one_batch(*o)
    170 
    171     def _do_one_batch(self):

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastai/data/load.py in __iter__(self)
    107         self.before_iter()
    108         self.__idxs=self.get_idxs() # called in context of main process (not workers/subprocesses)
--> 109         for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
    110             if self.device is not None: b = to_device(b, self.device)
    111             yield self.after_batch(b)

~/miniconda3/envs/fastai/lib/python3.9/site-packages/torch/utils/data/dataloader.py in __next__(self)
    519             if self._sampler_iter is None:
    520                 self._reset()
--> 521             data = self._next_data()
    522             self._num_yielded += 1
    523             if self._dataset_kind == _DatasetKind.Iterable and \

~/miniconda3/envs/fastai/lib/python3.9/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    559     def _next_data(self):
    560         index = self._next_index()  # may raise StopIteration
--> 561         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    562         if self._pin_memory:
    563             data = _utils.pin_memory.pin_memory(data)

~/miniconda3/envs/fastai/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     32                 raise StopIteration
     33         else:
---> 34             data = next(self.dataset_iter)
     35         return self.collate_fn(data)
     36 

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastai/data/load.py in create_batches(self, samps)
    116         if self.dataset is not None: self.it = iter(self.dataset)
    117         res = filter(lambda o:o is not None, map(self.do_item, samps))
--> 118         yield from map(self.do_batch, self.chunkify(res))
    119 
    120     def new(self, dataset=None, cls=None, **kwargs):

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastcore/basics.py in chunked(it, chunk_sz, drop_last, n_chunks)
    214     if not isinstance(it, Iterator): it = iter(it)
    215     while True:
--> 216         res = list(itertools.islice(it, chunk_sz))
    217         if res and (len(res)==chunk_sz or not drop_last): yield res
    218         if len(res)<chunk_sz: return

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastai/data/load.py in do_item(self, s)
    131     def prebatched(self): return self.bs is None
    132     def do_item(self, s):
--> 133         try: return self.after_item(self.create_item(s))
    134         except SkipItemException: return None
    135     def chunkify(self, b): return b if self.prebatched else chunked(b, self.bs, self.drop_last)

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastcore/transform.py in __call__(self, o)
    198         self.fs = self.fs.sorted(key='order')
    199 
--> 200     def __call__(self, o): return compose_tfms(o, tfms=self.fs, split_idx=self.split_idx)
    201     def __repr__(self): return f"Pipeline: {' -> '.join([f.name for f in self.fs if f.name != 'noop'])}"
    202     def __getitem__(self,i): return self.fs[i]

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastcore/transform.py in compose_tfms(x, tfms, is_enc, reverse, **kwargs)
    148     for f in tfms:
    149         if not is_enc: f = f.decode
--> 150         x = f(x, **kwargs)
    151     return x
    152 

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastcore/transform.py in __call__(self, x, **kwargs)
    111     "A transform that always take tuples as items"
    112     _retain = True
--> 113     def __call__(self, x, **kwargs): return self._call1(x, '__call__', **kwargs)
    114     def decode(self, x, **kwargs):   return self._call1(x, 'decode', **kwargs)
    115     def _call1(self, x, name, **kwargs):

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastcore/transform.py in _call1(self, x, name, **kwargs)
    115     def _call1(self, x, name, **kwargs):
    116         if not _is_tuple(x): return getattr(super(), name)(x, **kwargs)
--> 117         y = getattr(super(), name)(list(x), **kwargs)
    118         if not self._retain: return y
    119         if is_listy(y) and not isinstance(y, tuple): y = tuple(y)

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastcore/transform.py in __call__(self, x, **kwargs)
     71     @property
     72     def name(self): return getattr(self, '_name', _get_name(self))
---> 73     def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)
     74     def decode  (self, x, **kwargs): return self._call('decodes', x, **kwargs)
     75     def __repr__(self): return f'{self.name}:\nencodes: {self.encodes}decodes: {self.decodes}'

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastcore/transform.py in _call(self, fn, x, split_idx, **kwargs)
     81     def _call(self, fn, x, split_idx=None, **kwargs):
     82         if split_idx!=self.split_idx and self.split_idx is not None: return x
---> 83         return self._do_call(getattr(self, fn), x, **kwargs)
     84 
     85     def _do_call(self, f, x, **kwargs):

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastcore/transform.py in _do_call(self, f, x, **kwargs)
     87             if f is None: return x
     88             ret = f.returns(x) if hasattr(f,'returns') else None
---> 89             return retain_type(f(x, **kwargs), x, ret)
     90         res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
     91         return retain_type(res, x)

~/miniconda3/envs/fastai/lib/python3.9/site-packages/fastcore/dispatch.py in __call__(self, *args, **kwargs)
    116         elif self.inst is not None: f = MethodType(f, self.inst)
    117         elif self.owner is not None: f = MethodType(f, self.owner)
--> 118         return f(*args, **kwargs)
    119 
    120     def __get__(self, inst, owner):

/tmp/ipykernel_14397/3758110305.py in encodes(self, x)
      7 
      8     def encodes(self, x):
----> 9             img,mask = x
     10             img = img/img.max()
     11             aug = self.aug(image=np.array(img.permute(1,2,0)), mask=np.array(mask))

ValueError: not enough values to unpack (expected 2, got 1)

Thank you for your help.

Best regards

Simon

Here is all my code in another format:

%matplotlib inline
import torch
print(torch.__version__)
print(torch.cuda.is_available())

import fastai
print(fastai.__version__)

# other imports
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image

from fastai.vision.all import *
from sklearn.model_selection import StratifiedKFold
import torch
import fastai
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image

from fastai.vision.all import *

imgs_path = Path('./data')
lbls_path = Path('./label1')

def open_im(fn, chnls=None, cls=torch.Tensor):
    im = (np.load(fn)).astype('float32')
    return cls(im)

# The map_filename function makes it easier to map from one folder to another by replacing strings
def map_filename(base_fn, str1, str2):
    return Path(str(base_fn).replace(str1, str2))
# get items from both datasets
items = get_files('./data', extensions='.npy')
masks = get_files('./label1', extensions='.npy')
items_mask = masks
items = items
items
idx=2
img_pipe = Pipeline(open_im)
img = img_pipe(items[idx])

mask_pipe = Pipeline([partial(map_filename, str1='data', str2='label1'), 
                      partial(open_im, cls=TensorMask)])

mask = mask_pipe(items_mask[idx])
print(img.shape, mask.shape)

_, ax = plt.subplots(1, 2, figsize=(12,5))
ax[0].imshow(img.permute(1, 2, 0)[..., :1]/20000)
mask.show(ctx=ax[1])
plt.show()

def show_img(tensor_img, ctx=None):
    ctx = plt.subplot() if ctx is None else ctx
    
    #normalize to fit between 0 and 1
    if tensor_img.max() > 0:
        tensor_img = tensor_img / tensor_img.max()
    
    ctx.imshow(tensor_img.permute(1, 2, 0)[..., :1])
    
# To create this DataBlock we don't need to specify the get_items function 
# because we will pass the list of files as the source
db = DataBlock(blocks=(TransformBlock([open_im, lambda x: x/10000]), 
                       TransformBlock([partial(map_filename, str1='data', str2='label1'), 
                      partial(open_im, cls=TensorMask)])),
               splitter=RandomSplitter(valid_pct=0.2, seed=0)
              )
db.summary(source=items)
  
ds = db.datasets(source=items)
dl = db.dataloaders(source=items, bs=1)
batch = dl.one_batch()
print(batch[0].shape, batch[1].shape)

import albumentations as A

import pdb
class SegmentationAlbumentationsTransform(ItemTransform):
#     split_idx=0
    def __init__(self, aug, **kwargs): 
        super().__init__(**kwargs)
        self.aug = aug
        
    def encodes(self, x):
            img,mask = x
            img = img/img.max()
            aug = self.aug(image=np.array(img.permute(1,2,0)), mask=np.array(mask))

            return TensorImage(aug['image'].transpose(2,0,1)), TensorMask(aug['mask'])
        

aug_pipe = A.Compose([
                      A.ShiftScaleRotate(p=.9),
                      A.HorizontalFlip(),
                      A.RandomCrop(384, 384),
                      A.Rotate(limit=(-90, 90)),
                     # A.RandomBrightnessContrast(contrast_limit=0.0, p=1., brightness_by_max=False)
                    ])

# Create our class with this aug_pipe
aug = SegmentationAlbumentationsTransform(aug_pipe)

# And check the results
idx = 5
aug_number = 4

# Display original and some augmented samples
_, ax = plt.subplots(aug_number+1, 2, figsize=(8,aug_number*4))

show_img(ds[idx][0], ctx=ax[0,0])
ds[idx][1].show(ctx=ax[0,1])

# print(ds[idx][0])

for i in range(1, aug_number+1):
    img, mask = aug.encodes(ds[idx])
    show_img(img, ctx=ax[i,0])
    mask.show(ctx=ax[i,1])



db = DataBlock(blocks=(TransformBlock([open_im]), 
                       TransformBlock([partial(map_filename, str1='data', str2='label1'), 
                      partial(open_im, cls=TensorMask)])),
               splitter=RandomSplitter(valid_pct=0.2),
               item_tfms=aug,
              )


dl = db.dataloaders(items, bs=1)

idx=3
img, mask = dl.do_item(idx)

fig, (ax1, ax2) = plt.subplots(1,2)
tensor_img=img

ax1.imshow(tensor_img.permute(1, 2, 0)[..., :1])

ax2.imshow(mask)
#axs[1].imshow(mask)
plt.show()

print(np.shape(img))
print(np.shape(mask))

def acc_metric(input, target):
    target = target.squeeze(1)
    return (input.argmax(dim=1)==target).float().mean()

def loss_fn(pred, targ):
    targ[targ==255] = 1
    return torch.nn.functional.cross_entropy(pred, targ.squeeze(1).type(torch.long))

db = DataBlock(blocks=(TransformBlock([open_im]), 
                       TransformBlock([partial(map_filename, str1='data', str2='label1'), 
                      partial(open_im, cls=TensorMask)])),
               splitter=RandomSplitter(valid_pct=0.2),
               item_tfms=aug,
              )

dl = db.dataloaders(items, bs=1)

learn = unet_learner(dls = dl, arch = resnet18, pretrained = True,normalize = False,n_in=2, n_out=2,  loss_func=loss_fn, metrics=acc_metric)
learn.lr_find()
learn.fit_one_cycle(20,lr_max=6e-5, wd=0.8)

learn.fine_tune(8)

learn.export()

img,mask=dl.do_item(3)

cat, tensor, probs=learn.predict(items[1], masks[1])

cat, tensor, probs=learn.predict(img)