From fdca786f0999ea402ad5da98e3d69aa68bcbf635 Mon Sep 17 00:00:00 2001 From: Juan Rocamonde Date: Fri, 2 Sep 2022 05:10:01 +0200 Subject: [PATCH] Fix replay_buffer_class type annotation (#1042) * Fix replay_buffer_class type annotation * Update changelog * Further replacement of same type annotation issue * Formatting * Rolled back formatting changes for consistency --- docs/misc/changelog.rst | 1 + stable_baselines3/common/off_policy_algorithm.py | 2 +- stable_baselines3/ddpg/ddpg.py | 2 +- stable_baselines3/dqn/dqn.py | 2 +- stable_baselines3/sac/sac.py | 2 +- stable_baselines3/td3/td3.py | 2 +- 6 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9cefae4..48f187e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -27,6 +27,7 @@ Bug Fixes: - Added multidimensional action space support (@qgallouedec) - Fixed missing verbose parameter passing in the ``EvalCallback`` constructor (@burakdmb) - Fixed the issue that when updating the target network in DQN, SAC, TD3, the ``running_mean`` and ``running_var`` properties of batch norm layers are not updated (@honglu2875) +- Fixed incorrect type annotation of the replay_buffer_class argument in ``common.OffPolicyAlgorithm`` initializer, where an instance instead of a class was required (@Rocamonde) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index b841eb0..ebac818 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -85,7 +85,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): train_freq: Union[int, Tuple[int, str]] = (1, "step"), gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[ReplayBuffer] = None, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, policy_kwargs: Optional[Dict[str, Any]] = None, diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index 53d3fb6..d208a00 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -66,7 +66,7 @@ class DDPG(TD3): train_freq: Union[int, Tuple[int, str]] = (1, "episode"), gradient_steps: int = -1, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[ReplayBuffer] = None, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, tensorboard_log: Optional[str] = None, diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 839fe33..4830e7c 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -78,7 +78,7 @@ class DQN(OffPolicyAlgorithm): gamma: float = 0.99, train_freq: Union[int, Tuple[int, str]] = 4, gradient_steps: int = 1, - replay_buffer_class: Optional[ReplayBuffer] = None, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, target_update_interval: int = 10000, diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index ba27998..f967b38 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -92,7 +92,7 @@ class SAC(OffPolicyAlgorithm): train_freq: Union[int, Tuple[int, str]] = 1, gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[ReplayBuffer] = None, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, ent_coef: Union[str, float] = "auto", diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index f440b73..c8376f5 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -80,7 +80,7 @@ class TD3(OffPolicyAlgorithm): train_freq: Union[int, Tuple[int, str]] = (1, "episode"), gradient_steps: int = -1, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[ReplayBuffer] = None, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, policy_delay: int = 2,