mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-25 02:50:59 +00:00
Hotfix to load policies saved with SB3 <= v1.6 (#1234)
* Hotfix to load policies saved with SB3 <= v1.6 * Add warning and test * Update doc
This commit is contained in:
parent
3c028f3d5c
commit
e78ba6ffa4
4 changed files with 38 additions and 4 deletions
|
|
@ -4,9 +4,17 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.7.0a9 (WIP)
|
||||
Release 1.7.0a10 (WIP)
|
||||
--------------------------
|
||||
|
||||
.. note::
|
||||
|
||||
A2C and PPO saved with SB3 < 1.7.0 will show a warning about
|
||||
missing keys in the state dict when loaded with SB3 >= 1.7.0.
|
||||
To suppress the warning, simply save the model again.
|
||||
You can find more info in `issue #1233 <https://github.com/DLR-RM/stable-baselines3/issues/1233>`_
|
||||
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import io
|
||||
import pathlib
|
||||
import time
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
|
@ -705,8 +706,25 @@ class BaseAlgorithm(ABC):
|
|||
model.__dict__.update(kwargs)
|
||||
model._setup_model()
|
||||
|
||||
# put state_dicts back in place
|
||||
model.set_parameters(params, exact_match=True, device=device)
|
||||
try:
|
||||
# put state_dicts back in place
|
||||
model.set_parameters(params, exact_match=True, device=device)
|
||||
except RuntimeError as e:
|
||||
# Patch to load Policy saved using SB3 < 1.7.0
|
||||
# the error is probably due to old policy being loaded
|
||||
# See https://github.com/DLR-RM/stable-baselines3/issues/1233
|
||||
if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e):
|
||||
model.set_parameters(params, exact_match=False, device=device)
|
||||
warnings.warn(
|
||||
"You are probably loading a model saved with SB3 < 1.7.0, "
|
||||
"we deactivated exact_match so you can save the model "
|
||||
"again to avoid issues in the future "
|
||||
"(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). "
|
||||
f"Original error: {e} \n"
|
||||
"Note: the model should still work fine, this only a warning."
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
# put other pytorch variables back in place
|
||||
if pytorch_variables is not None:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.7.0a9
|
||||
1.7.0a10
|
||||
|
|
|
|||
|
|
@ -338,6 +338,14 @@ def test_save_load_env_cnn(tmp_path, model_class):
|
|||
# clear file from os
|
||||
os.remove(tmp_path / "test_save.zip")
|
||||
|
||||
# Check we can load models saved with SB3 < 1.7.0
|
||||
if model_class == A2C:
|
||||
del model.policy.pi_features_extractor
|
||||
model.save(tmp_path / "test_save")
|
||||
with pytest.warns(UserWarning):
|
||||
model_class.load(str(tmp_path / "test_save.zip"), env=env, **kwargs).learn(100)
|
||||
os.remove(tmp_path / "test_save.zip")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
|
||||
def test_save_load_replay_buffer(tmp_path, model_class):
|
||||
|
|
|
|||
Loading…
Reference in a new issue