Add classic advantage computation

This commit is contained in:
Antonin Raffin 2019-10-29 12:36:40 +01:00
parent c15b4bda1e
commit 69a348276e

View file

@ -113,22 +113,42 @@ class RolloutBuffer(BaseBuffer):
self.generator_ready = False
super(RolloutBuffer, self).reset()
def compute_returns_and_advantage(self, last_value, dones=False):
def compute_returns_and_advantage(self, last_value, dones=False, use_gae=True):
"""
From PPO2
From Stable-Baselines PPO2
:param last_value: (th.Tensor)
:param dones: ([bool])
:param use_gae: (bool) Whether to use Generalized Advantage Estimation
or normal advantage for advantage computation.
"""
last_gae_lam = 0
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
next_non_terminal = th.FloatTensor(1.0 - dones)
next_value = last_value.clone().cpu().flatten()
else:
next_non_terminal = 1.0 - self.dones[step + 1]
next_value = self.values[step + 1]
delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step]
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
self.advantages[step] = last_gae_lam
self.returns = self.advantages + self.values
if use_gae:
last_gae_lam = 0
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
next_non_terminal = th.FloatTensor(1.0 - dones)
next_value = last_value.clone().cpu().flatten()
else:
next_non_terminal = 1.0 - self.dones[step + 1]
next_value = self.values[step + 1]
delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step]
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
self.advantages[step] = last_gae_lam
self.returns = self.advantages + self.values
else:
# Discounted return with value bootstrap
# Note: this is equivalent to GAE computation
# with gae_lambda = 1.0
last_return = 0.0
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
next_non_terminal = th.FloatTensor(1.0 - dones)
next_value = last_value.clone().cpu().flatten()
last_return = self.rewards[step] + next_non_terminal * next_value
else:
next_non_terminal = 1.0 - self.dones[step + 1]
last_return = self.rewards[step] + self.gamma * last_return * next_non_terminal
self.returns[step] = last_return
self.advantages = self.returns - self.values
def add(self, obs, action, reward, done, value, log_prob):
if len(log_prob.shape) == 0: