From a2a8bbdf11983ed3244c804797a14e83c06796dc Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 2 Dec 2019 11:48:34 +0100 Subject: [PATCH] Sample n matrices for A2C/PPO when using SDE --- torchy_baselines/ppo/policies.py | 14 ++++++++++---- torchy_baselines/ppo/ppo.py | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index a0abfb4..50dbeea 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -73,11 +73,13 @@ class PPOPolicy(BasePolicy): self._build(learning_rate) - def reset_noise_net(self): + def reset_noise_net(self, n_envs=1): """ Sample new weights for the exploration matrix. + + :param n_envs: (int) """ - self.action_dist.sample_weights(self.log_std) + self.action_dist.sample_weights(self.log_std, batch_size=n_envs) def _build(self, learning_rate): self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch, @@ -85,15 +87,19 @@ class PPOPolicy(BasePolicy): # Separate feature extractor for SDE if self.sde_net_arch is not None: + # Special case: when using states as features (i.e. sde_net_arch is an empty list) + # don't use any activation function + sde_activation = self.activation_fn if len(self.sde_net_arch) > 0 else None latent_sde = create_mlp(self.features_dim, -1, self.sde_net_arch, - activation_fn=self.activation_fn, squash_out=False) + activation_fn=sde_activation, squash_out=False) self.sde_feature_extractor = nn.Sequential(*latent_sde) + latent_sde_dim = self.sde_net_arch[-1] if len(self.sde_net_arch) > 0 else self.features_dim if isinstance(self.action_dist, DiagGaussianDistribution): self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi, log_std_init=self.log_std_init) elif isinstance(self.action_dist, StateDependentNoiseDistribution): - latent_sde_dim = self.mlp_extractor.latent_dim_pi if self.sde_net_arch is None else self.sde_net_arch[-1] + latent_sde_dim = self.mlp_extractor.latent_dim_pi if self.sde_net_arch is None else latent_sde_dim self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi, latent_sde_dim=latent_sde_dim, log_std_init=self.log_std_init) diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 4257c1e..99ed1b4 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -157,7 +157,7 @@ class PPO(BaseRLModel): # Sample new weights for the state dependent exploration # TODO: ensure episodic setting? if self.use_sde: - self.policy.reset_noise_net() + self.policy.reset_noise_net(env.num_envs) while n_steps < n_rollout_steps: with th.no_grad():