Error in feeding custom dataloader to model

Hi all,
I am working on multispectral imagery as suggested in this medium article.
https://towardsdatascience.com/how-to-create-a-datablock-for-multispectral-satellite-image-segmentation-with-the-fastai-v2-bc5e82f4eb5

I am experiencing following issues after preparing data.

  1. Not able to feed the data to unet model. The reason is as I am feeding in 4 bands (one additional band along with RGB) but unet is expecting 3 bands. can
    anyone please suggest What
    model should I use or what modifications should I do and what are the steps to be followed?

2.When I am trying to work even only with three bands by removing any one band , the model is expecting loss function. When I supplied
CrossEntropyLossFlat(), I am facing the following error:
Expected input batch_size (4088) to match target batch_size (2088968).
FYI:
My 16 input images are of 3X1022 x 1022 shape( when I supply 3 channels)
my 16 target masks are of 1022x 1022 shape.

this is the block of code which I am using:

def get_lbl_fn(img_fn: Path):

    lbl_path = img_fn.parent.parent/'masks'

    lbl_name = img_fn.name

    return (lbl_path/lbl_name)

db = DataBlock(blocks=(TransformBlock(type_tfms=partial(MSTensorImage.create,chnls=[3,2,1], chnls_first=True)),

                   TransformBlock(type_tfms=[get_lbl_fn, partial(open_npy,chnls=[1], cls=TensorMask)], 

                                  item_tfms=AddMaskCodes(codes=['CLEAR','CLOUD'])),

                  ),

           get_items=partial(get_files, extensions='.npy'),

           splitter=RandomSplitter(valid_pct=0.1)

          )

 db.summary(source=imgs_path)


 dls = db.dataloaders(source=imgs_path, bs=2, num_workers=0)

 def acc_metric(input, target):

    target = target.squeeze(1)

    return (input.argmax(dim=1)==target).float().mean()

# weight decay

wd = 1e-2

#learning rate

lr=1e-3

# load the model, according to the data parameters (resolution, for example)

learn = unet_learner(dls, models.resnet34, normalize = True, pretrained     =True,loss_func= CrossEntropyLossFlat(),n_out= 2,metrics=acc_metric, wd=wd)

# train the model with 3 epochs

learn.fit_one_cycle(3, lr)

Any help in this regard is well appreciated.
Thank you.