Error with Densnet 121 finetuning

DenseNet implementation in PyTorch has repeated modules. The layers are in the form of OrderedDict. Repeated modules key names causes the optimizer throw this error:

ValueError: some parameters appear in more than one parameter group

I found this PyTorch issue helpful in understanding the problem.

I did the following to fix the error:

  1. Modified the torchvision.models.densenet code to append the block numbers to the layer names.
  2. Copied the weights from torchvision.models.densenet models to the models with updated layer names.
  3. Saved the state_dict for the updated model.

With these changes I am able load and train all of the DenseNet models.

Here is the code I used to transfer the model weights

import torch
from densenet import *
import torchvision
from collections import OrderedDict
from tqdm import tqdm

dn_models = {
    'densenet121': densenet121,
    'densenet169': densenet169,
    'densenet201': densenet201,
    'densenet161': densenet161,
}

torch_models = {
    'densenet121': torchvision.models.densenet121,
    'densenet169': torchvision.models.densenet169,
    'densenet201': torchvision.models.densenet201,
    'densenet161': torchvision.models.densenet161,
}

for m in tqdm(dn_models.keys()):
    print(f"Fixing {m}")
    # densenet with layer names fixed
    dnetm = dn_models[m]()
    # original densenet
    dnet = torch_models[m](True).eval()

    # get the state dict of
    dnet_sdict = dnet.state_dict()
    d_keys = dnet_sdict.keys()
    dm_keys = dnetm.state_dict().keys() # modified densenet keys

    dnetm.load_state_dict(OrderedDict(zip(dm_keys, dnet_sdict.values())))
    dnetm.eval()
    dnetm_sdict = dnetm.state_dict()

    for k1, k2 in zip(d_keys, dm_keys):
        assert torch.equal(dnet_sdict[k1], dnetm_sdict[k2]), f"{k1}!={k2}"

    torch.save(dnetm.state_dict(), model_locs[m])
    print(f"Saving to {model_locs[m]}\n")

print("Done!")

Modified DenseNet code
Fixed DenseNet weights

@jeremy Does this look like valid solution? Or is there a better way of fixing this issue?

8 Likes