Just slightly change the source code of SegmentationLabelList in your custom SegmentationLL
class SegmentationLL(ImageItemList):
def __init__(self, items:Iterator, classes:Collection=None, **kwargs):
super().__init__(items, **kwargs)
self.classes,self.loss_func,self.create_func = classes,CrossEntropyFlat(),partial(open_mask, div=True)
self.c = len(self.classes)
def new(self, items, classes=None, **kwargs):
return self.__class__(items, ifnone(classes, self.classes), **kwargs)
Or even shorter:
class SegmentationLL(SegmentationItemList):
def __init__(self, items:Iterator, classes:Collection=None, **kwargs):
super().__init__(items, classes, **kwargs)
self.create_func = partial(open_mask, div=True)