Error trying unet architecture for classification

I am trying to create an image classifier based on the lessons of part 1. I was successful using a cnn but I would like to see the results using an unet since the images are medical images.

My code:

from fastai import *
from import *
import os

image_path = Path(os.path.join(os.getcwd(),‘data’,‘images’))

data = (ImageList.from_folder(image_path)

learn = unet_learner(data, models.resnet34, metrics=accuracy, wd=wd, self_attention=True, pretrained=False, bottle=True)


Here the Error:

ValueError Traceback (most recent call last)
in ()
----> 1 learn.lr_find()
2 learn.recorder.plot()

~\Anaconda3\envs\fastai\lib\site-packages\fastai\ in lr_find(learn, start_lr, end_lr, num_it, stop_div, wd)
30 cb = LRFinder(learn, start_lr, end_lr, num_it, stop_div)
31 epochs = int(np.ceil(num_it/len(
—> 32, start_lr, callbacks=[cb], wd=wd)
34 def to_fp16(learn:Learner, loss_scale:float=None, max_noskip:int=1000, dynamic:bool=False, clip:float=None,

~\Anaconda3\envs\fastai\lib\site-packages\fastai\ in fit(self, epochs, lr, wd, callbacks)
188 if defaults.extra_callbacks is not None: callbacks += defaults.extra_callbacks
189 fit(epochs, self.model, self.loss_func, opt=self.opt,, metrics=self.metrics,
–> 190 callbacks=self.callbacks+callbacks)
192 def create_opt(self, lr:Floats, wd:Floats=0.)->None:

~\Anaconda3\envs\fastai\lib\site-packages\fastai\ in fit(epochs, model, loss_func, opt, data, callbacks, metrics)
91 for xb,yb in progress_bar(data.train_dl, parent=pbar):
92 xb, yb = cb_handler.on_batch_begin(xb, yb)
—> 93 loss = loss_batch(model, xb, yb, loss_func, opt, cb_handler)
94 if cb_handler.on_batch_end(loss): break

~\Anaconda3\envs\fastai\lib\site-packages\fastai\ in loss_batch(model, xb, yb, loss_func, opt, cb_handler)
27 if not loss_func: return to_detach(out), yb[0].detach()
—> 28 loss = loss_func(out, *yb)
30 if opt is not None:

~\Anaconda3\envs\fastai\lib\site-packages\fastai\ in call(self, input, target, **kwargs)
242 if self.floatify: target = target.float()
243 input = input.view(-1,input.shape[-1]) if self.is_2d else input.view(-1)
–> 244 return, target.view(-1), **kwargs)
246 def CrossEntropyFlat(*args, axis:int=-1, **kwargs):

~\Anaconda3\envs\fastai\lib\site-packages\torch\nn\modules\ in call(self, *input, **kwargs)
487 result = self._slow_forward(*input, **kwargs)
488 else:
–> 489 result = self.forward(*input, **kwargs)
490 for hook in self._forward_hooks.values():
491 hook_result = hook(self, input, result)

~\Anaconda3\envs\fastai\lib\site-packages\torch\nn\modules\ in forward(self, input, target)
902 def forward(self, input, target):
903 return F.cross_entropy(input, target, weight=self.weight,
–> 904 ignore_index=self.ignore_index, reduction=self.reduction)

~\Anaconda3\envs\fastai\lib\site-packages\torch\nn\ in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
1968 if size_average is not None or reduce is not None:
1969 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 1970 return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)

~\Anaconda3\envs\fastai\lib\site-packages\torch\nn\ in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
1786 if input.size(0) != target.size(0):
1787 raise ValueError(‘Expected input batch_size ({}) to match target batch_size ({}).’
-> 1788 .format(input.size(0), target.size(0)))
1789 if dim == 2:
1790 ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

ValueError: Expected input batch_size (3072) to match target batch_size (4).

If someone could shed some light on how to solve the error it would be of great help.


From my understanding, Unets are used to solve problems when you want your outputs to be of similar size of your inputs such as image segmentation or generative modeling.

For classifying images, since you would only be using the encoder / downsampling path of the Unet to retrieve a number corresponding to which class does the image belong to, such architecture wouldn’t be suited for the task. Therefore, a CNN would suffice.

thanks for the enlightening bro!