diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index bc2c7d8..66f459d 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -56,7 +56,7 @@ class PPO(BaseRLModel): :param use_sde: (bool) Whether to use State Dependent Exploration (SDE) instead of action noise exploration (default: False) :param sde_sample_freq: (int) Sample a new noise matrix every n steps when using SDE - Default: -1 (only sample at the beginning of the rollout) + Default: -1 (only sample at the beginning of the rollout) :param target_kl: (float) Limit the KL divergence between updates, because the clipping is not enough to prevent large update see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) @@ -210,6 +210,12 @@ class PPO(BaseRLModel): # Convert discrete action for float to long action = action.long().flatten() + # Re-sample the noise matrix because the log_std has changed + # TODO: investigate why there is no issue with the gradient + # if that line is commented (as in SAC) + if self.use_sde: + self.policy.reset_noise(batch_size) + values, log_prob, entropy = self.policy.evaluate_actions(obs, action) values = values.flatten() # Normalize advantage