PyTorch: What's the purpose of saving the optimizer state?

PyTorch is capable of saving and loading the state of an optimizer. An example is shown in the PyTorch tutorial. I'm currently just saving and loading the model state but not the optimizer. So what's the point of saving and loading the optimizer state besides not having to remember the optimizers params such as the learningrate. And what's contained in the optimizer state?


Solution 1:

You should save the optimizer state if you want to resume model training later. Especially if Adam is your optimizer. Adam is an adaptive learning rate method, which means it computes individual learning rates for various parameters.

It is not required if you only want to use the saved model for inference.

However, It's best practice to save both model state and optimizer state. You can also save loss history and other running metrics if you want to plot them later.

I'd do it like,

    torch.save({
            'epoch': epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss_history': loss_history,
            }, PATH)

Solution 2:

I believe that saving the optimizer's state is an important aspect of logging and reproducibility. It stores many details about the optimizer's settings; things including the kind of optimizer used, learning rate, weight decay, type of scheduler used (I find this very useful personally), etc. Moreover, it can be used in a similar fashion when loading pre-trained weights into your current model via .load_state_dict() such that you can pass in some stored optimizer setting/configuration into your current optimizer using the same method: optimizer.load_state_dict(some_good_optimizer.state_dict()).