From 254bb10c42e8f892e43af9da25aefc7c604c317c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9goire=20Passault?= Date: Fri, 8 Apr 2022 15:21:53 -0400 Subject: [PATCH] Replacing the policy registry with policy "aliases" (#842) * Replacing the policy registry with policy "aliases" * Fixing import order and SAC * Changing arg. order to be sure policy_aliases is a kwarg * Import orders * Removing pytype error check * Reformat * Fix alias import * Not using mutable {} as default for policy_aliases * Empty aliases initialization * Using static attributes for policy_aliases * Fixing isort * Fixing back bad merge * Running isort * Fixing aliases for A2C and PPO * Using f-string * Moving policy_aliases definition position * Adding change in the changelog * Update version Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 4 +- stable_baselines3/a2c/a2c.py | 8 ++- stable_baselines3/a2c/policies.py | 11 +--- stable_baselines3/common/base_class.py | 29 +++++++-- .../common/off_policy_algorithm.py | 3 - .../common/on_policy_algorithm.py | 5 +- stable_baselines3/common/policies.py | 65 ------------------- stable_baselines3/dqn/dqn.py | 10 ++- stable_baselines3/dqn/policies.py | 7 +- stable_baselines3/ppo/policies.py | 11 +--- stable_baselines3/ppo/ppo.py | 8 ++- stable_baselines3/sac/policies.py | 7 +- stable_baselines3/sac/sac.py | 10 ++- stable_baselines3/td3/policies.py | 7 +- stable_baselines3/td3/td3.py | 10 ++- stable_baselines3/version.txt | 2 +- 16 files changed, 71 insertions(+), 126 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e17c3df..b209f16 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,11 +4,13 @@ Changelog ========== -Release 1.5.1a0 (WIP) +Release 1.5.1a1 (WIP) --------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former + ``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar) New Features: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 837ec42..eeeb670 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -5,7 +5,7 @@ from gym import spaces from torch.nn import functional as F from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance @@ -51,6 +51,12 @@ class A2C(OnPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": ActorCriticPolicy, + "CnnPolicy": ActorCriticCnnPolicy, + "MultiInputPolicy": MultiInputActorCriticPolicy, + } + def __init__( self, policy: Union[str, Type[ActorCriticPolicy]], diff --git a/stable_baselines3/a2c/policies.py b/stable_baselines3/a2c/policies.py index 79c85f8..7299b34 100644 --- a/stable_baselines3/a2c/policies.py +++ b/stable_baselines3/a2c/policies.py @@ -1,16 +1,7 @@ # This file is here just to define MlpPolicy/CnnPolicy # that work for A2C -from stable_baselines3.common.policies import ( - ActorCriticCnnPolicy, - ActorCriticPolicy, - MultiInputActorCriticPolicy, - register_policy, -) +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy MlpPolicy = ActorCriticPolicy CnnPolicy = ActorCriticCnnPolicy MultiInputPolicy = MultiInputActorCriticPolicy - -register_policy("MlpPolicy", ActorCriticPolicy) -register_policy("CnnPolicy", ActorCriticCnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 25c2638..14570be 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -17,7 +17,7 @@ from stable_baselines3.common.env_util import is_wrapped from stable_baselines3.common.logger import Logger from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.noise import ActionNoise -from stable_baselines3.common.policies import BasePolicy, get_policy_from_name +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -60,7 +60,6 @@ class BaseAlgorithm(ABC): :param policy: Policy object :param env: The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) - :param policy_base: The base policy used by this method :param learning_rate: learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0) :param policy_kwargs: Additional arguments to be passed to the policy on creation @@ -83,11 +82,13 @@ class BaseAlgorithm(ABC): :param supported_action_spaces: The action spaces supported by the algorithm. """ + # Policy aliases (see _get_policy_from_name()) + policy_aliases: Dict[str, Type[BasePolicy]] = {} + def __init__( self, policy: Type[BasePolicy], env: Union[GymEnv, str, None], - policy_base: Type[BasePolicy], learning_rate: Union[float, Schedule], policy_kwargs: Optional[Dict[str, Any]] = None, tensorboard_log: Optional[str] = None, @@ -101,9 +102,8 @@ class BaseAlgorithm(ABC): sde_sample_freq: int = -1, supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, ): - - if isinstance(policy, str) and policy_base is not None: - self.policy_class = get_policy_from_name(policy_base, policy) + if isinstance(policy, str): + self.policy_class = self._get_policy_from_name(policy) else: self.policy_class = policy @@ -325,6 +325,23 @@ class BaseAlgorithm(ABC): "_custom_logger", ] + def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]: + """ + Get a policy class from its name representation. + + The goal here is to standardize policy naming, e.g. + all algorithms can call upon "MlpPolicy" or "CnnPolicy", + and they receive respective policies that work for them. + + :param policy_name: Alias of the policy + :return: A policy class (type) + """ + + if policy_name in self.policy_aliases: + return self.policy_aliases[policy_name] + else: + raise ValueError(f"Policy {policy_name} unknown") + def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: """ Get the name of the torch variables that will be saved with diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 27e8bdd..5905dee 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -28,7 +28,6 @@ class OffPolicyAlgorithm(BaseAlgorithm): :param policy: Policy object :param env: The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) - :param policy_base: The base policy used by this method :param learning_rate: learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0) :param buffer_size: size of the replay buffer @@ -76,7 +75,6 @@ class OffPolicyAlgorithm(BaseAlgorithm): self, policy: Type[BasePolicy], env: Union[GymEnv, str], - policy_base: Type[BasePolicy], learning_rate: Union[float, Schedule], buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, @@ -107,7 +105,6 @@ class OffPolicyAlgorithm(BaseAlgorithm): super(OffPolicyAlgorithm, self).__init__( policy=policy, env=env, - policy_base=policy_base, learning_rate=learning_rate, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 48cb365..281758c 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -8,7 +8,7 @@ import torch as th from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback -from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy +from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv @@ -34,7 +34,6 @@ class OnPolicyAlgorithm(BaseAlgorithm): instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) - :param policy_base: The base policy used by this method :param tensorboard_log: the log location for tensorboard (if None, no logging) :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) @@ -62,7 +61,6 @@ class OnPolicyAlgorithm(BaseAlgorithm): max_grad_norm: float, use_sde: bool, sde_sample_freq: int, - policy_base: Type[BasePolicy] = ActorCriticPolicy, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, monitor_wrapper: bool = True, @@ -77,7 +75,6 @@ class OnPolicyAlgorithm(BaseAlgorithm): super(OnPolicyAlgorithm, self).__init__( policy=policy, env=env, - policy_base=policy_base, learning_rate=learning_rate, policy_kwargs=policy_kwargs, verbose=verbose, diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 33918b7..c322dc6 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -894,68 +894,3 @@ class ContinuousCritic(BaseModel): with th.no_grad(): features = self.extract_features(obs) return self.q_networks[0](th.cat([features, actions], dim=1)) - - -_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]] - - -def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]: - """ - Returns the registered policy from the base type and name. - See `register_policy` for registering policies and explanation. - - :param base_policy_type: the base policy class - :param name: the policy name - :return: the policy - """ - if base_policy_type not in _policy_registry: - raise KeyError(f"Error: the policy type {base_policy_type} is not registered!") - if name not in _policy_registry[base_policy_type]: - raise KeyError( - f"Error: unknown policy type {name}," - f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!" - ) - return _policy_registry[base_policy_type][name] - - -def register_policy(name: str, policy: Type[BasePolicy]) -> None: - """ - Register a policy, so it can be called using its name. - e.g. SAC('MlpPolicy', ...) instead of SAC(MlpPolicy, ...). - - The goal here is to standardize policy naming, e.g. - all algorithms can call upon "MlpPolicy" or "CnnPolicy", - and they receive respective policies that work for them. - Consider following: - - OnlinePolicy - -- OnlineMlpPolicy ("MlpPolicy") - -- OnlineCnnPolicy ("CnnPolicy") - OfflinePolicy - -- OfflineMlpPolicy ("MlpPolicy") - -- OfflineCnnPolicy ("CnnPolicy") - - Two policies have name "MlpPolicy" and two have "CnnPolicy". - In `get_policy_from_name`, the parent class (e.g. OnlinePolicy) - is given and used to select and return the correct policy. - - :param name: the policy name - :param policy: the policy class - """ - sub_class = None - for cls in BasePolicy.__subclasses__(): - if issubclass(policy, cls): - sub_class = cls - break - if sub_class is None: - raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!") - - if sub_class not in _policy_registry: - _policy_registry[sub_class] = {} - if name in _policy_registry[sub_class]: - # Check if the registered policy is same - # we try to register. If not so, - # do not override and complain. - if _policy_registry[sub_class][name] != policy: - raise ValueError(f"Error: the name {name} is already registered for a different policy, will not override.") - _policy_registry[sub_class][name] = policy diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index a7aec6b..ed6073b 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -8,10 +8,11 @@ from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import maybe_transpose from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update -from stable_baselines3.dqn.policies import DQNPolicy +from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy class DQN(OffPolicyAlgorithm): @@ -59,6 +60,12 @@ class DQN(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + def __init__( self, policy: Union[str, Type[DQNPolicy]], @@ -91,7 +98,6 @@ class DQN(OffPolicyAlgorithm): super(DQN, self).__init__( policy, env, - DQNPolicy, learning_rate, buffer_size, learning_starts, diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index 099a4e3..ea00b5c 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -4,7 +4,7 @@ import gym import torch as th from torch import nn -from stable_baselines3.common.policies import BasePolicy, register_policy +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, CombinedExtractor, @@ -296,8 +296,3 @@ class MultiInputPolicy(DQNPolicy): optimizer_class, optimizer_kwargs, ) - - -register_policy("MlpPolicy", MlpPolicy) -register_policy("CnnPolicy", CnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/ppo/policies.py b/stable_baselines3/ppo/policies.py index 7427cfc..fb7afae 100644 --- a/stable_baselines3/ppo/policies.py +++ b/stable_baselines3/ppo/policies.py @@ -1,16 +1,7 @@ # This file is here just to define MlpPolicy/CnnPolicy # that work for PPO -from stable_baselines3.common.policies import ( - ActorCriticCnnPolicy, - ActorCriticPolicy, - MultiInputActorCriticPolicy, - register_policy, -) +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy MlpPolicy = ActorCriticPolicy CnnPolicy = ActorCriticCnnPolicy MultiInputPolicy = MultiInputActorCriticPolicy - -register_policy("MlpPolicy", ActorCriticPolicy) -register_policy("CnnPolicy", ActorCriticCnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 088bab3..0d05b4c 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -7,7 +7,7 @@ from gym import spaces from torch.nn import functional as F from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn @@ -65,6 +65,12 @@ class PPO(OnPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": ActorCriticPolicy, + "CnnPolicy": ActorCriticCnnPolicy, + "MultiInputPolicy": MultiInputActorCriticPolicy, + } + def __init__( self, policy: Union[str, Type[ActorCriticPolicy]], diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 0bd1382..cb6a61c 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -6,7 +6,7 @@ import torch as th from torch import nn from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution -from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy +from stable_baselines3.common.policies import BasePolicy, ContinuousCritic from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, @@ -514,8 +514,3 @@ class MultiInputPolicy(SACPolicy): n_critics, share_features_extractor, ) - - -register_policy("MlpPolicy", MlpPolicy) -register_policy("CnnPolicy", CnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 5f3a833..3703b73 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -8,9 +8,10 @@ from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import polyak_update -from stable_baselines3.sac.policies import SACPolicy +from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy class SAC(OffPolicyAlgorithm): @@ -72,6 +73,12 @@ class SAC(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + def __init__( self, policy: Union[str, Type[SACPolicy]], @@ -106,7 +113,6 @@ class SAC(OffPolicyAlgorithm): super(SAC, self).__init__( policy, env, - SACPolicy, learning_rate, buffer_size, learning_starts, diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index 264c760..ce91a0f 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -4,7 +4,7 @@ import gym import torch as th from torch import nn -from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy +from stable_baselines3.common.policies import BasePolicy, ContinuousCritic from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, @@ -351,8 +351,3 @@ class MultiInputPolicy(TD3Policy): n_critics, share_features_extractor, ) - - -register_policy("MlpPolicy", MlpPolicy) -register_policy("CnnPolicy", CnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index eb257a6..d31720b 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -8,9 +8,10 @@ from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import polyak_update -from stable_baselines3.td3.policies import TD3Policy +from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy class TD3(OffPolicyAlgorithm): @@ -60,6 +61,12 @@ class TD3(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + def __init__( self, policy: Union[str, Type[TD3Policy]], @@ -91,7 +98,6 @@ class TD3(OffPolicyAlgorithm): super(TD3, self).__init__( policy, env, - TD3Policy, learning_rate, buffer_size, learning_starts, diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 33271c4..1110517 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a0 +1.5.1a1