ImageImageLabelList Super Resolution Classifier

Hello @jeremy and @sgugger

I’m trying to implement a Super Resolution Classifier. Meaning Do both super-resolution and classification using the embeddings inside DynamicUnet. Creating the model and loss was straight forward for me since that is very similar to Pytorch. However, I’m having a lot of trouble creating the dataset to feed it.

All I really need is an ImageImageList but with two output labels as opposed to just the image output label. I read creating a custom ItemList tutorial and check SegmentationLabelList for reference but still haven’t quite got it to work as desired. Below is my code

class ImageLabel:
    def __init__(self, image, label):
        self.image, self.label = image, label
        self.obj, self.data = (image, label), [image.data, label.data]

    def apply_tfms(self, tfms, **kwargs):
        self.image = self.image.apply_tfms(tfms, **kwargs)
        self.data = [self.image.data, self.label.data]
        return self

    def show(self, ax:plt.Axes, **kwargs):
        "Subclass this method if you want to customize the way this `ItemBase` is shown on `ax`."
        ax.set_title(str(self))
    # def __repr__(self)->str: return f'{self.__class__.__name__} {str(self)}'
    def __eq__(self, other): return recurse_eq(self.data, other.data)

class ImageLabelList(ImageList):
    def __init__(self, items:Iterator, labels2:Collection=None, classes:Collection=None, **kwargs):
        super().__init__(items, **kwargs)
        self.copy_new.append('labels2')
        self.copy_new.append('classes')
        self.labels = labels2
        self.classes = classes

    def get(self, i):
        image = super().get(i)
        o = self.labels[i]
        label = Category(o, self.classes[o])
        return ImageLabel(image, label)

    def analyze_pred(self, pred, thresh: float=0.5): return pred.argmax()

    def reconstruct(self, t: Tensor):
        return ImageLabel(Image(t[0]), Category(t[1], self.classes[t[1]]))


class ImageImageLabelList(ImageList):
    "`ItemList` suitable for `Image` to `Image` tasks."
    _label_cls,_square_show,_square_show_res = ImageLabelList,False,False

    def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
        "Show the `xs` (inputs) and `ys`(targets)  on a figure of `figsize`."
        axs = subplots(len(xs), 2, imgsize=imgsize, figsize=figsize)
        for i, (x,y) in enumerate(zip(xs,ys[:, 0])):
            x.show(ax=axs[i,0], **kwargs)
            y.show(ax=axs[i,1], **kwargs)
        plt.tight_layout()

    def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
        "Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`."
        title = 'Input / Prediction / Target'
        axs = subplots(len(xs), 3, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=14)
        for i,(x,y,z) in enumerate(zip(xs,ys[:, 0],zs[:, 0])):
            x.show(ax=axs[i,0], **kwargs)
            y.show(ax=axs[i,2], **kwargs)
            z.show(ax=axs[i,1], **kwargs)

This is how I’m running it as a dummy example:

df = pd.read_csv(Path('<path>')/'<csv_file>', header='infer')
classes = df.iloc[:,df_names_to_idx(1, df)].values.squeeze().tolist()
labels = {name: idx if name !=-100 else _id  for idx, name in enumerate(sorted(set(classes)))}
labels = [labels[x] for x in classes]
src2 = (ImageImageLabelList.from_folder(path_lr)
        .split_by_rand_pct(0.1, seed=42)
        .label_from_func(lambda x: path_hr/x.name, labels2=labels, classes=classes)
        .transform(get_transforms(max_zoom=2.), size=64, tfm_y=True)
        .databunch(bs=8).normalize(imagenet_stats, do_y=True))
src2.show_batch(ds_type=DatasetType.Valid, rows=2, figsize=(9,9))

Right now I’m getting collate errors and had to comment def __repr__(self)->str: return f'{self.__class__.__name__} {str(self)}' since it was leading to infinite recursion error.

What am I doing wrong? Also, I would love to be able to have a more straightforward integration of Pytorch dataset and Dataloader(dataset). I love data block API it’s amazing but modifying it or creating a new Dataset seem like a pain point right now.

@avn3r Hey, did you ever happen to get this working?