Can anyone tell me the purpose of eval()
in the line
middle_conv = nn.Sequential(conv_layer(ni, ni*2, **kwargs), conv_layer(ni*2, ni, **kwargs)).eval()
?
It looks like this is because it is evaluating a dummy input in order to determine the output sizes of the encoder and bridge (/middle layer) to create appropriately sized decoding layers. The eval
function will put certain layers into evaluation mode, disabling things like dropout that should only operate in training mode. In particular I think this is to stop the batchnorms in the two child conv_layer
s from collecting statistics on the dummy inputs.
This will subsequently be reset when real inputs are put through the model by calling model.train
in the training loop.
4 Likes
Thank you.