The backbone of the U-Net is a ResNet34. If you type learn.model
it will print the different layers for you.
In PyTorch the pretrained ResNets are all structured very similar:
Stem (ConvLayer + BatchNorm + ReLU)
|
V
Block 1 (Multiple ConvLayers + BatchNorm + ReLU)
|
V
Block 2 (Multiple ConvLayers + BatchNorm + ReLU)
|
V
Block 3 (Multiple ConvLayers + BatchNorm + ReLU)
|
V
Block 4 (Multiple ConvLayers + BatchNorm + ReLU)
|
V
Head (Flatten input, Linear Layers)
if you type model[0]
you’ll get the stem, model[1]
will give you the first block, model[-1]
will give you the last layer (the head).
If you pass resnet34
to unet_learner
the following will happen:
- the
resnet34
is constructed and, ifpretrained=True
, the weights are loaded as well. - the
create_body
function offastai
cuts off the head of the model, leaving only stem and encoder. - hooks are placed into the model after each major block (so after the stem, block1, block2, …) which will catch the output of each layer, given the size of the feature maps did change in between layers, to use it in the skip connections.
- A BatchNorm, ReLU and double convolutional Layer is appended to the encoder.
- Multiple U-Net Blocks are added to the model.
The final model then looks something like this:
From resnet18
OutLayer
------------- ^
| | |
| Stem - - | - - - hooked output - > UNetBlock 4
| | | ^
| V | |
| Block 1 - | - - - hooked output - > UNetBlock 3
| | | ^
| V | |
| Block 2 - | - - - hooked output - > UNetBlock 2
| | | ^
| V | |
| Block 3 - | - - - hooked output - > UNetBlock 1
| | | ^
| V | |
| Block 4 - | - - - > DoubleConv - - - -
| |
-------------