diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index 6169943..cd54497 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -181,8 +181,8 @@ class SAC(BaseRLModel): # is lost and we cannot backpropagate through again # anyway, we need to sample because `log_std` may have changed between two gradient steps if self.use_sde: - # self.actor.reset_noise(batch_size=batch_size) - self.actor.reset_noise() + self.actor.reset_noise(batch_size=batch_size) + # self.actor.reset_noise() # Action by the current actor for the sampled state action_pi, log_prob = self.actor.action_log_prob(obs)