I have trained a model for 4channel input to do segmentation.
learn.show_results() shows the predictions just fine, but when I try to do inference on a arbitrary picture it failes.
I suspect the problem is related to Albumenations augmentations, because I have previously trained the same dataset using only 3channels input using fastai augmentations. And that worked just fine.
I have been stuck for a week now, and begin to feel desparate.
Here is the callstack from when I try to predict image
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-41-e4ac01516bbf> in <module>
----> 1 prediction = learn.predict(img)
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastai\learner.py in predict(self, item, rm_type_tfms, with_input)
246 def predict(self, item, rm_type_tfms=None, with_input=False):
247 dl = self.dls.test_dl([item], rm_type_tfms=rm_type_tfms, num_workers=0)
--> 248 inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
249 i = getattr(self.dls, 'n_inp', -1)
250 inp = (inp,) if i==1 else tuplify(inp)
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastai\learner.py in get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)
233 if with_loss: ctx_mgrs.append(self.loss_not_reduced())
234 with ContextManagers(ctx_mgrs):
--> 235 self._do_epoch_validate(dl=dl)
236 if act is None: act = getattr(self.loss_func, 'activation', noop)
237 res = cb.all_tensors()
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastai\learner.py in _do_epoch_validate(self, ds_idx, dl)
186 if dl is None: dl = self.dls[ds_idx]
187 self.dl = dl
--> 188 with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
189
190 def _do_epoch(self):
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastai\learner.py in _with_events(self, f, event_type, ex, final)
153
154 def _with_events(self, f, event_type, ex, final=noop):
--> 155 try: self(f'before_{event_type}') ;f()
156 except ex: self(f'after_cancel_{event_type}')
157 finally: self(f'after_{event_type}') ;final()
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastai\learner.py in all_batches(self)
159 def all_batches(self):
160 self.n_iter = len(self.dl)
--> 161 for o in enumerate(self.dl): self.one_batch(*o)
162
163 def _do_one_batch(self):
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastai\data\load.py in __iter__(self)
100 self.before_iter()
101 self.__idxs=self.get_idxs() # called in context of main process (not workers/subprocesses)
--> 102 for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
103 if self.device is not None: b = to_device(b, self.device)
104 yield self.after_batch(b)
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
361
362 def __next__(self):
--> 363 data = self._next_data()
364 self._num_yielded += 1
365 if self._dataset_kind == _DatasetKind.Iterable and \
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
401 def _next_data(self):
402 index = self._next_index() # may raise StopIteration
--> 403 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
404 if self._pin_memory:
405 data = _utils.pin_memory.pin_memory(data)
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
32 raise StopIteration
33 else:
---> 34 data = next(self.dataset_iter)
35 return self.collate_fn(data)
36
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastai\data\load.py in create_batches(self, samps)
109 self.it = iter(self.dataset) if self.dataset is not None else None
110 res = filter(lambda o:o is not None, map(self.do_item, samps))
--> 111 yield from map(self.do_batch, self.chunkify(res))
112
113 def new(self, dataset=None, cls=None, **kwargs):
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\basics.py in chunked(it, chunk_sz, drop_last, n_chunks)
196 if not isinstance(it, Iterator): it = iter(it)
197 while True:
--> 198 res = list(itertools.islice(it, chunk_sz))
199 if res and (len(res)==chunk_sz or not drop_last): yield res
200 if len(res)<chunk_sz: return
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastai\data\load.py in do_item(self, s)
122 def prebatched(self): return self.bs is None
123 def do_item(self, s):
--> 124 try: return self.after_item(self.create_item(s))
125 except SkipItemException: return None
126 def chunkify(self, b): return b if self.prebatched else chunked(b, self.bs, self.drop_last)
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\transform.py in __call__(self, o)
196 self.fs.append(t)
197
--> 198 def __call__(self, o): return compose_tfms(o, tfms=self.fs, split_idx=self.split_idx)
199 def __repr__(self): return f"Pipeline: {' -> '.join([f.name for f in self.fs if f.name != 'noop'])}"
200 def __getitem__(self,i): return self.fs[i]
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\transform.py in compose_tfms(x, tfms, is_enc, reverse, **kwargs)
148 for f in tfms:
149 if not is_enc: f = f.decode
--> 150 x = f(x, **kwargs)
151 return x
152
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\transform.py in __call__(self, x, **kwargs)
111 "A transform that always take tuples as items"
112 _retain = True
--> 113 def __call__(self, x, **kwargs): return self._call1(x, '__call__', **kwargs)
114 def decode(self, x, **kwargs): return self._call1(x, 'decode', **kwargs)
115 def _call1(self, x, name, **kwargs):
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\transform.py in _call1(self, x, name, **kwargs)
115 def _call1(self, x, name, **kwargs):
116 if not _is_tuple(x): return getattr(super(), name)(x, **kwargs)
--> 117 y = getattr(super(), name)(list(x), **kwargs)
118 if not self._retain: return y
119 if is_listy(y) and not isinstance(y, tuple): y = tuple(y)
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\transform.py in __call__(self, x, **kwargs)
71 @property
72 def name(self): return getattr(self, '_name', _get_name(self))
---> 73 def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)
74 def decode (self, x, **kwargs): return self._call('decodes', x, **kwargs)
75 def __repr__(self): return f'{self.name}:\nencodes: {self.encodes}decodes: {self.decodes}'
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\transform.py in _call(self, fn, x, split_idx, **kwargs)
81 def _call(self, fn, x, split_idx=None, **kwargs):
82 if split_idx!=self.split_idx and self.split_idx is not None: return x
---> 83 return self._do_call(getattr(self, fn), x, **kwargs)
84
85 def _do_call(self, f, x, **kwargs):
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\transform.py in _do_call(self, f, x, **kwargs)
87 if f is None: return x
88 ret = f.returns_none(x) if hasattr(f,'returns_none') else None
---> 89 return retain_type(f(x, **kwargs), x, ret)
90 res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
91 return retain_type(res, x)
c:\users\tosv\appdata\local\programs\python\python38\lib\site-packages\fastcore\dispatch.py in __call__(self, *args, **kwargs)
127 elif self.inst is not None: f = MethodType(f, self.inst)
128 elif self.owner is not None: f = MethodType(f, self.owner)
--> 129 return f(*args, **kwargs)
130
131 def __get__(self, inst, owner):
<ipython-input-16-5c3084d812ae> in encodes(self, x)
5
6 def encodes(self, x):
----> 7 img,mask = x
8
9 # for albumentations to work correctly, the channels must be at the last dimension
ValueError: not enough values to unpack (expected 2, got 1)
Here is the full code:
def read_cmv(fn):
f = open(fn,'rb')
cmv_type= int(np.fromfile(f, dtype='>i4', count=1 ))
cmv_payload_len = int(np.fromfile(f, dtype='>u4', count=1))
if cmv_type != 7 :
raise NotImplementedError(f'CMV type {cmv_type} is not implemented')
cmv_time = float(np.fromfile(f, dtype='>f8', count = 1))
cmv_size_x, cmv_size_y = np.fromfile(f, dtype='>u4', count = 2)
# CHW order
return torch.Tensor(np.fromfile(f, dtype = 'f4', count = cmv_size_x*cmv_size_y)).view(cmv_size_y,cmv_size_x).flip([0])
def normalise_cmv(tn):
min_val = -82.4148
max_val = 68.7552
return (tn - min_val)*255./(max_val-min_val)
def get_cmv_filename(imagePath):
return path.joinpath('height', f'{imagePath.stem}_height.cmv')
def open_rgb_cmv(fn):
im_rgb = TensorImage(PILImage.create(fn)).permute((2,0,1)) # HWC to CHW order
im_cmv = normalise_cmv(read_cmv(get_cmv_filename(fn)))
return TensorImage(torch.cat((im_rgb,torch.unsqueeze(im_cmv,0)),0))
import albumentations as A
class AlbumentationsTransform(ItemTransform):
split_idx = None
def __init__(self, aug): self.aug = aug
def encodes(self, x):
img,mask = x
# for albumentations to work correctly, the channels must be at the last dimension
aug = self.aug(image=np.array(img.permute(1,2,0)), mask=np.array(mask))
return TensorImage(aug['image'].transpose(2,0,1)), TensorMask(aug['mask'])
class SegmentationAlbumentationsTransform(ItemTransform):
split_idx = 0
def __init__(self, aug): self.aug = aug
def encodes(self, x):
img,mask = x
# for albumentations to work correctly, the channels must be at the last dimension
aug = self.aug(image=np.array(img.permute(1,2,0)), mask=np.array(mask))
return TensorImage(aug['image'].transpose(2,0,1)), TensorMask(aug['mask'])
def get_y(inPath):
return path.joinpath('labels', f'{inPath.stem}.png')
dataBlock = DataBlock(
blocks=(TransformBlock(open_rgb_cmv,batch_tfms=IntToFloatTensor ),MaskBlock(codes=np.loadtxt(path/'codes.txt', dtype=str))),
get_items=get_image_files,
splitter=RandomSplitter(seed=42, valid_pct=0.3),
get_y=get_y,
item_tfms=[AlbumentationsTransform(A.PadIfNeeded(min_height=511,min_width=511)), SegmentationAlbumentationsTransform(A.Compose([A.ShiftScaleRotate(p=0.9), A.HorizontalFlip(), A.RandomBrightnessContrast(contrast_limit=0.1, brightness_by_max=False)]))],
batch_tfms= Normalize()
)
dls = dataBlock.dataloaders(path/"images", bs=12, num_workers=0)
from fastai.callback.fp16 import *
learn = unet_learner(dls, resnet18, n_in=4,normalize=False,pretrained=False,
metrics=[JaccardCoeff()]).to_fp16()
learn.fit_one_cycle(10, 0.0002)
images = get_image_files(path/"images")
img = open_rgb_cmv(images[22]);
i2f_tfm = IntToFloatTensor()
i2f_tfm(img)
# Here is where it failes
prediction = learn.predict(img)