Currently in _load_pretrained_weights
, channels > 3 are initialized to zero weights (see the very last line):
def _load_pretrained_weights(new_layer, previous_layer):
"Load pretrained weights based on number of input channels"
n_in = getattr(new_layer, 'in_channels')
if n_in==1:
# we take the sum
new_layer.weight.data = previous_layer.weight.data.sum(dim=1, keepdim=True)
elif n_in==2:
# we take first 2 channels + 50%
new_layer.weight.data = previous_layer.weight.data[:,:2] * 1.5
else:
# keep 3 channels weights and set others to null
new_layer.weight.data[:,:3] = previous_layer.weight.data
new_layer.weight.data[:,3:].zero_()
A favorable alternative would be to initialize these to nn.init.kaiming_normal_
:
def _load_pretrained_weights(new_layer, previous_layer):
"Load pretrained weights based on number of input channels"
n_in = getattr(new_layer, 'in_channels')
if n_in==1:
# we take the sum
new_layer.weight.data = previous_layer.weight.data.sum(dim=1, keepdim=True)
elif n_in==2:
# we take first 2 channels + 50%
new_layer.weight.data = previous_layer.weight.data[:,:2] * 1.5
else:
# keep 3 channels weights and set others to null
new_layer.weight.data[:,:3] = previous_layer.weight.data
# new_layer.weight.data[:,3:].zero_()
nn.init.kaiming_normal_(new_layer.weight.data[:,3:])
Curious for others’ opinions.