diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index bab86a4..692aa7f 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,9 +4,17 @@ Changelog ========== -Release 1.7.0a9 (WIP) +Release 1.7.0a10 (WIP) -------------------------- +.. note:: + + A2C and PPO saved with SB3 < 1.7.0 will show a warning about + missing keys in the state dict when loaded with SB3 >= 1.7.0. + To suppress the warning, simply save the model again. + You can find more info in `issue #1233 `_ + + Breaking Changes: ^^^^^^^^^^^^^^^^^ - Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters, diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 1309229..132f314 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -3,6 +3,7 @@ import io import pathlib import time +import warnings from abc import ABC, abstractmethod from collections import deque from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union @@ -705,8 +706,25 @@ class BaseAlgorithm(ABC): model.__dict__.update(kwargs) model._setup_model() - # put state_dicts back in place - model.set_parameters(params, exact_match=True, device=device) + try: + # put state_dicts back in place + model.set_parameters(params, exact_match=True, device=device) + except RuntimeError as e: + # Patch to load Policy saved using SB3 < 1.7.0 + # the error is probably due to old policy being loaded + # See https://github.com/DLR-RM/stable-baselines3/issues/1233 + if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e): + model.set_parameters(params, exact_match=False, device=device) + warnings.warn( + "You are probably loading a model saved with SB3 < 1.7.0, " + "we deactivated exact_match so you can save the model " + "again to avoid issues in the future " + "(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). " + f"Original error: {e} \n" + "Note: the model should still work fine, this only a warning." + ) + else: + raise e # put other pytorch variables back in place if pytorch_variables is not None: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 82d01f9..89e17c2 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.7.0a9 +1.7.0a10 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index f96b69e..2c35e43 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -338,6 +338,14 @@ def test_save_load_env_cnn(tmp_path, model_class): # clear file from os os.remove(tmp_path / "test_save.zip") + # Check we can load models saved with SB3 < 1.7.0 + if model_class == A2C: + del model.policy.pi_features_extractor + model.save(tmp_path / "test_save") + with pytest.warns(UserWarning): + model_class.load(str(tmp_path / "test_save.zip"), env=env, **kwargs).learn(100) + os.remove(tmp_path / "test_save.zip") + @pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) def test_save_load_replay_buffer(tmp_path, model_class):