diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 26b147b..d5968f9 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -40,6 +40,7 @@ Others: - Goal-conditioned environments are now characterized by the availability of the ``compute_reward`` method, rather than by their inheritance to ``gym.GoalEnv`` - Replaced ``CartPole-v0`` by ``CartPole-v1`` is tests - Fixed ``tests/test_distributions.py`` type hint +- Fixed ``stable_baselines3/common/type_aliases.py`` type hint Documentation: ^^^^^^^^^^^^^^ diff --git a/setup.cfg b/setup.cfg index 5e41364..cbc4402 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,7 +49,6 @@ exclude = (?x)( | stable_baselines3/common/save_util.py$ | stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$ | stable_baselines3/common/torch_layers.py$ - | stable_baselines3/common/type_aliases.py$ | stable_baselines3/common/utils.py$ | stable_baselines3/common/vec_env/__init__.py$ | stable_baselines3/common/vec_env/base_vec_env.py$ diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 4faad7d..7227667 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -36,7 +36,7 @@ class RolloutBufferSamples(NamedTuple): returns: th.Tensor -class DictRolloutBufferSamples(RolloutBufferSamples): +class DictRolloutBufferSamples(NamedTuple): observations: TensorDict actions: th.Tensor old_values: th.Tensor @@ -53,7 +53,7 @@ class ReplayBufferSamples(NamedTuple): rewards: th.Tensor -class DictReplayBufferSamples(ReplayBufferSamples): +class DictReplayBufferSamples(NamedTuple): observations: TensorDict actions: th.Tensor next_observations: TensorDict