mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-25 22:35:14 +00:00
parent
58a98060f9
commit
7ce4bb8016
8 changed files with 42 additions and 31 deletions
|
|
@ -4,17 +4,17 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.4.1a0 (WIP)
|
||||
Release 1.4.1a1 (WIP)
|
||||
---------------------------
|
||||
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Switched minimum Gym version to 0.21.0.
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Makes the length of keys and values in `HumanOutputFormat` configurable,
|
||||
- Makes the length of keys and values in ``HumanOutputFormat`` configurable,
|
||||
depending on desired maximum width of output.
|
||||
|
||||
SB3-Contrib
|
||||
|
|
@ -30,10 +30,10 @@ Bug Fixes:
|
|||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
- Switched minimum Gym version to 0.21.0.
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Fixed pytest warnings
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ filterwarnings =
|
|||
ignore:Parameters to load are deprecated.:DeprecationWarning
|
||||
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
|
||||
ignore::UserWarning:gym
|
||||
ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning
|
||||
ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning
|
||||
|
||||
[pytype]
|
||||
inputs = stable_baselines3
|
||||
|
|
|
|||
5
setup.py
5
setup.py
|
|
@ -73,7 +73,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
|
||||
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"gym>=0.21", # Remember to also update gym version in "extra" below when this changes
|
||||
"gym==0.21", # Fixed version due to breaking changes in 0.22
|
||||
"numpy",
|
||||
"torch>=1.8.1",
|
||||
# For saving models
|
||||
|
|
@ -116,7 +116,8 @@ setup(
|
|||
# For render
|
||||
"opencv-python",
|
||||
# For atari games,
|
||||
"gym[atari,accept-rom-license]>=0.21",
|
||||
"ale-py~=0.7.4",
|
||||
"autorom[accept-rom-license]~=0.4.2",
|
||||
"pillow",
|
||||
# Tensorboard support
|
||||
"tensorboard>=2.2.0",
|
||||
|
|
|
|||
|
|
@ -66,6 +66,15 @@ class Image(object):
|
|||
|
||||
|
||||
class FormatUnsupportedError(NotImplementedError):
|
||||
"""
|
||||
Custom error to display informative message when
|
||||
a value is not supported by some formats.
|
||||
|
||||
:param unsupported_formats: A sequence of unsupported formats,
|
||||
for instance ``["stdout"]``.
|
||||
:param value_description: Description of the value that cannot be logged by this format.
|
||||
"""
|
||||
|
||||
def __init__(self, unsupported_formats: Sequence[str], value_description: str):
|
||||
if len(unsupported_formats) > 1:
|
||||
format_str = f"formats {', '.join(unsupported_formats)} are"
|
||||
|
|
@ -116,21 +125,18 @@ class SeqWriter(object):
|
|||
class HumanOutputFormat(KVWriter, SeqWriter):
|
||||
"""A human-readable output format producing ASCII tables of key-value pairs.
|
||||
|
||||
Set attribute `max_length` to change the maximum length of keys and values
|
||||
to write to output (or specify it when calling `__init__`).
|
||||
Set attribute ``max_length`` to change the maximum length of keys and values
|
||||
to write to output (or specify it when calling ``__init__``).
|
||||
|
||||
:param filename_or_file: the file to write the log to
|
||||
:param max_length: the maximum length of keys and values to write to output.
|
||||
Outputs longer than this will be truncated. An error will be raised
|
||||
if multiple keys are truncated to the same value. The maximum output
|
||||
width will be ``2*max_length + 7``. The default of 36 produces output
|
||||
no longer than 79 characters wide.
|
||||
"""
|
||||
|
||||
def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36):
|
||||
"""
|
||||
log to a file, in a human readable format
|
||||
|
||||
:param filename_or_file: the file to write the log to
|
||||
:param max_length: the maximum length of keys and values to write to output.
|
||||
Outputs longer than this will be truncated. An error will be raised
|
||||
if multiple keys are truncated to the same value. The maximum output
|
||||
width will be ``2*max_length + 7``. The default of 36 produces output
|
||||
no longer than 79 characters wide.
|
||||
"""
|
||||
self.max_length = max_length
|
||||
if isinstance(filename_or_file, str):
|
||||
self.file = open(filename_or_file, "wt")
|
||||
|
|
@ -174,7 +180,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
truncated_key = self._truncate(key)
|
||||
if truncated_key in key2str:
|
||||
raise ValueError(
|
||||
f"Key '{key}' truncated to " f"'{truncated_key}' that already exists. Consider increasing `max_length`."
|
||||
f"Key '{key}' truncated to '{truncated_key}' that already exists. Consider increasing `max_length`."
|
||||
)
|
||||
key2str[truncated_key] = self._truncate(value_str)
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.4.1a0
|
||||
1.4.1a1
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import types
|
||||
import warnings
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -35,7 +36,7 @@ def test_env(env_id):
|
|||
:param env_id: (str)
|
||||
"""
|
||||
env = gym.make(env_id)
|
||||
with pytest.warns(None) as record:
|
||||
with warnings.catch_warnings(record=True) as record:
|
||||
check_env(env)
|
||||
|
||||
# Pendulum-v1 will produce a warning because the action space is
|
||||
|
|
@ -50,7 +51,7 @@ def test_env(env_id):
|
|||
@pytest.mark.parametrize("env_class", ENV_CLASSES)
|
||||
def test_custom_envs(env_class):
|
||||
env = env_class()
|
||||
with pytest.warns(None) as record:
|
||||
with warnings.catch_warnings(record=True) as record:
|
||||
check_env(env)
|
||||
# No warnings for custom envs
|
||||
assert len(record) == 0
|
||||
|
|
@ -68,7 +69,7 @@ def test_custom_envs(env_class):
|
|||
def test_bit_flipping(kwargs):
|
||||
# Additional tests for BitFlippingEnv
|
||||
env = BitFlippingEnv(**kwargs)
|
||||
with pytest.warns(None) as record:
|
||||
with warnings.catch_warnings(record=True) as record:
|
||||
check_env(env)
|
||||
|
||||
# No warnings for custom envs
|
||||
|
|
@ -147,7 +148,7 @@ def test_non_default_spaces(new_obs_space):
|
|||
def test_non_default_action_spaces(new_action_space):
|
||||
env = FakeImageEnv(discrete=False)
|
||||
# Default, should pass the test
|
||||
with pytest.warns(None) as record:
|
||||
with warnings.catch_warnings(record=True) as record:
|
||||
check_env(env)
|
||||
|
||||
# No warnings for custom envs
|
||||
|
|
|
|||
|
|
@ -580,7 +580,7 @@ def test_open_file_str_pathlib(tmp_path, pathtype):
|
|||
with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1:
|
||||
save_to_pkl(fp1, "foo")
|
||||
assert fp1.closed
|
||||
with pytest.warns(None) as record:
|
||||
with warnings.catch_warnings(record=True) as record:
|
||||
assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo"
|
||||
assert not record
|
||||
|
||||
|
|
@ -588,7 +588,7 @@ def test_open_file_str_pathlib(tmp_path, pathtype):
|
|||
with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1:
|
||||
save_to_pkl(fp1, "foo")
|
||||
assert fp1.closed
|
||||
with pytest.warns(None) as record:
|
||||
with warnings.catch_warnings(record=True) as record:
|
||||
assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo"
|
||||
assert not record
|
||||
|
||||
|
|
@ -596,7 +596,7 @@ def test_open_file_str_pathlib(tmp_path, pathtype):
|
|||
with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1:
|
||||
save_to_pkl(fp1, "foo")
|
||||
assert fp1.closed
|
||||
with pytest.warns(None) as record:
|
||||
with warnings.catch_warnings(record=True) as record:
|
||||
assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo"
|
||||
assert not record
|
||||
|
||||
|
|
@ -604,11 +604,11 @@ def test_open_file_str_pathlib(tmp_path, pathtype):
|
|||
with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1:
|
||||
save_to_pkl(fp1, "foo")
|
||||
assert fp1.closed
|
||||
with pytest.warns(None) as record:
|
||||
with warnings.catch_warnings(record=True) as record:
|
||||
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo"
|
||||
assert len(record) == 0
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
with warnings.catch_warnings(record=True) as record:
|
||||
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo"
|
||||
assert len(record) == 1
|
||||
|
||||
|
|
@ -616,7 +616,7 @@ def test_open_file_str_pathlib(tmp_path, pathtype):
|
|||
fp.write("rubbish")
|
||||
fp.close()
|
||||
# test that a warning is only raised when verbose = 0
|
||||
with pytest.warns(None) as record:
|
||||
with warnings.catch_warnings(record=True) as record:
|
||||
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close()
|
||||
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close()
|
||||
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import operator
|
||||
import warnings
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -120,7 +121,7 @@ def make_dict_env():
|
|||
def test_deprecation():
|
||||
venv = DummyVecEnv([lambda: gym.make("CartPole-v1")])
|
||||
venv = VecNormalize(venv)
|
||||
with pytest.warns(None) as record:
|
||||
with warnings.catch_warnings(record=True) as record:
|
||||
assert np.allclose(venv.ret, venv.returns)
|
||||
# Deprecation warning when using .ret
|
||||
assert len(record) == 1
|
||||
|
|
|
|||
Loading…
Reference in a new issue