@sgugger
This is all I got to but I’m sure @neuronq can take it from here
class SegmentationTileItemList(SegmentationItemList):
def __init__(self, segments_per_image, *args, **kwargs):
super().__init__(*args, **kwargs)
self.segments_per_image = segments_per_image
def get_image_segment(self, full_image, segment_idx):
pass # implement this function
def get(self, i):
segment_idx = index % self.segments_per_image
image_idx = i // self.segments_per_image
fn = super().get(image_idx)
full_image = self.open(fn)
res = self.get_image_segment(full_image, segment_idx)
self.sizes[i] = res.size
return res
class SegmentationTileLabelList(SegmentationLabelList):
def __init__(self, segments_per_label, *args, **kwargs):
super().__init__(*args, **kwargs)
self.segments_per_label = segments_per_label
def get_label_segment(self, full_label, segment_idx):
pass # implement this function
def get(self, i):
segment_idx = index % self.segments_per_label
label_idx = i // self.segments_per_label
fn = super().get(label_idx)
full_label = self.open(fn)
res = self.get_label_segment(full_label, segment_idx)
self.sizes[i] = res.size
return res