Replacing the policy registry with policy "aliases" (#842)

* Replacing the policy registry with policy "aliases"

* Fixing import order and SAC

* Changing arg. order to be sure policy_aliases is a kwarg

* Import orders

* Removing pytype error check

* Reformat

* Fix alias import

* Not using mutable {} as default for policy_aliases

* Empty aliases initialization

* Using static attributes for policy_aliases

* Fixing isort

* Fixing back bad merge

* Running isort

* Fixing aliases for A2C and PPO

* Using f-string

* Moving policy_aliases definition position

* Adding change in the changelog

* Update version

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Grégoire Passault 2022-04-08 15:21:53 -04:00 committed by GitHub
parent 44e53ff811
commit 254bb10c42
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 71 additions and 126 deletions

View file

@ -4,11 +4,13 @@ Changelog
==========
Release 1.5.1a0 (WIP)
Release 1.5.1a1 (WIP)
---------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former
``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar)
New Features:
^^^^^^^^^^^^^

View file

@ -5,7 +5,7 @@ from gym import spaces
from torch.nn import functional as F
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance
@ -51,6 +51,12 @@ class A2C(OnPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": ActorCriticPolicy,
"CnnPolicy": ActorCriticCnnPolicy,
"MultiInputPolicy": MultiInputActorCriticPolicy,
}
def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],

View file

@ -1,16 +1,7 @@
# This file is here just to define MlpPolicy/CnnPolicy
# that work for A2C
from stable_baselines3.common.policies import (
ActorCriticCnnPolicy,
ActorCriticPolicy,
MultiInputActorCriticPolicy,
register_policy,
)
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
MultiInputPolicy = MultiInputActorCriticPolicy
register_policy("MlpPolicy", ActorCriticPolicy)
register_policy("CnnPolicy", ActorCriticCnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View file

@ -17,7 +17,7 @@ from stable_baselines3.common.env_util import is_wrapped
from stable_baselines3.common.logger import Logger
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
@ -60,7 +60,6 @@ class BaseAlgorithm(ABC):
:param policy: Policy object
:param env: The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param policy_base: The base policy used by this method
:param learning_rate: learning rate for the optimizer,
it can be a function of the current progress remaining (from 1 to 0)
:param policy_kwargs: Additional arguments to be passed to the policy on creation
@ -83,11 +82,13 @@ class BaseAlgorithm(ABC):
:param supported_action_spaces: The action spaces supported by the algorithm.
"""
# Policy aliases (see _get_policy_from_name())
policy_aliases: Dict[str, Type[BasePolicy]] = {}
def __init__(
self,
policy: Type[BasePolicy],
env: Union[GymEnv, str, None],
policy_base: Type[BasePolicy],
learning_rate: Union[float, Schedule],
policy_kwargs: Optional[Dict[str, Any]] = None,
tensorboard_log: Optional[str] = None,
@ -101,9 +102,8 @@ class BaseAlgorithm(ABC):
sde_sample_freq: int = -1,
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
):
if isinstance(policy, str) and policy_base is not None:
self.policy_class = get_policy_from_name(policy_base, policy)
if isinstance(policy, str):
self.policy_class = self._get_policy_from_name(policy)
else:
self.policy_class = policy
@ -325,6 +325,23 @@ class BaseAlgorithm(ABC):
"_custom_logger",
]
def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]:
"""
Get a policy class from its name representation.
The goal here is to standardize policy naming, e.g.
all algorithms can call upon "MlpPolicy" or "CnnPolicy",
and they receive respective policies that work for them.
:param policy_name: Alias of the policy
:return: A policy class (type)
"""
if policy_name in self.policy_aliases:
return self.policy_aliases[policy_name]
else:
raise ValueError(f"Policy {policy_name} unknown")
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
"""
Get the name of the torch variables that will be saved with

View file

@ -28,7 +28,6 @@ class OffPolicyAlgorithm(BaseAlgorithm):
:param policy: Policy object
:param env: The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param policy_base: The base policy used by this method
:param learning_rate: learning rate for the optimizer,
it can be a function of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
@ -76,7 +75,6 @@ class OffPolicyAlgorithm(BaseAlgorithm):
self,
policy: Type[BasePolicy],
env: Union[GymEnv, str],
policy_base: Type[BasePolicy],
learning_rate: Union[float, Schedule],
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
@ -107,7 +105,6 @@ class OffPolicyAlgorithm(BaseAlgorithm):
super(OffPolicyAlgorithm, self).__init__(
policy=policy,
env=env,
policy_base=policy_base,
learning_rate=learning_rate,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,

View file

@ -8,7 +8,7 @@ import torch as th
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
@ -34,7 +34,6 @@ class OnPolicyAlgorithm(BaseAlgorithm):
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param policy_base: The base policy used by this method
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
@ -62,7 +61,6 @@ class OnPolicyAlgorithm(BaseAlgorithm):
max_grad_norm: float,
use_sde: bool,
sde_sample_freq: int,
policy_base: Type[BasePolicy] = ActorCriticPolicy,
tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
monitor_wrapper: bool = True,
@ -77,7 +75,6 @@ class OnPolicyAlgorithm(BaseAlgorithm):
super(OnPolicyAlgorithm, self).__init__(
policy=policy,
env=env,
policy_base=policy_base,
learning_rate=learning_rate,
policy_kwargs=policy_kwargs,
verbose=verbose,

View file

@ -894,68 +894,3 @@ class ContinuousCritic(BaseModel):
with th.no_grad():
features = self.extract_features(obs)
return self.q_networks[0](th.cat([features, actions], dim=1))
_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]]
def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]:
"""
Returns the registered policy from the base type and name.
See `register_policy` for registering policies and explanation.
:param base_policy_type: the base policy class
:param name: the policy name
:return: the policy
"""
if base_policy_type not in _policy_registry:
raise KeyError(f"Error: the policy type {base_policy_type} is not registered!")
if name not in _policy_registry[base_policy_type]:
raise KeyError(
f"Error: unknown policy type {name},"
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!"
)
return _policy_registry[base_policy_type][name]
def register_policy(name: str, policy: Type[BasePolicy]) -> None:
"""
Register a policy, so it can be called using its name.
e.g. SAC('MlpPolicy', ...) instead of SAC(MlpPolicy, ...).
The goal here is to standardize policy naming, e.g.
all algorithms can call upon "MlpPolicy" or "CnnPolicy",
and they receive respective policies that work for them.
Consider following:
OnlinePolicy
-- OnlineMlpPolicy ("MlpPolicy")
-- OnlineCnnPolicy ("CnnPolicy")
OfflinePolicy
-- OfflineMlpPolicy ("MlpPolicy")
-- OfflineCnnPolicy ("CnnPolicy")
Two policies have name "MlpPolicy" and two have "CnnPolicy".
In `get_policy_from_name`, the parent class (e.g. OnlinePolicy)
is given and used to select and return the correct policy.
:param name: the policy name
:param policy: the policy class
"""
sub_class = None
for cls in BasePolicy.__subclasses__():
if issubclass(policy, cls):
sub_class = cls
break
if sub_class is None:
raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!")
if sub_class not in _policy_registry:
_policy_registry[sub_class] = {}
if name in _policy_registry[sub_class]:
# Check if the registered policy is same
# we try to register. If not so,
# do not override and complain.
if _policy_registry[sub_class][name] != policy:
raise ValueError(f"Error: the name {name} is already registered for a different policy, will not override.")
_policy_registry[sub_class][name] = policy

View file

@ -8,10 +8,11 @@ from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
from stable_baselines3.dqn.policies import DQNPolicy
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy
class DQN(OffPolicyAlgorithm):
@ -59,6 +60,12 @@ class DQN(OffPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
def __init__(
self,
policy: Union[str, Type[DQNPolicy]],
@ -91,7 +98,6 @@ class DQN(OffPolicyAlgorithm):
super(DQN, self).__init__(
policy,
env,
DQNPolicy,
learning_rate,
buffer_size,
learning_starts,

View file

@ -4,7 +4,7 @@ import gym
import torch as th
from torch import nn
from stable_baselines3.common.policies import BasePolicy, register_policy
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
@ -296,8 +296,3 @@ class MultiInputPolicy(DQNPolicy):
optimizer_class,
optimizer_kwargs,
)
register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View file

@ -1,16 +1,7 @@
# This file is here just to define MlpPolicy/CnnPolicy
# that work for PPO
from stable_baselines3.common.policies import (
ActorCriticCnnPolicy,
ActorCriticPolicy,
MultiInputActorCriticPolicy,
register_policy,
)
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
MultiInputPolicy = MultiInputActorCriticPolicy
register_policy("MlpPolicy", ActorCriticPolicy)
register_policy("CnnPolicy", ActorCriticCnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View file

@ -7,7 +7,7 @@ from gym import spaces
from torch.nn import functional as F
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
@ -65,6 +65,12 @@ class PPO(OnPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": ActorCriticPolicy,
"CnnPolicy": ActorCriticCnnPolicy,
"MultiInputPolicy": MultiInputActorCriticPolicy,
}
def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],

View file

@ -6,7 +6,7 @@ import torch as th
from torch import nn
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
@ -514,8 +514,3 @@ class MultiInputPolicy(SACPolicy):
n_critics,
share_features_extractor,
)
register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View file

@ -8,9 +8,10 @@ from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.sac.policies import SACPolicy
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
class SAC(OffPolicyAlgorithm):
@ -72,6 +73,12 @@ class SAC(OffPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
def __init__(
self,
policy: Union[str, Type[SACPolicy]],
@ -106,7 +113,6 @@ class SAC(OffPolicyAlgorithm):
super(SAC, self).__init__(
policy,
env,
SACPolicy,
learning_rate,
buffer_size,
learning_starts,

View file

@ -4,7 +4,7 @@ import gym
import torch as th
from torch import nn
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
@ -351,8 +351,3 @@ class MultiInputPolicy(TD3Policy):
n_critics,
share_features_extractor,
)
register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View file

@ -8,9 +8,10 @@ from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.td3.policies import TD3Policy
from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy
class TD3(OffPolicyAlgorithm):
@ -60,6 +61,12 @@ class TD3(OffPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
def __init__(
self,
policy: Union[str, Type[TD3Policy]],
@ -91,7 +98,6 @@ class TD3(OffPolicyAlgorithm):
super(TD3, self).__init__(
policy,
env,
TD3Policy,
learning_rate,
buffer_size,
learning_starts,

View file

@ -1 +1 @@
1.5.1a0
1.5.1a1