Classifier with multiple images as input and multiple labels as output

I’ve tried to address Kaggle “Human protein atlas” competition:

Where the input consist of 4 image per sample (RGB + Yellow) and the output is a multi class label, with a domain of 28 possible classes.

To address this competition I’ve to implement a couple of custom function to cope with “multi channel images” datasets.

1) Multi channel image loader

Inspired by “open_image”, this function take as input a list of “Path to image” and transforms it in a single multi channel image.
The current version supports only single channel images (grayscale) as input but It’s not difficult to change the code to support list of images with multiple channels.
It supports any number of input images so you can load for example sections of MRI scan.

import PIL
def openMultiChannelImage(fpArr):
    Open multiple images and return a single multi channel image
    mat = None
    nChannels = len(fpArr)
    for i,fp in enumerate(fpArr):
        #print('Loading: ', fp)
        img =
        chan = pil2tensor(img).float().div_(255)
        if(mat is None):
            mat = torch.zeros((nChannels,chan.shape[1],chan.shape[2]))
    return Image(mat)

# Usage sample
# v = (train_data_and_labels_df[train_df.columns]).values[0,:]
v = array(['00070df0-bbc3-11e8-b2bc-ac1f6b6435d0', # Object reference - not used here

ret = openMultiChannelImage(v[1:])

2) Multi channel dataset

It’s a subclass of “ImageMultiDataset” and takes as input “X” an array of Images paths list (where the first element is the sample “id” and the others are Paths) and as output “Y” an array of classes lists (classes are string).
It supports (X,Y) pairs for train and validation set, and “X” only samples for test dataset, so you can use it even for predictions.

class MultiChannelDataset(ImageMultiDataset):
    A dataset wuth multi channel support task.
    x: is of type x[0]=id, x[1:]=channels
    y: labels: matrix of shape #samples: each row is an array of string representing classes
    def __init__(self, x:Collection[Any], y:Collection[Collection[str]], classes:Optional[Collection[Any]]=None):
        if((y is not None) and (len(y)>0)):
            assert len(x)==len(y)
        super().__init__(fns=None, labels=y, classes=classes) # No file names! It retrive X in different way
        self.x=x # Assign x instead of using fns
        #self.y # Using super class initializer
        #self.loss_func # Using super class initializer
        self.isTest = (y is None) or (len(y)==0)

    def _get_x(self,i): 
        #print('loaded x:',self.x[i,0], # WARNING: it's needed to slow down
        var = # Probably there is a race ceondition...
        return img # it's a pytorch tensor image...
    def _get_y(self,i): # Override default behaviour to accomodate test set without y
            return [0] # if no label passed, used label of first training item Encoded (as in add_test)
            return super()._get_y(i)

    def create_test_dataset(cls, x:Collection[Any], classes:Collection[Any]):
        return cls(x,[],classes)

This is the link to git hub public repo with the complete notebook.

I’ve tried to train:

  • resnet 50: score of .135 and rank of 609/736
  • resnet 34: score of .167 and rank of 594/736

There is a lot of room for improvement.

I hope this help you!


Great suggestion @jeremy!

The idea of extending a pretrained model and initialize the weight tensors to accomodate additional channels input (4th, 5th…) is great. Not knowing that, I’ve trained the network from scratch (pre_trained=False)…

Does it it make sense to “duplicate the weights”?
If we’ve got an RGBXY input image we can copy weight for X and Y channels from B and R such as if we had trained the original network with RGBRG channels?

I should try to add that and make the code compatible with funcitonal Block Api.

Here is the updated model with the option to use transfer learning with multiple channel images (any number of channels).
It duplicates RGB weights of first convolution on top of missing channels (IE: RGBY image use RGBR weights, RGBXYZ use RGBRGB).

TODO: Data Block Api support…


I am currently trying to figure out how to load it with the Data Block API, but so far I am struggling with the multiple labels.

I tried it with two different custom classes based on ImageItemList and MultiCategoryList but I was not successful with getting the classes loaded properly.

I am not sure what the proper way would be after the last Data Block API updates?

1 Like

I’ve ported the original example to data_block api. I’ve tried to do that some times in the last weeks, trying to follow the evolution of library, but now I’ve to say that the actual version of fastai is very easy to customize.
The code that let you read a multi channel image is only:

def openMultiChannelImage(fpArr):
    Open multiple images and return a single multi channel image
    mat = None
    nChannels = len(fpArr)
    for i,fp in enumerate(fpArr):
        img =
        chan = pil2tensor(img, dtype='float').div_(255)
        if(mat is None):
            mat = torch.zeros((nChannels,chan.shape[1],chan.shape[2]))
    return Image(mat)
class MultiChannelImageItemList(ImageItemList):
    def open(self, fn):
        return openMultiChannelImage(fn)

So you can use the new “MultiChannelImageItemList” to create the multi channel image reader, both for the training and test set.

il = MultiChannelImageItemList.from_df(path=path, df=train_data_and_labels_df, cols=x_cols)
ils1 = il.random_split_by_pct()
ils2 = ils1.label_from_df(cols=y_cols) 
test_items = MultiChannelImageItemList.from_df(path=path, df=test_df, cols=x_cols) 
ils3 = ils2.add_test(items=test_items)
ils4 = ils3.transform(tfms, size=size)
data = ils4.databunch(bs=bs)

After that you’ve only to adapt your model to accept different input (ie: 4 channels instead of 3).

You can find a complete working example on:



Thank you so much for sharing this! I have wanted to try this with a project I am working now. I was going to combine 2-3 different gray-scale post-processed outputs of the same raw data. Is this approach still current with the current fastai library to your knowledge?


You’re welcome @aksg87: I’m happy to see that you find it useful.
It should work - I’ve run it using fastai version 1.0.56.dev0 and there was no errors.

1 Like

hi, would you have any hints on getting something similar to work but for segmentation masks rather than images? ie, I have 1 image but I want to predict 4 different masks on it?


I also have a requirement similar to @wwymak wherein i am predicting 4 segmentation masks for a single image. Any pointers for this would be much appreciated.

Hello, @njordsir @wwymak did you guys managed to work it ou ? I also have the same problem.