Pin gym version (#782)

* Pin gym version

* Cleanup warnings

* Reformat
This commit is contained in:
Antonin RAFFIN 2022-02-21 23:12:54 +01:00 committed by GitHub
parent 58a98060f9
commit 7ce4bb8016
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 42 additions and 31 deletions

View file

@ -4,17 +4,17 @@ Changelog
==========
Release 1.4.1a0 (WIP)
Release 1.4.1a1 (WIP)
---------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched minimum Gym version to 0.21.0.
New Features:
^^^^^^^^^^^^^
- Makes the length of keys and values in `HumanOutputFormat` configurable,
- Makes the length of keys and values in ``HumanOutputFormat`` configurable,
depending on desired maximum width of output.
SB3-Contrib
@ -30,10 +30,10 @@ Bug Fixes:
Deprecations:
^^^^^^^^^^^^^
- Switched minimum Gym version to 0.21.0.
Others:
^^^^^^^
- Fixed pytest warnings
Documentation:
^^^^^^^^^^^^^^

View file

@ -13,6 +13,8 @@ filterwarnings =
ignore:Parameters to load are deprecated.:DeprecationWarning
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
ignore::UserWarning:gym
ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning
ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning
[pytype]
inputs = stable_baselines3

View file

@ -73,7 +73,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
"gym>=0.21", # Remember to also update gym version in "extra" below when this changes
"gym==0.21", # Fixed version due to breaking changes in 0.22
"numpy",
"torch>=1.8.1",
# For saving models
@ -116,7 +116,8 @@ setup(
# For render
"opencv-python",
# For atari games,
"gym[atari,accept-rom-license]>=0.21",
"ale-py~=0.7.4",
"autorom[accept-rom-license]~=0.4.2",
"pillow",
# Tensorboard support
"tensorboard>=2.2.0",

View file

@ -66,6 +66,15 @@ class Image(object):
class FormatUnsupportedError(NotImplementedError):
"""
Custom error to display informative message when
a value is not supported by some formats.
:param unsupported_formats: A sequence of unsupported formats,
for instance ``["stdout"]``.
:param value_description: Description of the value that cannot be logged by this format.
"""
def __init__(self, unsupported_formats: Sequence[str], value_description: str):
if len(unsupported_formats) > 1:
format_str = f"formats {', '.join(unsupported_formats)} are"
@ -116,21 +125,18 @@ class SeqWriter(object):
class HumanOutputFormat(KVWriter, SeqWriter):
"""A human-readable output format producing ASCII tables of key-value pairs.
Set attribute `max_length` to change the maximum length of keys and values
to write to output (or specify it when calling `__init__`).
Set attribute ``max_length`` to change the maximum length of keys and values
to write to output (or specify it when calling ``__init__``).
:param filename_or_file: the file to write the log to
:param max_length: the maximum length of keys and values to write to output.
Outputs longer than this will be truncated. An error will be raised
if multiple keys are truncated to the same value. The maximum output
width will be ``2*max_length + 7``. The default of 36 produces output
no longer than 79 characters wide.
"""
def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36):
"""
log to a file, in a human readable format
:param filename_or_file: the file to write the log to
:param max_length: the maximum length of keys and values to write to output.
Outputs longer than this will be truncated. An error will be raised
if multiple keys are truncated to the same value. The maximum output
width will be ``2*max_length + 7``. The default of 36 produces output
no longer than 79 characters wide.
"""
self.max_length = max_length
if isinstance(filename_or_file, str):
self.file = open(filename_or_file, "wt")
@ -174,7 +180,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
truncated_key = self._truncate(key)
if truncated_key in key2str:
raise ValueError(
f"Key '{key}' truncated to " f"'{truncated_key}' that already exists. Consider increasing `max_length`."
f"Key '{key}' truncated to '{truncated_key}' that already exists. Consider increasing `max_length`."
)
key2str[truncated_key] = self._truncate(value_str)

View file

@ -1 +1 @@
1.4.1a0
1.4.1a1

View file

@ -1,4 +1,5 @@
import types
import warnings
import gym
import numpy as np
@ -35,7 +36,7 @@ def test_env(env_id):
:param env_id: (str)
"""
env = gym.make(env_id)
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_env(env)
# Pendulum-v1 will produce a warning because the action space is
@ -50,7 +51,7 @@ def test_env(env_id):
@pytest.mark.parametrize("env_class", ENV_CLASSES)
def test_custom_envs(env_class):
env = env_class()
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_env(env)
# No warnings for custom envs
assert len(record) == 0
@ -68,7 +69,7 @@ def test_custom_envs(env_class):
def test_bit_flipping(kwargs):
# Additional tests for BitFlippingEnv
env = BitFlippingEnv(**kwargs)
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_env(env)
# No warnings for custom envs
@ -147,7 +148,7 @@ def test_non_default_spaces(new_obs_space):
def test_non_default_action_spaces(new_action_space):
env = FakeImageEnv(discrete=False)
# Default, should pass the test
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_env(env)
# No warnings for custom envs

View file

@ -580,7 +580,7 @@ def test_open_file_str_pathlib(tmp_path, pathtype):
with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo"
assert not record
@ -588,7 +588,7 @@ def test_open_file_str_pathlib(tmp_path, pathtype):
with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo"
assert not record
@ -596,7 +596,7 @@ def test_open_file_str_pathlib(tmp_path, pathtype):
with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo"
assert not record
@ -604,11 +604,11 @@ def test_open_file_str_pathlib(tmp_path, pathtype):
with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo"
assert len(record) == 0
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo"
assert len(record) == 1
@ -616,7 +616,7 @@ def test_open_file_str_pathlib(tmp_path, pathtype):
fp.write("rubbish")
fp.close()
# test that a warning is only raised when verbose = 0
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close()
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close()
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close()

View file

@ -1,4 +1,5 @@
import operator
import warnings
import gym
import numpy as np
@ -120,7 +121,7 @@ def make_dict_env():
def test_deprecation():
venv = DummyVecEnv([lambda: gym.make("CartPole-v1")])
venv = VecNormalize(venv)
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
assert np.allclose(venv.ret, venv.returns)
# Deprecation warning when using .ret
assert len(record) == 1