Using fastai v1 on 4+ channel images (RGBY)

Hello, I’m currently participating in the Human Protein Analysis Kaggle competition, and was wondering what would be the best way to load 4 channel images using fastai v1.

Right now, we’ve written an ImageDataset class that gets passed into a Pytorch dataloader, which then gets passed into an ImageDataBunch.

class ImageDataset(Dataset):
def __init__(self, csv_path, matrices_path, transform=None):

    self.csv_file = pd.read_csv(csv_path)
    self.matrices_path = matrices_path
    self.transform = transform

def __len__(self):
    return len(self.csv_file)

def __getitem__(self, idx):
    #csv_file=train_labels
    #img_path="D:/Human Protein Atlas Image Classification/train/"
    #idx=0
    #path= img_path+csv_file.iloc[idx, 0]
    
    path = self.matrices_path+self.csv_file.iloc[idx, 0]+".npy"
    
    im = np.load(path)
    im=torch.Tensor(im)
    label = self.csv_file["Target"][idx].split()
    label = [int(s) for s in label]
    return im, label

dataset = ImageDataset(label_csv,matrixpath)

dataloader = DataLoader(dataset, batch_size=4,
                        shuffle=True, num_workers=4)
bunch = ImageDataBunch(dataloader, dataloader)
bunch.show_batch(rows=3)

Is this the best way to implement this in fastai v1? We’ve seen that people have used fastai v.7 to create models with 4 channels with success here: https://www.kaggle.com/zhugds/resnet34-with-rgby-fast-ai-fork/comments

6 Likes

assuming that your “im = np.load(path)” results in a np.array with layout heigh, width, nchannels then the following will work tensor = pil2tensor(im,np.float32) using the below definition of pil2tensor. if the channel is 8 bit then you should also divide by 255.: tensor = pil2tensor(im,np.float32).div(255.)

This is a new version of pil2tensor that i on its way in a PR (https://github.com/fastai/fastai/pull/1085)

def pil2tensor(image,dtype:np.dtype)->TensorImage:
“Convert PIL style image array to torch style image tensor.”
a = np.asarray(image)
if a.ndim==2 : a = np.expand_dims(a,2)
a = np.transpose(a, (1, 0, 2))
a = np.transpose(a, (2, 1, 0))
return torch.from_numpy( a.astype(dtype, copy=False) )

5 Likes