diff --git a/torchy_baselines/common/buffers.py b/torchy_baselines/common/buffers.py index 34a7098..7ca61b8 100644 --- a/torchy_baselines/common/buffers.py +++ b/torchy_baselines/common/buffers.py @@ -113,22 +113,42 @@ class RolloutBuffer(BaseBuffer): self.generator_ready = False super(RolloutBuffer, self).reset() - def compute_returns_and_advantage(self, last_value, dones=False): + def compute_returns_and_advantage(self, last_value, dones=False, use_gae=True): """ - From PPO2 + From Stable-Baselines PPO2 + :param last_value: (th.Tensor) + :param dones: ([bool]) + :param use_gae: (bool) Whether to use Generalized Advantage Estimation + or normal advantage for advantage computation. """ - last_gae_lam = 0 - for step in reversed(range(self.buffer_size)): - if step == self.buffer_size - 1: - next_non_terminal = th.FloatTensor(1.0 - dones) - next_value = last_value.clone().cpu().flatten() - else: - next_non_terminal = 1.0 - self.dones[step + 1] - next_value = self.values[step + 1] - delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step] - last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam - self.advantages[step] = last_gae_lam - self.returns = self.advantages + self.values + if use_gae: + last_gae_lam = 0 + for step in reversed(range(self.buffer_size)): + if step == self.buffer_size - 1: + next_non_terminal = th.FloatTensor(1.0 - dones) + next_value = last_value.clone().cpu().flatten() + else: + next_non_terminal = 1.0 - self.dones[step + 1] + next_value = self.values[step + 1] + delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step] + last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam + self.advantages[step] = last_gae_lam + self.returns = self.advantages + self.values + else: + # Discounted return with value bootstrap + # Note: this is equivalent to GAE computation + # with gae_lambda = 1.0 + last_return = 0.0 + for step in reversed(range(self.buffer_size)): + if step == self.buffer_size - 1: + next_non_terminal = th.FloatTensor(1.0 - dones) + next_value = last_value.clone().cpu().flatten() + last_return = self.rewards[step] + next_non_terminal * next_value + else: + next_non_terminal = 1.0 - self.dones[step + 1] + last_return = self.rewards[step] + self.gamma * last_return * next_non_terminal + self.returns[step] = last_return + self.advantages = self.returns - self.values def add(self, obs, action, reward, done, value, log_prob): if len(log_prob.shape) == 0: