From 6902fac5e7bacb3efddba94d6cb59167555bd3b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 29 Nov 2022 12:26:16 +0100 Subject: [PATCH] Fix `stable_baselines3/common/type_aliases.py` type hint (#1189) --- docs/misc/changelog.rst | 1 + setup.cfg | 1 - stable_baselines3/common/type_aliases.py | 4 ++-- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 26b147b..d5968f9 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -40,6 +40,7 @@ Others: - Goal-conditioned environments are now characterized by the availability of the ``compute_reward`` method, rather than by their inheritance to ``gym.GoalEnv`` - Replaced ``CartPole-v0`` by ``CartPole-v1`` is tests - Fixed ``tests/test_distributions.py`` type hint +- Fixed ``stable_baselines3/common/type_aliases.py`` type hint Documentation: ^^^^^^^^^^^^^^ diff --git a/setup.cfg b/setup.cfg index 5e41364..cbc4402 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,7 +49,6 @@ exclude = (?x)( | stable_baselines3/common/save_util.py$ | stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$ | stable_baselines3/common/torch_layers.py$ - | stable_baselines3/common/type_aliases.py$ | stable_baselines3/common/utils.py$ | stable_baselines3/common/vec_env/__init__.py$ | stable_baselines3/common/vec_env/base_vec_env.py$ diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 4faad7d..7227667 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -36,7 +36,7 @@ class RolloutBufferSamples(NamedTuple): returns: th.Tensor -class DictRolloutBufferSamples(RolloutBufferSamples): +class DictRolloutBufferSamples(NamedTuple): observations: TensorDict actions: th.Tensor old_values: th.Tensor @@ -53,7 +53,7 @@ class ReplayBufferSamples(NamedTuple): rewards: th.Tensor -class DictReplayBufferSamples(ReplayBufferSamples): +class DictReplayBufferSamples(NamedTuple): observations: TensorDict actions: th.Tensor next_observations: TensorDict