Error when calling transform on SegmentationLabelList

I am trying to load data for a segmentation task, using the very classic following function:

class MultiClassSegList(SegmentationLabelList):
    def open(self, id_rles):
        image_id, rles = id_rles[0], id_rles[1:]
        shape = open_image(self.path/image_id).shape[-2:]       
        final_mask = torch.zeros((1, *shape)).long()
        for k, rle in enumerate(rles):
            if isinstance(rle, str):
                mask = open_mask_rle(rle, shape).data.permute(0, 2, 1)
                final_mask += (k+1)*mask
        return ImageSegment(final_mask)

def load_data(path, csv, bs=32, size=(128, 800)):
    train_list = (SegmentationItemList.
                  from_csv(path, csv).
                  split_by_rand_pct(valid_pct=0.2).
                  label_from_df(cols=list(range(5)), label_cls=MultiClassSegList, classes=[0, 1, 2, 3, 4]).
                  transform(get_transforms(), size=size, tfm_y=True).
                  databunch(bs=bs, num_workers=0).
                  normalize(imagenet_stats))
    return train_list

I keep getting the following error:

RuntimeError                              Traceback (most recent call last)
/opt/conda/lib/python3.6/site-packages/fastai/data_block.py in _check_kwargs(ds, tfms, **kwargs)
    590         x = ds[0]
--> 591         try: x.apply_tfms(tfms, **kwargs)
    592         except Exception as e:

/opt/conda/lib/python3.6/site-packages/fastai/vision/image.py in apply_tfms(self, tfms, do_resolve, xtra, size, resize_method, mult, padding_mode, mode, remove_out)
    121                 if resize_method in (ResizeMethod.CROP,ResizeMethod.PAD):
--> 122                     x = tfm(x, size=_get_crop_target(size,mult=mult), padding_mode=padding_mode)
    123             else: x = tfm(x)

/opt/conda/lib/python3.6/site-packages/fastai/vision/image.py in __call__(self, x, *args, **kwargs)
    517         "Randomly execute our tfm on `x`."
--> 518         return self.tfm(x, *args, **{**self.resolved, **kwargs}) if self.do_run else x
    519 

/opt/conda/lib/python3.6/site-packages/fastai/vision/image.py in __call__(self, p, is_random, use_on_y, *args, **kwargs)
    463         "Calc now if `args` passed; else create a transform called prob `p` if `random`."
--> 464         if args: return self.calc(*args, **kwargs)
    465         else: return RandTransform(self, kwargs=kwargs, is_random=is_random, use_on_y=use_on_y, p=p)

/opt/conda/lib/python3.6/site-packages/fastai/vision/image.py in calc(self, x, *args, **kwargs)
    468         "Apply to image `x`, wrapping it if necessary."
--> 469         if self._wrap: return getattr(x, self._wrap)(self.func, *args, **kwargs)
    470         else:          return self.func(x, *args, **kwargs)

/opt/conda/lib/python3.6/site-packages/fastai/vision/image.py in pixel(self, func, *args, **kwargs)
    171         "Equivalent to `image.px = func(image.px)`."
--> 172         self.px = func(self.px, *args, **kwargs)
    173         return self

/opt/conda/lib/python3.6/site-packages/fastai/vision/image.py in px(self)
    144         "Get the tensor pixel buffer."
--> 145         self.refresh()
    146         return self._px

/opt/conda/lib/python3.6/site-packages/fastai/vision/image.py in refresh(self)
    229         self.sample_kwargs['mode'] = 'nearest'
--> 230         return super().refresh()
    231 

/opt/conda/lib/python3.6/site-packages/fastai/vision/image.py in refresh(self)
    131         if self._affine_mat is not None or self._flow is not None:
--> 132             self._px = _grid_sample(self._px, self.flow, **self.sample_kwargs)
    133             self.sample_kwargs = {}

/opt/conda/lib/python3.6/site-packages/fastai/vision/image.py in _grid_sample(x, coords, mode, padding_mode, remove_out)
    534         if d>1 and d>z: x = F.interpolate(x[None], scale_factor=1/d, mode='area')[0]
--> 535     return F.grid_sample(x[None], coords, mode=mode, padding_mode=padding_mode)[0]
    536 

/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py in grid_sample(input, grid, mode, padding_mode)
   2716 
-> 2717     return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum)
   2718 

RuntimeError: grid_sampler(): expected input and grid to have same dtype, but input has long and grid has float

During handling of the above exception, another exception occurred:

Exception                                 Traceback (most recent call last)
<ipython-input-39-e7670957d52d> in <module>
----> 1 db = load_data(TRAIN_PATH, LABELS, bs=BATCH_SIZE, size=TRAIN_SIZE)

<ipython-input-36-15e1d1ba275a> in load_data(path, csv, bs, size)
      4                   split_by_rand_pct(valid_pct=0.2).
      5                   label_from_df(cols=list(range(5)), label_cls=MultiClassSegList, classes=[0, 1, 2, 3, 4]).
----> 6                   transform(get_transforms(), size=size, tfm_y=True).
      7                   databunch(bs=bs, num_workers=0).
      8                   normalize(imagenet_stats))

/opt/conda/lib/python3.6/site-packages/fastai/data_block.py in transform(self, tfms, **kwargs)
    500         if not tfms: tfms=(None,None)
    501         assert is_listy(tfms) and len(tfms) == 2, "Please pass a list of two lists of transforms (train and valid)."
--> 502         self.train.transform(tfms[0], **kwargs)
    503         self.valid.transform(tfms[1], **kwargs)
    504         if self.test: self.test.transform(tfms[1], **kwargs)

/opt/conda/lib/python3.6/site-packages/fastai/data_block.py in transform(self, tfms, tfm_y, **kwargs)
    722         if tfm_y is None: tfm_y = self.tfm_y
    723         tfms_y = None if tfms is None else list(filter(lambda t: getattr(t, 'use_on_y', True), listify(tfms)))
--> 724         if tfm_y: _check_kwargs(self.y, tfms_y, **kwargs)
    725         self.tfms,self.tfmargs  = tfms,kwargs
    726         self.tfm_y,self.tfms_y,self.tfmargs_y = tfm_y,tfms_y,kwargs

/opt/conda/lib/python3.6/site-packages/fastai/data_block.py in _check_kwargs(ds, tfms, **kwargs)
    591         try: x.apply_tfms(tfms, **kwargs)
    592         except Exception as e:
--> 593             raise Exception(f"It's not possible to apply those transforms to your dataset:\n {e}")
    594 
    595 class LabelList(Dataset):

Exception: It's not possible to apply those transforms to your dataset:
 grid_sampler(): expected input and grid to have same dtype, but input has long and grid has float

I understand my labels are of type long and my inputs of type float, but I don’t understand how it suddenly becomes a problem, while it works perfectly with the camvid notebook from the course for instance. There is probably a problem with my custom class but I can’t see where it comes from.

Found it, even though ImageSegment.data returns a long tensor, it stores a float tensor in px, so I need to call ImageSegment on a float tensor instead of a long.

1 Like