I’m trying to implement a Super Resolution Classifier. Meaning Do both super-resolution and classification using the embeddings inside DynamicUnet. Creating the model and loss was straight forward for me since that is very similar to Pytorch. However, I’m having a lot of trouble creating the dataset to feed it.
All I really need is an ImageImageList but with two output labels as opposed to just the image output label. I read creating a custom ItemList tutorial and check SegmentationLabelList for reference but still haven’t quite got it to work as desired. Below is my code
class ImageLabel:
def __init__(self, image, label):
self.image, self.label = image, label
self.obj, self.data = (image, label), [image.data, label.data]
def apply_tfms(self, tfms, **kwargs):
self.image = self.image.apply_tfms(tfms, **kwargs)
self.data = [self.image.data, self.label.data]
return self
def show(self, ax:plt.Axes, **kwargs):
"Subclass this method if you want to customize the way this `ItemBase` is shown on `ax`."
ax.set_title(str(self))
# def __repr__(self)->str: return f'{self.__class__.__name__} {str(self)}'
def __eq__(self, other): return recurse_eq(self.data, other.data)
class ImageLabelList(ImageList):
def __init__(self, items:Iterator, labels2:Collection=None, classes:Collection=None, **kwargs):
super().__init__(items, **kwargs)
self.copy_new.append('labels2')
self.copy_new.append('classes')
self.labels = labels2
self.classes = classes
def get(self, i):
image = super().get(i)
o = self.labels[i]
label = Category(o, self.classes[o])
return ImageLabel(image, label)
def analyze_pred(self, pred, thresh: float=0.5): return pred.argmax()
def reconstruct(self, t: Tensor):
return ImageLabel(Image(t[0]), Category(t[1], self.classes[t[1]]))
class ImageImageLabelList(ImageList):
"`ItemList` suitable for `Image` to `Image` tasks."
_label_cls,_square_show,_square_show_res = ImageLabelList,False,False
def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
"Show the `xs` (inputs) and `ys`(targets) on a figure of `figsize`."
axs = subplots(len(xs), 2, imgsize=imgsize, figsize=figsize)
for i, (x,y) in enumerate(zip(xs,ys[:, 0])):
x.show(ax=axs[i,0], **kwargs)
y.show(ax=axs[i,1], **kwargs)
plt.tight_layout()
def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
"Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`."
title = 'Input / Prediction / Target'
axs = subplots(len(xs), 3, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=14)
for i,(x,y,z) in enumerate(zip(xs,ys[:, 0],zs[:, 0])):
x.show(ax=axs[i,0], **kwargs)
y.show(ax=axs[i,2], **kwargs)
z.show(ax=axs[i,1], **kwargs)
This is how I’m running it as a dummy example:
df = pd.read_csv(Path('<path>')/'<csv_file>', header='infer')
classes = df.iloc[:,df_names_to_idx(1, df)].values.squeeze().tolist()
labels = {name: idx if name !=-100 else _id for idx, name in enumerate(sorted(set(classes)))}
labels = [labels[x] for x in classes]
src2 = (ImageImageLabelList.from_folder(path_lr)
.split_by_rand_pct(0.1, seed=42)
.label_from_func(lambda x: path_hr/x.name, labels2=labels, classes=classes)
.transform(get_transforms(max_zoom=2.), size=64, tfm_y=True)
.databunch(bs=8).normalize(imagenet_stats, do_y=True))
src2.show_batch(ds_type=DatasetType.Valid, rows=2, figsize=(9,9))
Right now I’m getting collate errors and had to comment def __repr__(self)->str: return f'{self.__class__.__name__} {str(self)}'
since it was leading to infinite recursion error.
What am I doing wrong? Also, I would love to be able to have a more straightforward integration of Pytorch dataset and Dataloader(dataset)
. I love data block API it’s amazing but modifying it or creating a new Dataset seem like a pain point right now.