Sample n matrices for A2C/PPO when using SDE

This commit is contained in:
Antonin Raffin 2019-12-02 11:48:34 +01:00
parent 7a6a500398
commit a2a8bbdf11
2 changed files with 11 additions and 5 deletions

View file

@ -73,11 +73,13 @@ class PPOPolicy(BasePolicy):
self._build(learning_rate)
def reset_noise_net(self):
def reset_noise_net(self, n_envs=1):
"""
Sample new weights for the exploration matrix.
:param n_envs: (int)
"""
self.action_dist.sample_weights(self.log_std)
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
def _build(self, learning_rate):
self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch,
@ -85,15 +87,19 @@ class PPOPolicy(BasePolicy):
# Separate feature extractor for SDE
if self.sde_net_arch is not None:
# Special case: when using states as features (i.e. sde_net_arch is an empty list)
# don't use any activation function
sde_activation = self.activation_fn if len(self.sde_net_arch) > 0 else None
latent_sde = create_mlp(self.features_dim, -1, self.sde_net_arch,
activation_fn=self.activation_fn, squash_out=False)
activation_fn=sde_activation, squash_out=False)
self.sde_feature_extractor = nn.Sequential(*latent_sde)
latent_sde_dim = self.sde_net_arch[-1] if len(self.sde_net_arch) > 0 else self.features_dim
if isinstance(self.action_dist, DiagGaussianDistribution):
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi,
log_std_init=self.log_std_init)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
latent_sde_dim = self.mlp_extractor.latent_dim_pi if self.sde_net_arch is None else self.sde_net_arch[-1]
latent_sde_dim = self.mlp_extractor.latent_dim_pi if self.sde_net_arch is None else latent_sde_dim
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi,
latent_sde_dim=latent_sde_dim,
log_std_init=self.log_std_init)

View file

@ -157,7 +157,7 @@ class PPO(BaseRLModel):
# Sample new weights for the state dependent exploration
# TODO: ensure episodic setting?
if self.use_sde:
self.policy.reset_noise_net()
self.policy.reset_noise_net(env.num_envs)
while n_steps < n_rollout_steps:
with th.no_grad():