Multi band image classifier

Hi All :wave:

I have been working on a mapping project in witch it would be useful to train an image classifier on imagery with more than 3 bands/channels. I’m currently working with imagery with 2 classes and 6 bands. I think I mostly have it sorted, however I can’t work out how to tell the model to output a prediction for each class (2 predictions) instead of outputting one prediction for each image

This is how I’m loading in the data.

# open a image and convert it to a tensor
def open_img(path):
    ms_img ='float32')/255.0
    im = torch.from_numpy(ms_img)
    return im

# get the image label from the folder name
def get_label(path):
    label = os.path.basename(os.path.dirname(path))
    return label 

db = DataBlock(blocks=(TransformBlock(open_img), CategoryBlock),
               get_items = get_image_files,
               get_y= get_label,
               splitter=RandomSplitter(valid_pct=0.2, seed=42),

ds = db.datasets(source=path)
dl = db.dataloaders(source=path, bs=4)
batch = dl.one_batch()
print(batch[0].shape, batch[1])

#torch.Size([4, 6, 1000, 1000]) TensorCategory([0, 0, 1, 0], device='cuda:0')

Then I’m setting up the learner like this

def print_input(predictions, targets):

learn = cnn_learner(dl, resnet18,n_in=6, n_out=1, metrics=error_rate, loss_func = print_input).to_fp16()

#        [-2.2383],
#        [ 1.8320],
#        [ 2.2969]], device='cuda:0', grad_fn=<CopyBackwards>)
#TensorCategory([1, 0, 0, 0], device='cuda:0')

So I think my problem is that the output above is only giving me one prediction for each input image, and what I want is two predictions for each image, one for each class.

Also I believe I should be using ‘CrossEntropyLossFlat’ as the loss function, however I needed a way to see what the model was outputting which is why a added ‘print_input’ as the loss function (I get that this is a bit odd but I’m getting desperate :laughing:).

I believe what I’m after is the model to output a prediction for each class, like this.

tensor([[-1.1084, -1.1084],
        [-2.2383, -2.2383],
        [ 1.8320, 1.8320],
        [ 2.2969, 2.2969]], device='cuda:0', grad_fn=<CopyBackwards>)
TensorCategory([1, 0, 0, 0], device='cuda:0')

If anyone could let me know what I’m doing wrong here it would be greatly appreciated.

Thanks :+1:

Ok so I just needed to sleep on it, it turns out all I needed to do was remove the ‘n_out’ option from the learner, or set it to the number of classes like ‘n_out=2’ or dynamically like “n_out=len(dl.vocab)” .
Once I have this script all cleaned up I will post a link to it, to hopefully help someone out in the future.


I just finished up writing an end to end walk through on how to handle this situation with fastai v2.Multispectral image classification with Transfer Learning

Hi @Nickelberry,

I’m trying to do segmentation for Sentinel2 multi-spectral images. I found your article and notebook super helpful to get started. Thank you.

I also used @cordmaur articles / notebooks as reference. Thanks @cordmaur.

Setting up the training pipeline proved straightforward with these references.

I’m wondering how you dealt with missing data. The images I’m working with have a lot of missing pixels, which I’m replacing with nan values. This (appears to) result in model outputs with -inf values, which causes the loss functions to blow up.

Did you face this problem?

Hi @restlessronin, I’m glad you found my work useful :slight_smile:

I have had this exact issue when building a Sentinel 2 cloud masking model. It depends on why you have nan values. If the satellite didn’t ‘see’ an area of an image I will reclassify those pixels to a 0 value, if your nan values are really high reflectance values (such as clouds) that have been clipped I would reclassify them as some large number.

@Nickelberry Thanks a lot for the pointers. I made the changes and everything seems to run properly now.

I had to make some small tweaks to the multi channel segmentation code (@cordmaur code), but it’s at least running.

Once again. Many thanks. Much appreciated. :pray:

1 Like

Hi @Nickelberry, quick follow up. I wound up using what I learnt from your notebooks as well as those of @cordmaur to create a library that simplifies the visualization, brightening, augmentation, re-weighting etc. In case it’s useful for others, I’m posting the link here

repo: GitHub - restlessronin/fastgs: Geospatial (Sentinel2 Multi-Spectral) support for fastai
docs: fastgs - Welcome to fastgs

It’s still a work in progress so feedback (and PRs) are welcome


@restlessronin Amazing, This looks very useful! I’m currently working on a Sentinel 2 project at work, I will give fastgs a go and see what happens. One question before I go digging into the code, how are you dealing with the different pixel size of the different bands?

I’m working with data where the pixel sizes have been normalized. But I’m happy to consider ways to support that.