diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 1f5052d..9a875bf 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,17 +4,17 @@ Changelog ========== -Release 1.4.1a0 (WIP) +Release 1.4.1a1 (WIP) --------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ - +- Switched minimum Gym version to 0.21.0. New Features: ^^^^^^^^^^^^^ -- Makes the length of keys and values in `HumanOutputFormat` configurable, +- Makes the length of keys and values in ``HumanOutputFormat`` configurable, depending on desired maximum width of output. SB3-Contrib @@ -30,10 +30,10 @@ Bug Fixes: Deprecations: ^^^^^^^^^^^^^ -- Switched minimum Gym version to 0.21.0. Others: ^^^^^^^ +- Fixed pytest warnings Documentation: ^^^^^^^^^^^^^^ diff --git a/setup.cfg b/setup.cfg index 73ae3db..e23ad45 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,8 @@ filterwarnings = ignore:Parameters to load are deprecated.:DeprecationWarning ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning ignore::UserWarning:gym + ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning + ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning [pytype] inputs = stable_baselines3 diff --git a/setup.py b/setup.py index eabf30c..de615a7 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ setup( packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gym>=0.21", # Remember to also update gym version in "extra" below when this changes + "gym==0.21", # Fixed version due to breaking changes in 0.22 "numpy", "torch>=1.8.1", # For saving models @@ -116,7 +116,8 @@ setup( # For render "opencv-python", # For atari games, - "gym[atari,accept-rom-license]>=0.21", + "ale-py~=0.7.4", + "autorom[accept-rom-license]~=0.4.2", "pillow", # Tensorboard support "tensorboard>=2.2.0", diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 26e5a6e..6493a3e 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -66,6 +66,15 @@ class Image(object): class FormatUnsupportedError(NotImplementedError): + """ + Custom error to display informative message when + a value is not supported by some formats. + + :param unsupported_formats: A sequence of unsupported formats, + for instance ``["stdout"]``. + :param value_description: Description of the value that cannot be logged by this format. + """ + def __init__(self, unsupported_formats: Sequence[str], value_description: str): if len(unsupported_formats) > 1: format_str = f"formats {', '.join(unsupported_formats)} are" @@ -116,21 +125,18 @@ class SeqWriter(object): class HumanOutputFormat(KVWriter, SeqWriter): """A human-readable output format producing ASCII tables of key-value pairs. - Set attribute `max_length` to change the maximum length of keys and values - to write to output (or specify it when calling `__init__`). + Set attribute ``max_length`` to change the maximum length of keys and values + to write to output (or specify it when calling ``__init__``). + + :param filename_or_file: the file to write the log to + :param max_length: the maximum length of keys and values to write to output. + Outputs longer than this will be truncated. An error will be raised + if multiple keys are truncated to the same value. The maximum output + width will be ``2*max_length + 7``. The default of 36 produces output + no longer than 79 characters wide. """ def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36): - """ - log to a file, in a human readable format - - :param filename_or_file: the file to write the log to - :param max_length: the maximum length of keys and values to write to output. - Outputs longer than this will be truncated. An error will be raised - if multiple keys are truncated to the same value. The maximum output - width will be ``2*max_length + 7``. The default of 36 produces output - no longer than 79 characters wide. - """ self.max_length = max_length if isinstance(filename_or_file, str): self.file = open(filename_or_file, "wt") @@ -174,7 +180,7 @@ class HumanOutputFormat(KVWriter, SeqWriter): truncated_key = self._truncate(key) if truncated_key in key2str: raise ValueError( - f"Key '{key}' truncated to " f"'{truncated_key}' that already exists. Consider increasing `max_length`." + f"Key '{key}' truncated to '{truncated_key}' that already exists. Consider increasing `max_length`." ) key2str[truncated_key] = self._truncate(value_str) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 97ec5cc..d012e1c 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.4.1a0 +1.4.1a1 diff --git a/tests/test_envs.py b/tests/test_envs.py index b859ed7..671e2a5 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -1,4 +1,5 @@ import types +import warnings import gym import numpy as np @@ -35,7 +36,7 @@ def test_env(env_id): :param env_id: (str) """ env = gym.make(env_id) - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_env(env) # Pendulum-v1 will produce a warning because the action space is @@ -50,7 +51,7 @@ def test_env(env_id): @pytest.mark.parametrize("env_class", ENV_CLASSES) def test_custom_envs(env_class): env = env_class() - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_env(env) # No warnings for custom envs assert len(record) == 0 @@ -68,7 +69,7 @@ def test_custom_envs(env_class): def test_bit_flipping(kwargs): # Additional tests for BitFlippingEnv env = BitFlippingEnv(**kwargs) - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_env(env) # No warnings for custom envs @@ -147,7 +148,7 @@ def test_non_default_spaces(new_obs_space): def test_non_default_action_spaces(new_action_space): env = FakeImageEnv(discrete=False) # Default, should pass the test - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_env(env) # No warnings for custom envs diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 7d810c7..19b2c90 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -580,7 +580,7 @@ def test_open_file_str_pathlib(tmp_path, pathtype): with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo" assert not record @@ -588,7 +588,7 @@ def test_open_file_str_pathlib(tmp_path, pathtype): with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo" assert not record @@ -596,7 +596,7 @@ def test_open_file_str_pathlib(tmp_path, pathtype): with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo" assert not record @@ -604,11 +604,11 @@ def test_open_file_str_pathlib(tmp_path, pathtype): with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo" assert len(record) == 0 - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo" assert len(record) == 1 @@ -616,7 +616,7 @@ def test_open_file_str_pathlib(tmp_path, pathtype): fp.write("rubbish") fp.close() # test that a warning is only raised when verbose = 0 - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close() open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close() open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close() diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index c3d1d30..8134340 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -1,4 +1,5 @@ import operator +import warnings import gym import numpy as np @@ -120,7 +121,7 @@ def make_dict_env(): def test_deprecation(): venv = DummyVecEnv([lambda: gym.make("CartPole-v1")]) venv = VecNormalize(venv) - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: assert np.allclose(venv.ret, venv.returns) # Deprecation warning when using .ret assert len(record) == 1