Fix Python – How do I save a trained model in PyTorch?

Question

Asked By – Wasi Ahmad

How do I save a trained model in PyTorch? I have read that:

  1. torch.save()/torch.load() is for saving/loading a serializable object.
  2. model.state_dict()/model.load_state_dict() is for saving/loading model state.

Now we will see solution for issue: How do I save a trained model in PyTorch?


Answer

Found this page on their github repo:

Recommended approach for saving a model

There are two main approaches for serializing and restoring a model.

The first (recommended) saves and loads only the model parameters:

torch.save(the_model.state_dict(), PATH)

Then later:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

The second saves and loads the entire model:

torch.save(the_model, PATH)

Then later:

the_model = torch.load(PATH)

However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.


See also: Save and Load the Model section from the official PyTorch tutorials.

This question is answered By – dontloo

This answer is collected from stackoverflow and reviewed by FixPython community admins, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0