mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-31 23:28:05 +00:00
Fix type annotation of `policy in BaseAlgorithm and OffPolicyAlgorithm` (#1120)
This commit is contained in:
parent
cdcdd32c51
commit
5ef10c8e69
3 changed files with 5 additions and 4 deletions
|
|
@ -23,6 +23,7 @@ SB3-Contrib
|
|||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fix return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde)
|
||||
- Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm``
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class BaseAlgorithm(ABC):
|
|||
"""
|
||||
The base of RL algorithms
|
||||
|
||||
:param policy: Policy object
|
||||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: The environment to learn from
|
||||
(if registered in Gym, can be str. Can be None for loading trained models)
|
||||
:param learning_rate: learning rate for the optimizer,
|
||||
|
|
@ -89,7 +89,7 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Type[BasePolicy],
|
||||
policy: Union[str, Type[BasePolicy]],
|
||||
env: Union[GymEnv, str, None],
|
||||
learning_rate: Union[float, Schedule],
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
"""
|
||||
The base for Off-Policy algorithms (ex: SAC/TD3)
|
||||
|
||||
:param policy: Policy object
|
||||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: The environment to learn from
|
||||
(if registered in Gym, can be str. Can be None for loading trained models)
|
||||
:param learning_rate: learning rate for the optimizer,
|
||||
|
|
@ -75,7 +75,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Type[BasePolicy],
|
||||
policy: Union[str, Type[BasePolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule],
|
||||
buffer_size: int = 1_000_000, # 1e6
|
||||
|
|
|
|||
Loading…
Reference in a new issue