diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index da32a8e..7a4bb69 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -15,6 +15,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Fix GAE computation for on-policy algorithms (off-by one for the last value) (thanks @Wovchena) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 2063d6c..563ddea 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -299,7 +299,7 @@ class RolloutBuffer(BaseBuffer): self.generator_ready = False super(RolloutBuffer, self).reset() - def compute_returns_and_advantage(self, last_value: th.Tensor, dones: np.ndarray) -> None: + def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None: """ Post-processing step: compute the returns (sum of discounted rewards) and GAE advantage. @@ -310,22 +310,22 @@ class RolloutBuffer(BaseBuffer): where R is the discounted reward with value bootstrap, set ``gae_lambda=1.0`` during initialization. - :param last_value: + :param last_values: :param dones: """ # convert to numpy - last_value = last_value.clone().cpu().numpy().flatten() + last_values = last_values.clone().cpu().numpy().flatten() last_gae_lam = 0 for step in reversed(range(self.buffer_size)): if step == self.buffer_size - 1: next_non_terminal = 1.0 - dones - next_value = last_value + next_values = last_values else: next_non_terminal = 1.0 - self.dones[step + 1] - next_value = self.values[step + 1] - delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step] + next_values = self.values[step + 1] + delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step] last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam self.advantages[step] = last_gae_lam self.returns = self.advantages + self.values diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 85d08c8..4f2ef63 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -179,7 +179,12 @@ class OnPolicyAlgorithm(BaseAlgorithm): self._last_obs = new_obs self._last_dones = dones - rollout_buffer.compute_returns_and_advantage(values, dones=dones) + with th.no_grad(): + # Compute value for the last timestep + obs_tensor = th.as_tensor(new_obs).to(self.device) + _, values, _ = self.policy.forward(obs_tensor) + + rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) callback.on_rollout_end() diff --git a/tests/test_identity.py b/tests/test_identity.py index 38c6570..678f63c 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -23,7 +23,7 @@ def test_discrete(model_class, env): if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)): return - model = model_class("MlpPolicy", env_, gamma=0.5, seed=1, **kwargs).learn(n_steps) + model = model_class("MlpPolicy", env_, gamma=0.4, seed=1, **kwargs).learn(n_steps) evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90) obs = env.reset()