Unet edge case and a suggested fix

Hi,

I recently finished part1 and wanted to work on a segmentation problem. I chose the TGS Salt Identification Challenge from Kaggle. As I managed to build the databunch and the learner, I reached to a point where the code breaks when reaching lr_find() or fit(). I thought of this as a good opportunity to understand the inner details of the DynamicUnet which led to a much better understanding of the library as I dived into the details. So I forgot all about the challenge and focused on understanding and solving the problem :sweat_smile:

I put up a notebook with test cases to illustrate the issue and made some effort to fix it.

Please have a look at the notebook and let me know what do you think about the suggested fix. I would like to contribute back to this great community and I hope this will result in my first PR!

Sorry for the long intro, but I felt really excited about this experience.

The issue:
Input images in the challenge were of size 101*101. I waned to use it as is without any padding or resizing. I noticed that when the size of the input is an odd number the following error occurs:

RuntimeError                              Traceback (most recent call last)
<ipython-input-7-ab8ad06f4220> in <module>()
----> 1 test_dynamic_unet_shape(model(),image(size=101))

4 frames
<ipython-input-5-99fbcbb6d64e> in test_dynamic_unet_shape(model, image)
      1 def test_dynamic_unet_shape(model, image):
----> 2     pred = model(image)
      3     print(pred.shape[-2:])

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/fastai/layers.py in forward(self, x)
    154         for l in self.layers:
    155             res.orig = x
--> 156             nres = l(res)
    157             # We have to remove res.orig to avoid hanging refs and therefore memory leaks
    158             res.orig = None

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/fastai/layers.py in forward(self, x)
    171         self.dense=dense
    172 
--> 173     def forward(self, x): return torch.cat([x,x.orig], dim=1) if self.dense else (x+x.orig)
    174 
    175 def res_block(nf, dense:bool=False, norm_type:Optional[NormType]=NormType.Batch, bottle:bool=False, **conv_kwargs):

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 102 and 101 in dimension 2 at /pytorch/aten/src/TH/generic/THTensor.cpp:711

I solved this by dynamically passing the actual input size when creating the model. Details are in the notebook.

Thanks,

Yes the DynamicUnet only works with input sizes that are a multiple of 32. I’m not too fond of adding interpolations since we went out of our way to avoid this but I don’t see another way round. If we were to apply this fix, I think a warning at init if the sizes don’t match (and telling the user to use a multiple of 32 for better performance) would be great.

I see. I was not aware of that. Although I saw interpolations being used in UnetBlock, but having input sizes that are multiple of 32, those lines won’t be triggered.

        if ssh != up_out.shape[-2:]:
            up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')

Seeing above lines, I assumed DynamicUnet supports any input sizes and there is no harm to use interpolation to fix 1-pixel difference between input and output of the decoder in my case.

Oh I had forgotten we already had that. Then by all mean suggest a PR with your fix.

1 Like