Add squash_output attribute to policy

This commit is contained in:
Antonin Raffin 2020-02-14 11:12:07 +01:00
parent aa8b4eb22a
commit a2b1bf06d3
6 changed files with 53 additions and 37 deletions

View file

@ -86,7 +86,7 @@ class BaseRLModel(ABC):
self.num_timesteps = 0
self.eval_env = None
self.seed = seed
self.action_noise = None # type: ActionNoise
self.action_noise = None # type: Optional[ActionNoise]
self.start_time = None
self.policy = None
self.learning_rate = None
@ -97,8 +97,8 @@ class BaseRLModel(ABC):
# this is used to update the learning rate
self._current_progress = 1
# Buffers for logging
self.ep_info_buffer = None # type: deque
self.ep_success_buffer = None # type: deque
self.ep_info_buffer = None # type: Optional[deque]
self.ep_success_buffer = None # type: Optional[deque]
# Create and wrap the env if needed
if env is not None:
@ -387,13 +387,12 @@ class BaseRLModel(ABC):
actions = actions.cpu().numpy()
# Rescale to proper domain when using squashing
# TODO: should not be used for a Gaussian distribution?
if isinstance(self.action_space, gym.spaces.Box):
if isinstance(self.action_space, gym.spaces.Box) and self.policy.squash_output:
actions = self.unscale_action(actions)
clipped_actions = actions
# Clip the actions to avoid out of bound error when using gaussian distribution
if isinstance(self.action_space, gym.spaces.Box):
if isinstance(self.action_space, gym.spaces.Box) and not self.policy.squash_output:
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
if not vectorized_env:

View file

@ -22,14 +22,14 @@ class BaseCallback(ABC):
"""
def __init__(self, verbose: int = 0):
super(BaseCallback, self).__init__()
self.model = None # type: BaseRLModel
self.model = None # type: Optional[BaseRLModel]
self.training_env = None # type: Union[gym.Env, VecEnv, None]
self.n_calls = 0 # type: int
self.num_timesteps = 0 # type: int
self.verbose = verbose
self.locals = None # type: Dict[str, Any]
self.globals = None # type: Dict[str, Any]
self.logger = None # type: Logger
self.locals = None # type: Optional[Dict[str, Any]]
self.globals = None # type: Optional[Dict[str, Any]]
self.logger = None # type: Optional[Logger]
# Sometimes, for event callback, it is useful
# to have access to the parent object
self.parent = None # type: Optional[BaseCallback]

View file

@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Type, Dict, List, Tuple
from itertools import zip_longest
@ -14,14 +14,24 @@ class BasePolicy(nn.Module):
:param observation_space: (gym.spaces.Space) The observation space of the environment
:param action_space: (gym.spaces.Space) The action space of the environment
:param device: (Union[th.device, str]) Device on which the code should run.
:param squash_output: (bool) For continuous actions, whether the output is squashed
or not using a `tanh()` function.
"""
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space, device: Union[th.device, str] = 'cpu'):
action_space: gym.spaces.Space,
device: Union[th.device, str] = 'cpu',
squash_output: bool = False):
super(BasePolicy, self).__init__()
self.observation_space = observation_space
self.action_space = action_space
self.device = device
self._squash_output = squash_output
@property
def squash_output(self) -> bool:
""" (bool) Getter for squash_output."""
return self._squash_output
@staticmethod
def init_weights(module: nn.Module, gain: float = 1):
@ -71,21 +81,25 @@ class BasePolicy(nn.Module):
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()
def create_mlp(input_dim, output_dim, net_arch,
activation_fn=nn.ReLU, squash_out=False):
def create_mlp(input_dim: int,
output_dim: int,
net_arch: List[int],
activation_fn: nn.Module = nn.ReLU,
squash_output: bool = False) -> List[nn.Module]:
"""
Create a multi layer perceptron (MLP), which is
a collection of fully-connected layers each followed by an activation function.
:param input_dim: (int) Dimension of the input vector
:param output_dim: (int)
:param net_arch: ([int]) Architecture of the neural net
:param net_arch: (List[int]) Architecture of the neural net
It represents the number of units per layer.
The length of this list is the number of layers.
:param activation_fn: (th.nn.Module) The activation function
:param activation_fn: (nn.Module) The activation function
to use after each layer.
:param squash_out: (bool) Whether to squash the output using a Tanh
:param squash_output: (bool) Whether to squash the output using a Tanh
activation function
:return: (List[nn.Module])
"""
if len(net_arch) > 0:
@ -99,12 +113,14 @@ def create_mlp(input_dim, output_dim, net_arch,
if output_dim > 0:
modules.append(nn.Linear(net_arch[-1], output_dim))
if squash_out:
if squash_output:
modules.append(nn.Tanh())
return modules
def create_sde_feature_extractor(features_dim, sde_net_arch, activation_fn):
def create_sde_feature_extractor(features_dim: int,
sde_net_arch: List[int],
activation_fn: nn.Module) -> Tuple[nn.Sequential, int]:
"""
Create the neural network that will be used to extract features
for the SDE.
@ -117,7 +133,7 @@ def create_sde_feature_extractor(features_dim, sde_net_arch, activation_fn):
# Special case: when using states as features (i.e. sde_net_arch is an empty list)
# don't use any activation function
sde_activation = activation_fn if len(sde_net_arch) > 0 else None
latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_out=False)
latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_output=False)
latent_sde_dim = sde_net_arch[-1] if len(sde_net_arch) > 0 else features_dim
sde_feature_extractor = nn.Sequential(*latent_sde_net)
return sde_feature_extractor, latent_sde_dim
@ -131,7 +147,7 @@ class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
def load_from_vector(self, vector):
def load_from_vector(self, vector: np.ndarray):
"""
Load parameters from a 1D vector.
@ -140,7 +156,7 @@ class BaseNetwork(nn.Module):
device = next(self.parameters()).device
th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(device), self.parameters())
def parameters_to_vector(self):
def parameters_to_vector(self) -> np.ndarray:
"""
Convert the parameters to a 1D vector.
@ -149,16 +165,16 @@ class BaseNetwork(nn.Module):
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()
_policy_registry = dict()
_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]]
def get_policy_from_name(base_policy_type, name):
def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]:
"""
returns the registed policy from the base type and name
Returns the registered policy from the base type and name
:param base_policy_type: (BasePolicy) the base policy object
:param base_policy_type: (Type[BasePolicy]) the base policy class
:param name: (str) the policy name
:return: (base_policy_type) the policy
:return: (Type[BasePolicy]) the policy
"""
if base_policy_type not in _policy_registry:
raise ValueError(f"Error: the policy type {base_policy_type} is not registered!")
@ -168,12 +184,13 @@ def get_policy_from_name(base_policy_type, name):
return _policy_registry[base_policy_type][name]
def register_policy(name, policy):
def register_policy(name: str, policy: Type[BasePolicy]) -> None:
"""
returns the registed policy from the base type and name
Register a policy, so it can be called using its name.
e.g. SAC('MlpPolicy', ...) instead of SAC(MlpPolicy, ...)
:param name: (str) the policy name
:param policy: (subclass of BasePolicy) the policy
:param policy: (Type[BasePolicy]) the policy class
"""
sub_class = None
# For building the doc

View file

@ -41,7 +41,7 @@ class PPOPolicy(BasePolicy):
ortho_init=True, use_sde=False,
log_std_init=0.0, full_std=True,
sde_net_arch=None, use_expln=False, squash_output=False):
super(PPOPolicy, self).__init__(observation_space, action_space, device)
super(PPOPolicy, self).__init__(observation_space, action_space, device, squash_output=squash_output)
self.obs_dim = self.observation_space.shape[0]
# Default network architecture, from stable-baselines

View file

@ -200,7 +200,7 @@ class SACPolicy(BasePolicy):
learning_rate, net_arch=None, device='cpu',
activation_fn=nn.ReLU, use_sde=False,
log_std_init=-3, sde_net_arch=None, use_expln=False):
super(SACPolicy, self).__init__(observation_space, action_space, device)
super(SACPolicy, self).__init__(observation_space, action_space, device, squash_output=True)
if net_arch is None:
net_arch = [256, 256]

View file

@ -52,7 +52,7 @@ class Actor(BaseNetwork):
self.sde_feature_extractor = None
if use_sde:
latent_pi_net = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_out=False)
latent_pi_net = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_output=False)
self.latent_pi = nn.Sequential(*latent_pi_net)
latent_sde_dim = net_arch[-1]
learn_features = sde_net_arch is not None
@ -74,7 +74,7 @@ class Actor(BaseNetwork):
self.sde_optimizer = th.optim.Adam([self.log_std], lr=lr_sde)
self.reset_noise()
else:
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True)
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_output=True)
self.mu = nn.Sequential(*actor_net)
def get_std(self) -> torch.Tensor:
@ -134,7 +134,7 @@ class Actor(BaseNetwork):
if self.clip_noise is not None:
noise = th.clamp(noise, -self.clip_noise, self.clip_noise)
# TODO: Replace with squashing -> need to account for that in the sde update
# -> set squash_out=True in the action_dist?
# -> set squash_output=True in the action_dist?
# NOTE: the clipping is done in the rollout for now
return self.mu(latent_pi) + noise
# action, _ = self._get_action_dist_from_latent(latent_pi)
@ -215,7 +215,7 @@ class TD3Policy(BasePolicy):
learning_rate, net_arch=None, device='cpu',
activation_fn=nn.ReLU, use_sde=False, log_std_init=-3,
clip_noise=None, lr_sde=3e-4, sde_net_arch=None, use_expln=False):
super(TD3Policy, self).__init__(observation_space, action_space, device)
super(TD3Policy, self).__init__(observation_space, action_space, device, squash_output=True)
# Default network architecture, from the original paper
if net_arch is None: