Using only one part of the LabelList/ItemBase class to update model layers

I am currently working on a U-Net model for the segmentation of cells in microscopy images. Due to class imbalances and to amplify the importance of cell boundaries, I calculated a pixelwise weightmap for each image that I pass into fastai. Therefore I created a new ItemBase class to save labels and weights together:

class WeightedLabels(ItemBase):
"""
Custom ItemBase to store and process labels and pixelwise weights together.
Also handling the target_size of the labels.
"""

def __init__(self, lbl: Image, wgt: Image, target_size: Tuple = None):
    self.lbl, self.wgt = lbl, wgt
    self.obj, self.data = (lbl, wgt), [lbl.data, wgt.data]

    self.target_size = target_size

...

I use extensive augmentation, like elastic deformation, mirroring and rotations on both weights and labels, as well as the original image. I determine the Loss with a custom Cross-entropy loss function that uses the weights to get the weighted loss for each pixel and averages them.

My problem is, that I do not get a very good performace. My IoU never exceeds 57% and I have the feeling that might be because of fastai trying to predict the weights as well. My questions are:

  1. Am I right to assume my model tries to predict both?
  2. If so, how do I tell the learner what to use for updating the layers and to only predict part of my labels, while still applying augmentation to both?

Also maybe this helps: When I test my model, I always get back two identical tensors (probably one for the weights and the other for the labels). Thank you in advance guys (and girls)!

I was told to post more Code. So here’s the CustomSegmentationLabelList and CustomSegmentationItemList. Please tell me what else you’d like to see:

class CustomSegmentationLabelList(ImageList):
"'Item List' suitable for WeightedLabels containing labels and pixelweights"
_processor = vision.data.SegmentationProcessor

def __init__(self,
             items: Iterator,
             wghts = None,
             classes: Collection = None,
             target_size: Tuple = None,
             loss_func=CrossEntropyFlat(axis=1),
             **kwargs):

    super().__init__(items, **kwargs)
    self.copy_new.append('classes')
    self.copy_new.append('wghts')
    self.classes, self.loss_func, self.wghts = classes, loss_func, wghts
    self.target_size = target_size

def open(self, fn):
    res = io.imread(fn)
    res = pil2tensor(res, np.float32)
    return Image(res)

def get(self, i):
    fn = super().get(i)
    wt = self.wghts[i]
    return WeightedLabels(fn, self.open(wt), self.target_size)

def reconstruct(self, t: Tensor):
    return WeightedLabels(Image(t[0]), Image(t[1]), self.target_size)

class CustomSegmentationItemList(ImageList):
"'ItemList' suitable for segmentation with pixelwise weighted loss"
_label_cls, _square_show_res = CustomSegmentationLabelList, False

def label_from_funcs(self, get_labels: Callable, get_weights: Callable,
                     label_cls: Callable = None, classes=None,
                     target_size: Tuple = None, **kwargs) -> 'LabelList':
    "Get weights and labels from two functions. Saves them in a CustomSegmentationLabelList"
    kwargs = {}
    wghts = [get_weights(o) for o in self.items]
    labels = [get_labels(o) for o in self.items]

    if target_size:
        print(
            f'Masks will be cropped to {target_size}. Choose \'target_size \\= None \' to keep initial size.')
    else:
        print(f'Masks will not be cropped.')

    y = CustomSegmentationLabelList(
        labels, wghts, classes, target_size, path=self.path)
    res = self._label_list(x=self, y=y)
    return res

Also here the part of me, where I initiate my databunch object:

data = (CustomSegmentationItemList.from_df(img_df,IMG_PATH, convert_mode=IMAGE_TYPE)
  .split_by_rand_pct(valid_pct=(1/N_SPLITS),seed=SEED)
  .label_from_funcs(get_labels, get_weights, target_size=MASK_SHAPE, classes = array(['background','cell']))
  .transform(tfms=tfms, tfm_y=True)
  .databunch(bs=BATCH_SIZE))

My model is a Unet with input WxHx1 and output WxHxC and since having a binary segmentation C=2.