Dynamic U-Net source code question

Can anyone tell me the purpose of eval() in the line
middle_conv = nn.Sequential(conv_layer(ni, ni*2, **kwargs), conv_layer(ni*2, ni, **kwargs)).eval()?

It looks like this is because it is evaluating a dummy input in order to determine the output sizes of the encoder and bridge (/middle layer) to create appropriately sized decoding layers. The eval function will put certain layers into evaluation mode, disabling things like dropout that should only operate in training mode. In particular I think this is to stop the batchnorms in the two child conv_layers from collecting statistics on the dummy inputs.

This will subsequently be reset when real inputs are put through the model by calling model.train in the training loop.

4 Likes

Thank you.