FastAI for more than 3 channels

I have seen some people(including me) having doubs about how to use FastAi with more than Three Channel Input.So Recently I got it working(although it is a very Naive Solution)

Step1:Create a custom Dataset using Pytorch and apply own transforms(I have no idea how to make default fasai transforms to work with more than 3 channels )
In this example We have a 7 channel Image

def applytransform(x,y,mode):
    if mode=="val":
      return x,y
   
  
    if random.random()>=0.5:
        return x,y
    else:
        
        if random.random()>=0.5:
            for i in range(len(x)):x[i]=np.fliplr(x[i])
            y=np.fliplr(y)
        if random.random()>=0.5:
            for i in range(len(x)):x[i]=np.flipud(x[i])
            y=np.flipud(y)
        if random.random()>=0.5:
            k=random.randint(0,3)
            for i in range(len(x)):x[i]=np.rot90(x[i],k)
            y=np.rot90(y,k)
    return x,y

where Inputs x and Y are Numpy arrays

def open_7_channel_img(file_name,mask_path,mode):
    file_name=file_name.split("_")
    file_name=file_name[0]+"_"+file_name[1]+"_"+file_name[2]+"_"+file_name[3]
    path=f'{im_input}/'+file_name
    images=[]
    mask=Im.open(mask_path)
    mask=np.array(mask)
    for i in range(1,8):
        im=Im.open(path+"_"+str(i)+".tif")
        im_array=np.array(im)/255
        #im_array=im_array-mean_norm[i-1]
        #im_array=im_array/std_norm[i-1]
        
        images.append(im_array)
    
    images,masks=applytransform(np.stack(images, axis=0),mask,mode)
    return  torch.from_numpy(images.copy()),torch.from_numpy(masks.copy())

Basically all inputs have same name with _channelnumber separation so we open them using PIL and stack them to numpy array

Now use this to in the Dataset

Step2:
Create a train and validation set from the base Dataset
Step3:Pass the train and validation set to Databunch

custom_databunch=DataBunch.create(trn,val,dl_tfms=None,bs=50,device=default_device)

now create the model and wrap it using

Learner(custom_databunch,model)

This seem to start the training Process but for me validation loss is stuck and no matter what model architecture I Used.
On the contrary if i average the 7 channels into a 3 channel image I am Getting much better results with model Converging.
Could @jeremy point me as to what I am missing