diff --git a/tests/test_sde.py b/tests/test_sde.py index 09b48e8..7143e9c 100644 --- a/tests/test_sde.py +++ b/tests/test_sde.py @@ -49,14 +49,15 @@ def test_state_dependent_exploration(): @pytest.mark.parametrize("model_class", [A2C]) -def test_state_dependent_noise(model_class): +@pytest.mark.parametrize("sde_net_arch", [None, [64, 64]]) +def test_state_dependent_noise(model_class, sde_net_arch): env_id = 'MountainCarContinuous-v0' env = VecNormalize(DummyVecEnv([lambda: Monitor(gym.make(env_id))]), norm_reward=True) eval_env = VecNormalize(DummyVecEnv([lambda: Monitor(gym.make(env_id))]), training=False, norm_reward=False) model = model_class('MlpPolicy', env, n_steps=200, use_sde=True, ent_coef=0.00, verbose=1, learning_rate=3e-4, - policy_kwargs=dict(log_std_init=0.0, ortho_init=False), seed=None) + policy_kwargs=dict(log_std_init=0.0, ortho_init=False, sde_net_arch=sde_net_arch), seed=None) model.learn(total_timesteps=int(1000), log_interval=5, eval_freq=500, eval_env=eval_env) diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index 68a1c44..4aa2dbe 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -113,6 +113,7 @@ class A2C(PPO): # Optimization step self.policy.optimizer.zero_grad() loss.backward() + # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() diff --git a/torchy_baselines/common/distributions.py b/torchy_baselines/common/distributions.py index f90883a..b0752e2 100644 --- a/torchy_baselines/common/distributions.py +++ b/torchy_baselines/common/distributions.py @@ -241,14 +241,17 @@ class StateDependentNoiseDistribution(Distribution): above zero and prevent it from growing too fast. In practice, `exp()` is usually enough. :param squash_output: (bool) Whether to squash the output using a tanh function, this allows to ensure boundaries. + :param learn_features: (bool) Whether to learn features for SDE or not. + This will enable gradients to be backpropagated through the features + `latent_sde` in the code. :param epsilon: (float) small value to avoid NaN due to numerical imprecision. """ def __init__(self, action_dim, full_std=True, use_expln=False, - squash_output=False, epsilon=1e-6): + squash_output=False, learn_features=False, epsilon=1e-6): super(StateDependentNoiseDistribution, self).__init__() self.distribution = None self.action_dim = action_dim - self.latent_dim = None + self.latent_sde_dim = None self.mean_actions = None self.log_std = None self.weights_dist = None @@ -256,6 +259,7 @@ class StateDependentNoiseDistribution(Distribution): self.use_expln = use_expln self.full_std = full_std self.epsilon = epsilon + self.learn_features = learn_features if squash_output: print("== Using TanhBijector ===") self.bijector = TanhBijector(epsilon) @@ -284,7 +288,7 @@ class StateDependentNoiseDistribution(Distribution): if self.full_std: return std # Reduce the number of parameters: - return th.ones((self.latent_dim, self.action_dim)).to(log_std.device) * std + return th.ones((self.latent_sde_dim, self.action_dim)).to(log_std.device) * std def sample_weights(self, log_std): """ @@ -297,29 +301,32 @@ class StateDependentNoiseDistribution(Distribution): self.weights_dist = Normal(th.zeros_like(std), std) self.exploration_mat = self.weights_dist.rsample() - def proba_distribution_net(self, latent_dim, log_std_init=-2.0): + def proba_distribution_net(self, latent_dim, log_std_init=-2.0, latent_sde_dim=None): """ Create the layers and parameter that represent the distribution: one output will be the deterministic action, the other parameter will be the standard deviation of the distribution that control the weights of the noise matrix. - :param latent_dim: (int) Dimension og the last layer of the policy (before the action layer) + :param latent_dim: (int) Dimension of the last layer of the policy (before the action layer) :param log_std_init: (float) Initial value for the log standard deviation + :param latent_sde_dim: (int) Dimension of the last layer of the feature extractor + for SDE. By default, it is shared with the policy network. :return: (nn.Linear, nn.Parameter) """ # Network for the deterministic action, it represents the mean of the distribution mean_actions_net = nn.Linear(latent_dim, self.action_dim) - - self.latent_dim = latent_dim + # When we learn features for the noise, the feature dimension + # can be different between the policy and the noise network + self.latent_sde_dim = latent_dim if latent_sde_dim is None else latent_sde_dim # Reduce the number of parameters if needed - log_std = th.ones(latent_dim, self.action_dim) if self.full_std else th.ones(latent_dim, 1) + log_std = th.ones(self.latent_sde_dim, self.action_dim) if self.full_std else th.ones(self.latent_sde_dim, 1) # Transform it to a parameter so it can be optimized log_std = nn.Parameter(log_std * log_std_init) # Sample an exploration matrix self.sample_weights(log_std) return mean_actions_net, log_std - def proba_distribution(self, mean_actions, log_std, latent_pi, deterministic=False): + def proba_distribution(self, mean_actions, log_std, latent_sde, deterministic=False): """ Create and sample for the distribution given its parameters (mean, std) @@ -328,13 +335,15 @@ class StateDependentNoiseDistribution(Distribution): :param deterministic: (bool) :return: (th.Tensor) """ - variance = th.mm(latent_pi.detach() ** 2, self.get_std(log_std) ** 2) + # Stop gradient if we don't want to influence the features + latent_sde = latent_sde if self.learn_features else latent_sde.detach() + variance = th.mm(latent_sde ** 2, self.get_std(log_std) ** 2) self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon)) if deterministic: action = self.mode() else: - action = self.sample(latent_pi) + action = self.sample(latent_sde) return action, self def mode(self): @@ -343,11 +352,12 @@ class StateDependentNoiseDistribution(Distribution): return self.bijector.forward(action) return action - def get_noise(self, latent_pi): - return th.mm(latent_pi.detach(), self.exploration_mat) + def get_noise(self, latent_sde): + latent_sde = latent_sde if self.learn_features else latent_sde.detach() + return th.mm(latent_sde, self.exploration_mat) - def sample(self, latent_pi): - noise = self.get_noise(latent_pi) + def sample(self, latent_sde): + noise = self.get_noise(latent_sde) action = self.distribution.mean + noise if self.bijector is not None: return self.bijector.forward(action) @@ -357,8 +367,8 @@ class StateDependentNoiseDistribution(Distribution): # TODO: account for the squashing? return self.distribution.entropy() - def log_prob_from_params(self, mean_actions, log_std, latent_pi): - action, _ = self.proba_distribution(mean_actions, log_std, latent_pi) + def log_prob_from_params(self, mean_actions, log_std, latent_sde): + action, _ = self.proba_distribution(mean_actions, log_std, latent_sde) log_prob = self.log_prob(action) return action, log_prob diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index 46fe75c..ed5c493 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -4,7 +4,7 @@ import torch as th import torch.nn as nn import numpy as np -from torchy_baselines.common.policies import BasePolicy, register_policy, MlpExtractor +from torchy_baselines.common.policies import BasePolicy, register_policy, MlpExtractor, create_mlp from torchy_baselines.common.distributions import make_proba_distribution,\ DiagGaussianDistribution, CategoricalDistribution, StateDependentNoiseDistribution @@ -30,7 +30,7 @@ class PPOPolicy(BasePolicy): learning_rate, net_arch=None, device='cpu', activation_fn=nn.Tanh, adam_epsilon=1e-5, ortho_init=True, use_sde=False, - log_std_init=0.0, full_std=True): + log_std_init=0.0, full_std=True, sde_net_arch=None): super(PPOPolicy, self).__init__(observation_space, action_space, device) self.obs_dim = self.observation_space.shape[0] @@ -61,9 +61,13 @@ class PPOPolicy(BasePolicy): dist_kwargs = { 'full_std': full_std, 'squash_output': False, - 'use_expln': False + 'use_expln': False, + 'learn_features': sde_net_arch is not None } + self.sde_feature_extractor = None + self.sde_net_arch = sde_net_arch + # Action distribution self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs) @@ -79,9 +83,20 @@ class PPOPolicy(BasePolicy): self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn, device=self.device) + # Separate feature extractor for SDE + if self.sde_net_arch is not None: + latent_sde = create_mlp(self.features_dim, -1, self.sde_net_arch, + activation_fn=self.activation_fn, squash_out=False) + self.sde_feature_extractor = nn.Sequential(*latent_sde) + if isinstance(self.action_dist, (DiagGaussianDistribution, StateDependentNoiseDistribution)): self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi, log_std_init=self.log_std_init) + elif isinstance(self.action_dist, StateDependentNoiseDistribution): + latent_sde_dim = self.mlp_extractor.latent_dim_pi if self.sde_net_arch is None else self.sde_net_arch[-1] + self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi, + latent_sde_dim=latent_sde_dim, + log_std_init=self.log_std_init) elif isinstance(self.action_dist, CategoricalDistribution): self.action_net = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi) @@ -102,16 +117,23 @@ class PPOPolicy(BasePolicy): def forward(self, obs, deterministic=False): if not isinstance(obs, th.Tensor): obs = th.FloatTensor(obs).to(self.device) - latent_pi, latent_vf = self._get_latent(obs) + latent_pi, latent_vf, latent_sde = self._get_latent(obs) value = self.value_net(latent_vf) - action, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic) + action, action_distribution = self._get_action_dist_from_latent(latent_pi, latent_sde=latent_sde, + deterministic=deterministic) log_prob = action_distribution.log_prob(action) return action, value, log_prob def _get_latent(self, obs): - return self.mlp_extractor(self.features_extractor(obs)) + features = self.features_extractor(obs) + latent_pi, latent_vf = self.mlp_extractor(features) + # Features for sde + latent_sde = latent_pi + if self.sde_feature_extractor is not None: + latent_sde = self.sde_feature_extractor(features) + return latent_pi, latent_vf, latent_sde - def _get_action_dist_from_latent(self, latent_pi, deterministic=False): + def _get_action_dist_from_latent(self, latent_pi, latent_sde=None, deterministic=False): mean_actions = self.action_net(latent_pi) if isinstance(self.action_dist, DiagGaussianDistribution): @@ -121,11 +143,11 @@ class PPOPolicy(BasePolicy): return self.action_dist.proba_distribution(mean_actions, deterministic=deterministic) elif isinstance(self.action_dist, StateDependentNoiseDistribution): - return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi, deterministic=deterministic) + return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde, deterministic=deterministic) def actor_forward(self, obs, deterministic=False): - latent_pi, _ = self._get_latent(obs) - action, _ = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic) + latent_pi, _, latent_sde = self._get_latent(obs) + action, _ = self._get_action_dist_from_latent(latent_pi, latent_sde, deterministic=deterministic) return action.detach().cpu().numpy() def evaluate_actions(self, obs, action, deterministic=False): @@ -139,14 +161,14 @@ class PPOPolicy(BasePolicy): :return: (th.Tensor, th.Tensor, th.Tensor) estimated value, log likelihood of taking those actions and entropy of the action distribution. """ - latent_pi, latent_vf = self._get_latent(obs) - _, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic) + latent_pi, latent_vf, latent_sde = self._get_latent(obs) + _, action_distribution = self._get_action_dist_from_latent(latent_pi, latent_sde, deterministic=deterministic) log_prob = action_distribution.log_prob(action) value = self.value_net(latent_vf) return value, log_prob, action_distribution.entropy() def value_forward(self, obs): - _, latent_vf = self._get_latent(obs) + _, latent_vf, _ = self._get_latent(obs) return self.value_net(latent_vf)