Why is `state_dict` the recommended method for saving PyTorch models?

There are two ways of saving a PyToch model:

torch.save(model, PATH)

which essentially pickles it, and

torch.save(model.state_dict(), PATH)

which stores only the model’s parameters as a Python ordered dictionary.

In the PyTorch documentation they mention that:

When saving a model for inference, it is only necessary to save the trained model’s learned parameters. Saving the model’s state_dict with the torch.save() function will give you the most flexibility for restoring the model later, which is why it is the recommended method for saving models.

But why is that? What do they mean by flexibility? Could you provide some examples?

It’s to avoid relying into the exact python code. If you pickle the whole class, you’ll need to import it before reading the pickle. If it has changed, it could lead to errors. For example, changing model definition location. However, pickle the state dictionary, you can easily load in again to the model, reuse a portion of it in another model, etc. and, specially, you don’t depend to the exact code version that was used to save it. For example, pytorch version, etc.

1 Like