mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
Add classic advantage computation
This commit is contained in:
parent
c15b4bda1e
commit
69a348276e
1 changed files with 34 additions and 14 deletions
|
|
@ -113,22 +113,42 @@ class RolloutBuffer(BaseBuffer):
|
|||
self.generator_ready = False
|
||||
super(RolloutBuffer, self).reset()
|
||||
|
||||
def compute_returns_and_advantage(self, last_value, dones=False):
|
||||
def compute_returns_and_advantage(self, last_value, dones=False, use_gae=True):
|
||||
"""
|
||||
From PPO2
|
||||
From Stable-Baselines PPO2
|
||||
:param last_value: (th.Tensor)
|
||||
:param dones: ([bool])
|
||||
:param use_gae: (bool) Whether to use Generalized Advantage Estimation
|
||||
or normal advantage for advantage computation.
|
||||
"""
|
||||
last_gae_lam = 0
|
||||
for step in reversed(range(self.buffer_size)):
|
||||
if step == self.buffer_size - 1:
|
||||
next_non_terminal = th.FloatTensor(1.0 - dones)
|
||||
next_value = last_value.clone().cpu().flatten()
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.dones[step + 1]
|
||||
next_value = self.values[step + 1]
|
||||
delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step]
|
||||
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
|
||||
self.advantages[step] = last_gae_lam
|
||||
self.returns = self.advantages + self.values
|
||||
if use_gae:
|
||||
last_gae_lam = 0
|
||||
for step in reversed(range(self.buffer_size)):
|
||||
if step == self.buffer_size - 1:
|
||||
next_non_terminal = th.FloatTensor(1.0 - dones)
|
||||
next_value = last_value.clone().cpu().flatten()
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.dones[step + 1]
|
||||
next_value = self.values[step + 1]
|
||||
delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step]
|
||||
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
|
||||
self.advantages[step] = last_gae_lam
|
||||
self.returns = self.advantages + self.values
|
||||
else:
|
||||
# Discounted return with value bootstrap
|
||||
# Note: this is equivalent to GAE computation
|
||||
# with gae_lambda = 1.0
|
||||
last_return = 0.0
|
||||
for step in reversed(range(self.buffer_size)):
|
||||
if step == self.buffer_size - 1:
|
||||
next_non_terminal = th.FloatTensor(1.0 - dones)
|
||||
next_value = last_value.clone().cpu().flatten()
|
||||
last_return = self.rewards[step] + next_non_terminal * next_value
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.dones[step + 1]
|
||||
last_return = self.rewards[step] + self.gamma * last_return * next_non_terminal
|
||||
self.returns[step] = last_return
|
||||
self.advantages = self.returns - self.values
|
||||
|
||||
def add(self, obs, action, reward, done, value, log_prob):
|
||||
if len(log_prob.shape) == 0:
|
||||
|
|
|
|||
Loading…
Reference in a new issue