Several input images for one target and ValueError: Expected input batch_size (13068) to match target batch_size (4356)

I had to create a custom pytorch model because I want to give it three input images through a dataloader. I created a basic u-net and gave its instance to a Learner. However, when trying to find a suitable learning rate with find_lr() I am getting error mentioned in the title:

ValueError                                Traceback (most recent call last)

 ----> 1 learn.lr_find()

~/anaconda3/lib/python3.8/site-packages/fastai/callback/schedule.py in lr_find(self, start_lr, end_lr, num_it, stop_div, show_plot, suggestions)
220     n_epoch = num_it//len(self.dls.train) + 1
221     cb=LRFinder(start_lr=start_lr, end_lr=end_lr, num_it=num_it, stop_div=stop_div)
--> 222     with self.no_logging(): self.fit(n_epoch, cbs=cb)
223     if show_plot: self.recorder.plot_lr_find()
224     if suggestions:

~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
209             self.opt.set_hypers(lr=self.lr if lr is None else lr)
210             self.n_epoch = n_epoch
--> 211             self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
212 
213     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None

~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
158 
159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
161         except ex: self(f'after_cancel_{event_type}')
162         self(f'after_{event_type}');  final()

~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _do_fit(self)
200         for epoch in range(self.n_epoch):
201             self.epoch=epoch
--> 202             self._with_events(self._do_epoch, 'epoch', CancelEpochException)
203 
204     def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):

 ~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
158 
159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
161         except ex: self(f'after_cancel_{event_type}')
162         self(f'after_{event_type}');  final()

~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _do_epoch(self)
194 
195     def _do_epoch(self):
--> 196         self._do_epoch_train()
197         self._do_epoch_validate()
198 

~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_train(self)
186     def _do_epoch_train(self):
187         self.dl = self.dls.train
--> 188         self._with_events(self.all_batches, 'train', CancelTrainException)
189 
190     def _do_epoch_validate(self, ds_idx=1, dl=None):

~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
158 
159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
161         except ex: self(f'after_cancel_{event_type}')
162         self(f'after_{event_type}');  final()

~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in all_batches(self)
164     def all_batches(self):
165         self.n_iter = len(self.dl)
--> 166         for o in enumerate(self.dl): self.one_batch(*o)
167 
168     def _do_one_batch(self):

~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in one_batch(self, i, b)
182         self.iter = i
183         self._split(b)
--> 184         self._with_events(self._do_one_batch, 'batch', CancelBatchException)
185 
186     def _do_epoch_train(self):

~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
158 
159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
161         except ex: self(f'after_cancel_{event_type}')
162         self(f'after_{event_type}');  final()

~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _do_one_batch(self)
170         self('after_pred')
171         if len(self.yb):
--> 172             self.loss_grad = self.loss_func(self.pred, *self.yb)
173             self.loss = self.loss_grad.clone()
174         self('after_loss')

~/anaconda3/lib/python3.8/site-packages/fastai/losses.py in __call__(self, inp, targ, **kwargs)
 33         if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
 34         if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
---> 35         return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
 36 
 37 # Cell

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725             result = self._slow_forward(*input, **kwargs)
726         else:
--> 727             result = self.forward(*input, **kwargs)
728         for hook in itertools.chain(
729                 _global_forward_hooks.values(),

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
959 
960     def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 961         return F.cross_entropy(input, target, weight=self.weight,
962                                ignore_index=self.ignore_index, reduction=self.reduction)
963 

~/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2460         tens_ops = (input, target)
2461         if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
-> 2462             return handle_torch_function(
2463                 cross_entropy, tens_ops, input, target, weight=weight,
2464                 size_average=size_average, ignore_index=ignore_index, reduce=reduce,

~/anaconda3/lib/python3.8/site-packages/torch/overrides.py in handle_torch_function(public_api, relevant_args, *args, **kwargs)
1058         # Use `public_api` instead of `implementation` so __torch_function__
1059         # implementations can do equality/identity comparisons.
-> 1060         result = overloaded_arg.__torch_function__(public_api, types, args, kwargs)
1061 
1062         if result is not NotImplemented:

~/anaconda3/lib/python3.8/site-packages/fastai/torch_core.py in __torch_function__(self, func, types, args, kwargs)
323         convert=False
324         if _torch_handled(args, self._opt, func): convert,types = type(self),(torch.Tensor,)
--> 325         res = super().__torch_function__(func, types, args=args, kwargs=kwargs)
326         if convert: res = convert(res)
327         if isinstance(res, TensorBase): res.set_meta(self, as_copy=True)

~/anaconda3/lib/python3.8/site-packages/torch/tensor.py in __torch_function__(cls, func, types, args, kwargs)
993 
994         with _C.DisableTorchFunction():
--> 995             ret = func(*args, **kwargs)
996             return _convert(ret, cls)
997 

~/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2466     if size_average is not None or reduce is not None:
2467         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2468     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
2469 
2470 

~/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
2259 
2260     if input.size(0) != target.size(0):
-> 2261         raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
2262                          .format(input.size(0), target.size(0)))
2263     if dim == 2:

ValueError: Expected input batch_size (13068) to match target batch_size (4356).

The u-net class

class ThreeInputsUnet(Module):
def __init__(self, n_channels, n_classes, bilinear=True):
    super(ThreeInputs).__init__()
    self.model = models.resnet18()
    
    self.n_channels = n_channels
    self.n_classes = n_classes
    self.bilinear = bilinear

    #...
    #other stuff

def forward(self, x1, x2, x3):
    #forward logic

def fit(self, data_loader, epochs, validation_data=None):
    #fitting logic
            

def _train_iteration(self,data_loader):
    #iteration logic

def predict(self, X):
    #prediction logic

instanciation
tiunet = ThreeInputsUnet(9,2)
Creating learner
learn = Learner(dls_train, model=ThreeInputsUnet, opt_func=ranger, metrics=acc_camvid)

datablock

field_train = DataBlock(blocks=(ImageBlock, ImageBlock, ImageBlock, MaskBlock(codes)),
            get_items=get_image_files,
            splitter=RandomSplitter(valid_pct=0.3),
            get_y=get_train_msk,
            n_inp=3,
            batch_tfms=[*aug_transforms(size=quarter, do_flip=False)])

dataloader
dls_train = field_train.dataloaders(train_image_path, bs=1, num_workers=4)
I would be glad for any help and would be happy to add more information if needed