Hotfix to load policies saved with SB3 <= v1.6 (#1234)

* Hotfix to load policies saved with SB3 <= v1.6

* Add warning and test

* Update doc
This commit is contained in:
Antonin RAFFIN 2022-12-22 23:58:30 +01:00 committed by GitHub
parent 3c028f3d5c
commit e78ba6ffa4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 4 deletions

View file

@ -4,9 +4,17 @@ Changelog
==========
Release 1.7.0a9 (WIP)
Release 1.7.0a10 (WIP)
--------------------------
.. note::
A2C and PPO saved with SB3 < 1.7.0 will show a warning about
missing keys in the state dict when loaded with SB3 >= 1.7.0.
To suppress the warning, simply save the model again.
You can find more info in `issue #1233 <https://github.com/DLR-RM/stable-baselines3/issues/1233>`_
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,

View file

@ -3,6 +3,7 @@
import io
import pathlib
import time
import warnings
from abc import ABC, abstractmethod
from collections import deque
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
@ -705,8 +706,25 @@ class BaseAlgorithm(ABC):
model.__dict__.update(kwargs)
model._setup_model()
# put state_dicts back in place
model.set_parameters(params, exact_match=True, device=device)
try:
# put state_dicts back in place
model.set_parameters(params, exact_match=True, device=device)
except RuntimeError as e:
# Patch to load Policy saved using SB3 < 1.7.0
# the error is probably due to old policy being loaded
# See https://github.com/DLR-RM/stable-baselines3/issues/1233
if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e):
model.set_parameters(params, exact_match=False, device=device)
warnings.warn(
"You are probably loading a model saved with SB3 < 1.7.0, "
"we deactivated exact_match so you can save the model "
"again to avoid issues in the future "
"(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). "
f"Original error: {e} \n"
"Note: the model should still work fine, this only a warning."
)
else:
raise e
# put other pytorch variables back in place
if pytorch_variables is not None:

View file

@ -1 +1 @@
1.7.0a9
1.7.0a10

View file

@ -338,6 +338,14 @@ def test_save_load_env_cnn(tmp_path, model_class):
# clear file from os
os.remove(tmp_path / "test_save.zip")
# Check we can load models saved with SB3 < 1.7.0
if model_class == A2C:
del model.policy.pi_features_extractor
model.save(tmp_path / "test_save")
with pytest.warns(UserWarning):
model_class.load(str(tmp_path / "test_save.zip"), env=env, **kwargs).learn(100)
os.remove(tmp_path / "test_save.zip")
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
def test_save_load_replay_buffer(tmp_path, model_class):