Here is how I did it:
# change csv so that it has image_id on one column and rles in the 4 others
def change_csv(old, new):
df = pd.read_csv(old)
def group_func(df, i):
reg = re.compile(r'(.+)_\d$')
return reg.search(df['ImageId_ClassId'].loc[i]).group(1)
group = df.groupby(lambda i: group_func(df, i))
df = group.agg({'EncodedPixels': lambda x: list(x)})
df['ImageId'] = df.index
df = df.reset_index(drop=True)
df[[f'EncodedPixels_{k}' for k in range(1, 5)]] = pd.DataFrame(df['EncodedPixels'].values.tolist())
df = df.drop(columns='EncodedPixels')
df = df.fillna(value=' ')
df.to_csv(new, index=False)
return df
class MultiClassSegList(SegmentationLabelList):
def open(self, id_rles):
image_id, rles = id_rles[0], id_rles[1:]
shape = open_image(self.path/image_id).shape[-2:]
final_mask = torch.zeros((1, *shape))
for k, rle in enumerate(rles):
if isinstance(rle, str):
mask = open_mask_rle(rle, shape).px.permute(0, 2, 1)
final_mask += (k+1)*mask
return ImageSegment(final_mask)
def load_data(path, csv, bs=32, size=(128, 800)):
train_list = (SegmentationItemList.
from_csv(path, csv).
split_by_rand_pct(valid_pct=0.2).
label_from_df(cols=list(range(5)), label_cls=MultiClassSegList, classes=[0, 1, 2, 3, 4]).
transform(size=size, tfm_y=True).
databunch(bs=bs, num_workers=0).
normalize(imagenet_stats))
return train_list
Basically what you do should work if you sum or aggregate your masks so that you return a single-channel mask that has values between 1 and 4.