Unet_binary segmentation

Just slightly change the source code of SegmentationLabelList in your custom SegmentationLL

class SegmentationLL(ImageItemList):
    def __init__(self, items:Iterator, classes:Collection=None, **kwargs):
        super().__init__(items, **kwargs)
        self.classes,self.loss_func,self.create_func = classes,CrossEntropyFlat(),partial(open_mask, div=True)
        self.c = len(self.classes)

    def new(self, items, classes=None, **kwargs):
        return self.__class__(items, ifnone(classes, self.classes), **kwargs)

Or even shorter:

class SegmentationLL(SegmentationItemList):
    def __init__(self, items:Iterator, classes:Collection=None, **kwargs):
        super().__init__(items, classes, **kwargs)
        self.create_func = partial(open_mask, div=True)
3 Likes