How to use ItemBase correctly?

So here I am asking the same question again.

I am currently working on a U-Net model for the segmentation of cells in microscopy images. Due to class imbalances and to amplify the importance of cell boundaries, I calculated a pixelwise weightmap for each image that I pass into fastai. Therefore I created a new ItemBase class to save labels and weights together:

    class WeightedLabels(ItemBase):
    """
    Custom ItemBase to store and process labels and pixelwise weights together.
    Also handling the target_size of the labels.
    """

    def __init__(self, lbl: Image, wgt: Image, target_size: Tuple = None):
        self.lbl, self.wgt = lbl, wgt
        self.obj, self.data = (lbl, wgt), [lbl.data, wgt.data]

        self.target_size = target_size

    ...

I also defined a custom ItemList and LabelList class as follows:

    class CustomSegmentationLabelList(ImageList):
    "'Item List' suitable for WeightedLabels containing labels and pixelweights"
    _processor = vision.data.SegmentationProcessor

    def __init__(self,
                 items: Iterator,
                 wghts = None,
                 classes: Collection = None,
                 target_size: Tuple = None,
                 loss_func=CrossEntropyFlat(axis=1),
                 **kwargs):

        super().__init__(items, **kwargs)
        self.copy_new.append('classes')
        self.copy_new.append('wghts')
        self.classes, self.loss_func, self.wghts = classes, loss_func, wghts
        self.target_size = target_size

    def open(self, fn):
        res = io.imread(fn)
        res = pil2tensor(res, np.float32)
        return Image(res)

    def get(self, i):
        fn = super().get(i)
        wt = self.wghts[i]
        return WeightedLabels(fn, self.open(wt), self.target_size)

    def reconstruct(self, t: Tensor):
        return WeightedLabels(Image(t[0]), Image(t[1]), self.target_size)

    class CustomSegmentationItemList(ImageList):
    "'ItemList' suitable for segmentation with pixelwise weighted loss"
    _label_cls, _square_show_res = CustomSegmentationLabelList, False

    def label_from_funcs(self, get_labels: Callable, get_weights: Callable,
                         label_cls: Callable = None, classes=None,
                         target_size: Tuple = None, **kwargs) -> 'LabelList':
        "Get weights and labels from two functions. Saves them in a CustomSegmentationLabelList"
        kwargs = {}
        wghts = [get_weights(o) for o in self.items]
        labels = [get_labels(o) for o in self.items]

        if target_size:
            print(
                f'Masks will be cropped to {target_size}. Choose \'target_size \\= None \' to keep initial size.')
        else:
            print(f'Masks will not be cropped.')

        y = CustomSegmentationLabelList(
            labels, wghts, classes, target_size, path=self.path)
        res = self._label_list(x=self, y=y)
        return res

Now I have the problem that when using learn.predict() on some data it always outputs two WxHxC dimensional tensors instead of just one. My torch model takes a WxHx1 dimensional image as an input and outputs one WxHxC dimensional tensor. This means that there must be something wrong with my implementation. And guessing from the behavior of learn.predict(), I worry that fit_one_cycle() also trains on the labels AND the weights. My questions regarding this are:

  1. Am I right to assume that my model trains on both: labels and weights?
  2. How can I fix this, that my model only trains on the labels and only includes the weights for my loss function?

Please help, even if you regard this as a stupid question. Lastly here’s how I create my databunch instance. For every other important class I used the predefined ones from fast-ai. Let me know if you need any additional code!

data = (CustomSegmentationItemList.from_df(img_df,IMG_PATH, convert_mode=IMAGE_TYPE)
  .split_by_rand_pct(valid_pct=(1/N_SPLITS),seed=SEED)
  .label_from_funcs(get_labels, get_weights, target_size=MASK_SHAPE, classes = array(['background','cell']))
  .transform(tfms=tfms, tfm_y=True)
  .databunch(bs=BATCH_SIZE))

It’s funny how this question is always ignored. It’s not even a complicated one and for the developer of this framework rather easy to answer.