From a2b1bf06d36bbf2dd9101e7910abb4d4d5d1090e Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 14 Feb 2020 11:12:07 +0100 Subject: [PATCH] Add `squash_output` attribute to policy --- torchy_baselines/common/base_class.py | 11 +++-- torchy_baselines/common/callbacks.py | 8 ++-- torchy_baselines/common/policies.py | 59 +++++++++++++++++---------- torchy_baselines/ppo/policies.py | 2 +- torchy_baselines/sac/policies.py | 2 +- torchy_baselines/td3/policies.py | 8 ++-- 6 files changed, 53 insertions(+), 37 deletions(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index c0951cb..eabc7f1 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -86,7 +86,7 @@ class BaseRLModel(ABC): self.num_timesteps = 0 self.eval_env = None self.seed = seed - self.action_noise = None # type: ActionNoise + self.action_noise = None # type: Optional[ActionNoise] self.start_time = None self.policy = None self.learning_rate = None @@ -97,8 +97,8 @@ class BaseRLModel(ABC): # this is used to update the learning rate self._current_progress = 1 # Buffers for logging - self.ep_info_buffer = None # type: deque - self.ep_success_buffer = None # type: deque + self.ep_info_buffer = None # type: Optional[deque] + self.ep_success_buffer = None # type: Optional[deque] # Create and wrap the env if needed if env is not None: @@ -387,13 +387,12 @@ class BaseRLModel(ABC): actions = actions.cpu().numpy() # Rescale to proper domain when using squashing - # TODO: should not be used for a Gaussian distribution? - if isinstance(self.action_space, gym.spaces.Box): + if isinstance(self.action_space, gym.spaces.Box) and self.policy.squash_output: actions = self.unscale_action(actions) clipped_actions = actions # Clip the actions to avoid out of bound error when using gaussian distribution - if isinstance(self.action_space, gym.spaces.Box): + if isinstance(self.action_space, gym.spaces.Box) and not self.policy.squash_output: clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) if not vectorized_env: diff --git a/torchy_baselines/common/callbacks.py b/torchy_baselines/common/callbacks.py index 8c0d108..392716c 100644 --- a/torchy_baselines/common/callbacks.py +++ b/torchy_baselines/common/callbacks.py @@ -22,14 +22,14 @@ class BaseCallback(ABC): """ def __init__(self, verbose: int = 0): super(BaseCallback, self).__init__() - self.model = None # type: BaseRLModel + self.model = None # type: Optional[BaseRLModel] self.training_env = None # type: Union[gym.Env, VecEnv, None] self.n_calls = 0 # type: int self.num_timesteps = 0 # type: int self.verbose = verbose - self.locals = None # type: Dict[str, Any] - self.globals = None # type: Dict[str, Any] - self.logger = None # type: Logger + self.locals = None # type: Optional[Dict[str, Any]] + self.globals = None # type: Optional[Dict[str, Any]] + self.logger = None # type: Optional[Logger] # Sometimes, for event callback, it is useful # to have access to the parent object self.parent = None # type: Optional[BaseCallback] diff --git a/torchy_baselines/common/policies.py b/torchy_baselines/common/policies.py index 1e0064d..abab7d8 100644 --- a/torchy_baselines/common/policies.py +++ b/torchy_baselines/common/policies.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Type, Dict, List, Tuple from itertools import zip_longest @@ -14,14 +14,24 @@ class BasePolicy(nn.Module): :param observation_space: (gym.spaces.Space) The observation space of the environment :param action_space: (gym.spaces.Space) The action space of the environment + :param device: (Union[th.device, str]) Device on which the code should run. + :param squash_output: (bool) For continuous actions, whether the output is squashed + or not using a `tanh()` function. """ - def __init__(self, observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, device: Union[th.device, str] = 'cpu'): + action_space: gym.spaces.Space, + device: Union[th.device, str] = 'cpu', + squash_output: bool = False): super(BasePolicy, self).__init__() self.observation_space = observation_space self.action_space = action_space self.device = device + self._squash_output = squash_output + + @property + def squash_output(self) -> bool: + """ (bool) Getter for squash_output.""" + return self._squash_output @staticmethod def init_weights(module: nn.Module, gain: float = 1): @@ -71,21 +81,25 @@ class BasePolicy(nn.Module): return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy() -def create_mlp(input_dim, output_dim, net_arch, - activation_fn=nn.ReLU, squash_out=False): +def create_mlp(input_dim: int, + output_dim: int, + net_arch: List[int], + activation_fn: nn.Module = nn.ReLU, + squash_output: bool = False) -> List[nn.Module]: """ Create a multi layer perceptron (MLP), which is a collection of fully-connected layers each followed by an activation function. :param input_dim: (int) Dimension of the input vector :param output_dim: (int) - :param net_arch: ([int]) Architecture of the neural net + :param net_arch: (List[int]) Architecture of the neural net It represents the number of units per layer. The length of this list is the number of layers. - :param activation_fn: (th.nn.Module) The activation function + :param activation_fn: (nn.Module) The activation function to use after each layer. - :param squash_out: (bool) Whether to squash the output using a Tanh + :param squash_output: (bool) Whether to squash the output using a Tanh activation function + :return: (List[nn.Module]) """ if len(net_arch) > 0: @@ -99,12 +113,14 @@ def create_mlp(input_dim, output_dim, net_arch, if output_dim > 0: modules.append(nn.Linear(net_arch[-1], output_dim)) - if squash_out: + if squash_output: modules.append(nn.Tanh()) return modules -def create_sde_feature_extractor(features_dim, sde_net_arch, activation_fn): +def create_sde_feature_extractor(features_dim: int, + sde_net_arch: List[int], + activation_fn: nn.Module) -> Tuple[nn.Sequential, int]: """ Create the neural network that will be used to extract features for the SDE. @@ -117,7 +133,7 @@ def create_sde_feature_extractor(features_dim, sde_net_arch, activation_fn): # Special case: when using states as features (i.e. sde_net_arch is an empty list) # don't use any activation function sde_activation = activation_fn if len(sde_net_arch) > 0 else None - latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_out=False) + latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_output=False) latent_sde_dim = sde_net_arch[-1] if len(sde_net_arch) > 0 else features_dim sde_feature_extractor = nn.Sequential(*latent_sde_net) return sde_feature_extractor, latent_sde_dim @@ -131,7 +147,7 @@ class BaseNetwork(nn.Module): def __init__(self): super(BaseNetwork, self).__init__() - def load_from_vector(self, vector): + def load_from_vector(self, vector: np.ndarray): """ Load parameters from a 1D vector. @@ -140,7 +156,7 @@ class BaseNetwork(nn.Module): device = next(self.parameters()).device th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(device), self.parameters()) - def parameters_to_vector(self): + def parameters_to_vector(self) -> np.ndarray: """ Convert the parameters to a 1D vector. @@ -149,16 +165,16 @@ class BaseNetwork(nn.Module): return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy() -_policy_registry = dict() +_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]] -def get_policy_from_name(base_policy_type, name): +def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]: """ - returns the registed policy from the base type and name + Returns the registered policy from the base type and name - :param base_policy_type: (BasePolicy) the base policy object + :param base_policy_type: (Type[BasePolicy]) the base policy class :param name: (str) the policy name - :return: (base_policy_type) the policy + :return: (Type[BasePolicy]) the policy """ if base_policy_type not in _policy_registry: raise ValueError(f"Error: the policy type {base_policy_type} is not registered!") @@ -168,12 +184,13 @@ def get_policy_from_name(base_policy_type, name): return _policy_registry[base_policy_type][name] -def register_policy(name, policy): +def register_policy(name: str, policy: Type[BasePolicy]) -> None: """ - returns the registed policy from the base type and name + Register a policy, so it can be called using its name. + e.g. SAC('MlpPolicy', ...) instead of SAC(MlpPolicy, ...) :param name: (str) the policy name - :param policy: (subclass of BasePolicy) the policy + :param policy: (Type[BasePolicy]) the policy class """ sub_class = None # For building the doc diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index d303a66..3e47375 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -41,7 +41,7 @@ class PPOPolicy(BasePolicy): ortho_init=True, use_sde=False, log_std_init=0.0, full_std=True, sde_net_arch=None, use_expln=False, squash_output=False): - super(PPOPolicy, self).__init__(observation_space, action_space, device) + super(PPOPolicy, self).__init__(observation_space, action_space, device, squash_output=squash_output) self.obs_dim = self.observation_space.shape[0] # Default network architecture, from stable-baselines diff --git a/torchy_baselines/sac/policies.py b/torchy_baselines/sac/policies.py index 3fbea93..4f96738 100644 --- a/torchy_baselines/sac/policies.py +++ b/torchy_baselines/sac/policies.py @@ -200,7 +200,7 @@ class SACPolicy(BasePolicy): learning_rate, net_arch=None, device='cpu', activation_fn=nn.ReLU, use_sde=False, log_std_init=-3, sde_net_arch=None, use_expln=False): - super(SACPolicy, self).__init__(observation_space, action_space, device) + super(SACPolicy, self).__init__(observation_space, action_space, device, squash_output=True) if net_arch is None: net_arch = [256, 256] diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index fa19952..f75cdea 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -52,7 +52,7 @@ class Actor(BaseNetwork): self.sde_feature_extractor = None if use_sde: - latent_pi_net = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_out=False) + latent_pi_net = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_output=False) self.latent_pi = nn.Sequential(*latent_pi_net) latent_sde_dim = net_arch[-1] learn_features = sde_net_arch is not None @@ -74,7 +74,7 @@ class Actor(BaseNetwork): self.sde_optimizer = th.optim.Adam([self.log_std], lr=lr_sde) self.reset_noise() else: - actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True) + actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_output=True) self.mu = nn.Sequential(*actor_net) def get_std(self) -> torch.Tensor: @@ -134,7 +134,7 @@ class Actor(BaseNetwork): if self.clip_noise is not None: noise = th.clamp(noise, -self.clip_noise, self.clip_noise) # TODO: Replace with squashing -> need to account for that in the sde update - # -> set squash_out=True in the action_dist? + # -> set squash_output=True in the action_dist? # NOTE: the clipping is done in the rollout for now return self.mu(latent_pi) + noise # action, _ = self._get_action_dist_from_latent(latent_pi) @@ -215,7 +215,7 @@ class TD3Policy(BasePolicy): learning_rate, net_arch=None, device='cpu', activation_fn=nn.ReLU, use_sde=False, log_std_init=-3, clip_noise=None, lr_sde=3e-4, sde_net_arch=None, use_expln=False): - super(TD3Policy, self).__init__(observation_space, action_space, device) + super(TD3Policy, self).__init__(observation_space, action_space, device, squash_output=True) # Default network architecture, from the original paper if net_arch is None: