mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-28 22:56:53 +00:00
Fix off-by-one GAE computation (#185)
* Fix off-by-one GAE computation * Fix identity test * Revert gae loop
This commit is contained in:
parent
fc6c5d3daa
commit
fc9527157a
4 changed files with 14 additions and 8 deletions
|
|
@ -15,6 +15,7 @@ New Features:
|
|||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fix GAE computation for on-policy algorithms (off-by one for the last value) (thanks @Wovchena)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -299,7 +299,7 @@ class RolloutBuffer(BaseBuffer):
|
|||
self.generator_ready = False
|
||||
super(RolloutBuffer, self).reset()
|
||||
|
||||
def compute_returns_and_advantage(self, last_value: th.Tensor, dones: np.ndarray) -> None:
|
||||
def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
|
||||
"""
|
||||
Post-processing step: compute the returns (sum of discounted rewards)
|
||||
and GAE advantage.
|
||||
|
|
@ -310,22 +310,22 @@ class RolloutBuffer(BaseBuffer):
|
|||
where R is the discounted reward with value bootstrap,
|
||||
set ``gae_lambda=1.0`` during initialization.
|
||||
|
||||
:param last_value:
|
||||
:param last_values:
|
||||
:param dones:
|
||||
|
||||
"""
|
||||
# convert to numpy
|
||||
last_value = last_value.clone().cpu().numpy().flatten()
|
||||
last_values = last_values.clone().cpu().numpy().flatten()
|
||||
|
||||
last_gae_lam = 0
|
||||
for step in reversed(range(self.buffer_size)):
|
||||
if step == self.buffer_size - 1:
|
||||
next_non_terminal = 1.0 - dones
|
||||
next_value = last_value
|
||||
next_values = last_values
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.dones[step + 1]
|
||||
next_value = self.values[step + 1]
|
||||
delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step]
|
||||
next_values = self.values[step + 1]
|
||||
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
|
||||
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
|
||||
self.advantages[step] = last_gae_lam
|
||||
self.returns = self.advantages + self.values
|
||||
|
|
|
|||
|
|
@ -179,7 +179,12 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
self._last_obs = new_obs
|
||||
self._last_dones = dones
|
||||
|
||||
rollout_buffer.compute_returns_and_advantage(values, dones=dones)
|
||||
with th.no_grad():
|
||||
# Compute value for the last timestep
|
||||
obs_tensor = th.as_tensor(new_obs).to(self.device)
|
||||
_, values, _ = self.policy.forward(obs_tensor)
|
||||
|
||||
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def test_discrete(model_class, env):
|
|||
if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
|
||||
return
|
||||
|
||||
model = model_class("MlpPolicy", env_, gamma=0.5, seed=1, **kwargs).learn(n_steps)
|
||||
model = model_class("MlpPolicy", env_, gamma=0.4, seed=1, **kwargs).learn(n_steps)
|
||||
|
||||
evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90)
|
||||
obs = env.reset()
|
||||
|
|
|
|||
Loading…
Reference in a new issue