I had to create a custom pytorch model because I want to give it three input images through a dataloader. I created a basic u-net and gave its instance to a Learner. However, when trying to find a suitable learning rate with find_lr() I am getting error mentioned in the title:
ValueError Traceback (most recent call last)
----> 1 learn.lr_find()
~/anaconda3/lib/python3.8/site-packages/fastai/callback/schedule.py in lr_find(self, start_lr, end_lr, num_it, stop_div, show_plot, suggestions)
220 n_epoch = num_it//len(self.dls.train) + 1
221 cb=LRFinder(start_lr=start_lr, end_lr=end_lr, num_it=num_it, stop_div=stop_div)
--> 222 with self.no_logging(): self.fit(n_epoch, cbs=cb)
223 if show_plot: self.recorder.plot_lr_find()
224 if suggestions:
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
209 self.opt.set_hypers(lr=self.lr if lr is None else lr)
210 self.n_epoch = n_epoch
--> 211 self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
212
213 def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
158
159 def _with_events(self, f, event_type, ex, final=noop):
--> 160 try: self(f'before_{event_type}'); f()
161 except ex: self(f'after_cancel_{event_type}')
162 self(f'after_{event_type}'); final()
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _do_fit(self)
200 for epoch in range(self.n_epoch):
201 self.epoch=epoch
--> 202 self._with_events(self._do_epoch, 'epoch', CancelEpochException)
203
204 def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
158
159 def _with_events(self, f, event_type, ex, final=noop):
--> 160 try: self(f'before_{event_type}'); f()
161 except ex: self(f'after_cancel_{event_type}')
162 self(f'after_{event_type}'); final()
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _do_epoch(self)
194
195 def _do_epoch(self):
--> 196 self._do_epoch_train()
197 self._do_epoch_validate()
198
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_train(self)
186 def _do_epoch_train(self):
187 self.dl = self.dls.train
--> 188 self._with_events(self.all_batches, 'train', CancelTrainException)
189
190 def _do_epoch_validate(self, ds_idx=1, dl=None):
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
158
159 def _with_events(self, f, event_type, ex, final=noop):
--> 160 try: self(f'before_{event_type}'); f()
161 except ex: self(f'after_cancel_{event_type}')
162 self(f'after_{event_type}'); final()
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in all_batches(self)
164 def all_batches(self):
165 self.n_iter = len(self.dl)
--> 166 for o in enumerate(self.dl): self.one_batch(*o)
167
168 def _do_one_batch(self):
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in one_batch(self, i, b)
182 self.iter = i
183 self._split(b)
--> 184 self._with_events(self._do_one_batch, 'batch', CancelBatchException)
185
186 def _do_epoch_train(self):
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
158
159 def _with_events(self, f, event_type, ex, final=noop):
--> 160 try: self(f'before_{event_type}'); f()
161 except ex: self(f'after_cancel_{event_type}')
162 self(f'after_{event_type}'); final()
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _do_one_batch(self)
170 self('after_pred')
171 if len(self.yb):
--> 172 self.loss_grad = self.loss_func(self.pred, *self.yb)
173 self.loss = self.loss_grad.clone()
174 self('after_loss')
~/anaconda3/lib/python3.8/site-packages/fastai/losses.py in __call__(self, inp, targ, **kwargs)
33 if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
34 if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
---> 35 return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
36
37 # Cell
~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
959
960 def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 961 return F.cross_entropy(input, target, weight=self.weight,
962 ignore_index=self.ignore_index, reduction=self.reduction)
963
~/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2460 tens_ops = (input, target)
2461 if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
-> 2462 return handle_torch_function(
2463 cross_entropy, tens_ops, input, target, weight=weight,
2464 size_average=size_average, ignore_index=ignore_index, reduce=reduce,
~/anaconda3/lib/python3.8/site-packages/torch/overrides.py in handle_torch_function(public_api, relevant_args, *args, **kwargs)
1058 # Use `public_api` instead of `implementation` so __torch_function__
1059 # implementations can do equality/identity comparisons.
-> 1060 result = overloaded_arg.__torch_function__(public_api, types, args, kwargs)
1061
1062 if result is not NotImplemented:
~/anaconda3/lib/python3.8/site-packages/fastai/torch_core.py in __torch_function__(self, func, types, args, kwargs)
323 convert=False
324 if _torch_handled(args, self._opt, func): convert,types = type(self),(torch.Tensor,)
--> 325 res = super().__torch_function__(func, types, args=args, kwargs=kwargs)
326 if convert: res = convert(res)
327 if isinstance(res, TensorBase): res.set_meta(self, as_copy=True)
~/anaconda3/lib/python3.8/site-packages/torch/tensor.py in __torch_function__(cls, func, types, args, kwargs)
993
994 with _C.DisableTorchFunction():
--> 995 ret = func(*args, **kwargs)
996 return _convert(ret, cls)
997
~/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2466 if size_average is not None or reduce is not None:
2467 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2468 return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
2469
2470
~/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
2259
2260 if input.size(0) != target.size(0):
-> 2261 raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
2262 .format(input.size(0), target.size(0)))
2263 if dim == 2:
ValueError: Expected input batch_size (13068) to match target batch_size (4356).
The u-net class
class ThreeInputsUnet(Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(ThreeInputs).__init__()
self.model = models.resnet18()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
#...
#other stuff
def forward(self, x1, x2, x3):
#forward logic
def fit(self, data_loader, epochs, validation_data=None):
#fitting logic
def _train_iteration(self,data_loader):
#iteration logic
def predict(self, X):
#prediction logic
instanciation
tiunet = ThreeInputsUnet(9,2)
Creating learner
learn = Learner(dls_train, model=ThreeInputsUnet, opt_func=ranger, metrics=acc_camvid)
datablock
field_train = DataBlock(blocks=(ImageBlock, ImageBlock, ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=0.3),
get_y=get_train_msk,
n_inp=3,
batch_tfms=[*aug_transforms(size=quarter, do_flip=False)])
dataloader
dls_train = field_train.dataloaders(train_image_path, bs=1, num_workers=4)
I would be glad for any help and would be happy to add more information if needed