RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x512 and 1280x2)

Hi everyone!
I’m doing RSNA Breast Cancer Detection Competition on Kaggle and I’m getting an image process error, but I don’t understand how to reshape the images?
From which number I should reshape?

My Code:

def get_dataloaders():
train_image_path = TRAIN_IMAGE_DIR

dblock = DataBlock(
    blocks    = (ImageBlock, CategoryBlock),
    get_items = get_items,
    get_y = label_func,
    splitter  = splitting_func,
    batch_tfms=[Flip()],
)
dsets = dblock.datasets(train_image_path)
return dblock.dataloaders(train_image_path, batch_size=32)

def get_learner(arch=resnet18):
learner = vision_learner(
get_dataloaders(),
arch,
custom_head=nn.Sequential(SelectAdaptivePool2d(pool_type=‘avg’, flatten=Flatten()), nn.Linear(1280, 2)),
metrics=[
error_rate,
AccumMetric(pfbeta_torch, activation=ActivationType.Softmax, flatten=False),
AccumMetric(pfbeta_torch_thresh, activation=ActivationType.Softmax, flatten=False)
],
loss_func=CrossEntropyLossFlat(weight=torch.tensor([1,50]).float()),
pretrained=True,
normalize=False
).to_fp16()
return learner

Error:

138 for module in self:
→ 139 input = module(input)
140 return input

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/linear.py in forward(self, input)
112
113 def forward(self, input: Tensor) → Tensor:
→ 114 return F.linear(input, self.weight, self.bias)
115
116 def extra_repr(self) → str:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x512 and 1280x2)

Hello,

ResNet-18 yields 512 output channels, but the linear classifier you have defined in the head expects 1280 input features - changing 1280 to 512 in nn.Linear(1280, 2) should fix the problem.

2 Likes

Thanks alot BobMcDear

1 Like