How to use multiple gpus

DataParallel adds a new wrapper around your model and every group will have a new prefix ‘module’. There is some indications here advising to save your model (learn.save(‘name’)) then

# original saved file with DataParallel
state_dict = torch.load('myfile.pth')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)
2 Likes