From 5ef10c8e69b52e1376e6c2c636737d6dd528dda1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 17 Oct 2022 10:16:20 +0200 Subject: [PATCH] Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm`` (#1120) --- docs/misc/changelog.rst | 1 + stable_baselines3/common/base_class.py | 4 ++-- stable_baselines3/common/off_policy_algorithm.py | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index ce5e3c8..ed9834f 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -23,6 +23,7 @@ SB3-Contrib Bug Fixes: ^^^^^^^^^^ - Fix return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde) +- Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm`` Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 05a8cf4..d993f9b 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -60,7 +60,7 @@ class BaseAlgorithm(ABC): """ The base of RL algorithms - :param policy: Policy object + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) :param learning_rate: learning rate for the optimizer, @@ -89,7 +89,7 @@ class BaseAlgorithm(ABC): def __init__( self, - policy: Type[BasePolicy], + policy: Union[str, Type[BasePolicy]], env: Union[GymEnv, str, None], learning_rate: Union[float, Schedule], policy_kwargs: Optional[Dict[str, Any]] = None, diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index d309322..9c2bda2 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -28,7 +28,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): """ The base for Off-Policy algorithms (ex: SAC/TD3) - :param policy: Policy object + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) :param learning_rate: learning rate for the optimizer, @@ -75,7 +75,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): def __init__( self, - policy: Type[BasePolicy], + policy: Union[str, Type[BasePolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule], buffer_size: int = 1_000_000, # 1e6