Fix type annotation of `policy in BaseAlgorithm and OffPolicyAlgorithm` (#1120)

This commit is contained in:
Quentin Gallouédec 2022-10-17 10:16:20 +02:00 committed by GitHub
parent cdcdd32c51
commit 5ef10c8e69
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 5 additions and 4 deletions

View file

@ -23,6 +23,7 @@ SB3-Contrib
Bug Fixes:
^^^^^^^^^^
- Fix return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde)
- Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm``
Deprecations:
^^^^^^^^^^^^^

View file

@ -60,7 +60,7 @@ class BaseAlgorithm(ABC):
"""
The base of RL algorithms
:param policy: Policy object
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param learning_rate: learning rate for the optimizer,
@ -89,7 +89,7 @@ class BaseAlgorithm(ABC):
def __init__(
self,
policy: Type[BasePolicy],
policy: Union[str, Type[BasePolicy]],
env: Union[GymEnv, str, None],
learning_rate: Union[float, Schedule],
policy_kwargs: Optional[Dict[str, Any]] = None,

View file

@ -28,7 +28,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
"""
The base for Off-Policy algorithms (ex: SAC/TD3)
:param policy: Policy object
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param learning_rate: learning rate for the optimizer,
@ -75,7 +75,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
def __init__(
self,
policy: Type[BasePolicy],
policy: Union[str, Type[BasePolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule],
buffer_size: int = 1_000_000, # 1e6