mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
Sample n matrices for A2C/PPO when using SDE
This commit is contained in:
parent
7a6a500398
commit
a2a8bbdf11
2 changed files with 11 additions and 5 deletions
|
|
@ -73,11 +73,13 @@ class PPOPolicy(BasePolicy):
|
|||
|
||||
self._build(learning_rate)
|
||||
|
||||
def reset_noise_net(self):
|
||||
def reset_noise_net(self, n_envs=1):
|
||||
"""
|
||||
Sample new weights for the exploration matrix.
|
||||
|
||||
:param n_envs: (int)
|
||||
"""
|
||||
self.action_dist.sample_weights(self.log_std)
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
||||
|
||||
def _build(self, learning_rate):
|
||||
self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch,
|
||||
|
|
@ -85,15 +87,19 @@ class PPOPolicy(BasePolicy):
|
|||
|
||||
# Separate feature extractor for SDE
|
||||
if self.sde_net_arch is not None:
|
||||
# Special case: when using states as features (i.e. sde_net_arch is an empty list)
|
||||
# don't use any activation function
|
||||
sde_activation = self.activation_fn if len(self.sde_net_arch) > 0 else None
|
||||
latent_sde = create_mlp(self.features_dim, -1, self.sde_net_arch,
|
||||
activation_fn=self.activation_fn, squash_out=False)
|
||||
activation_fn=sde_activation, squash_out=False)
|
||||
self.sde_feature_extractor = nn.Sequential(*latent_sde)
|
||||
latent_sde_dim = self.sde_net_arch[-1] if len(self.sde_net_arch) > 0 else self.features_dim
|
||||
|
||||
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi,
|
||||
log_std_init=self.log_std_init)
|
||||
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||
latent_sde_dim = self.mlp_extractor.latent_dim_pi if self.sde_net_arch is None else self.sde_net_arch[-1]
|
||||
latent_sde_dim = self.mlp_extractor.latent_dim_pi if self.sde_net_arch is None else latent_sde_dim
|
||||
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi,
|
||||
latent_sde_dim=latent_sde_dim,
|
||||
log_std_init=self.log_std_init)
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ class PPO(BaseRLModel):
|
|||
# Sample new weights for the state dependent exploration
|
||||
# TODO: ensure episodic setting?
|
||||
if self.use_sde:
|
||||
self.policy.reset_noise_net()
|
||||
self.policy.reset_noise_net(env.num_envs)
|
||||
|
||||
while n_steps < n_rollout_steps:
|
||||
with th.no_grad():
|
||||
|
|
|
|||
Loading…
Reference in a new issue