Modifying a 3 band model for additional bands

Thanks Malcolm, I’m running all of this in a fresh conda environment so your sample code worked without a problem.

After playing around for a while I realised that your code is using the data type ‘torch.Tensor’ and the tutorial I was following was using ‘fastai.torch_core.TensorImage’. All I needed to do was change two lines in the tutorial to use ‘torch.tensor’ instead of the fastai version. The lines I changed were:

Within the ‘open_ms_tif’ function

# return TensorImage(ms_img) 
return torch.from_numpy(ms_img)

and from within the ‘SegmentationAlbumentationsTransform’ function

# return TensorImage(aug['image'].transpose(2,0,1)), TensorMask(aug['mask'])
return torch.from_numpy(aug['image'].transpose(2,0,1)), TensorMask(aug['mask'])

Then it worked perfectly with n_in=4 and pretrained=True :smiley:

So it appears that this info from Zachary on this thread is still correct.

The second is an issue with the newest pytorch, it won’t just readily accept types anymore like it used to before (so long as it was a tensor ).

Thanks for all the help Malcolm :+1:, this would have been very difficult without you.

1 Like