From 3b68dc731219f112ccc2a6745f216bca701080bb Mon Sep 17 00:00:00 2001 From: Shyamal H Anadkat Date: Thu, 25 Nov 2021 04:53:42 -0500 Subject: [PATCH] Update GAE computation docstring (#655) * Fix typo in buffers.py * Revert "Fix typo in buffers.py" This reverts commit ca643d5e3a509ae1b8a65bf0de98f4609ca9d8da. * Ignore pytype errors * Update GAE computation docstring Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 1 + setup.cfg | 1 + stable_baselines3/common/buffers.py | 9 ++++----- stable_baselines3/common/vec_env/stacked_observations.py | 1 + 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 4eb697a..2de4979 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -34,6 +34,7 @@ Documentation: - Add highway-env to projects page (@eleurent) - Add tactile-gym to projects page (@ac-93) - Fix indentation in the RL tips page (@cove9988) +- Update GAE computation docstring Release 1.3.0 (2021-10-23) diff --git a/setup.cfg b/setup.cfg index 7bfd321..73ae3db 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,6 +16,7 @@ filterwarnings = [pytype] inputs = stable_baselines3 +disable = pyi-error [flake8] ignore = W503,W504,E203,E231 # line breaks before and after binary operators diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 7530d47..77d6c36 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -352,9 +352,9 @@ class RolloutBuffer(BaseBuffer): and GAE(lambda) advantage. Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) - to compute the advantage. To obtain vanilla advantage (A(s) = R - V(S)) - where R is the discounted reward with value bootstrap, - set ``gae_lambda=1.0`` during initialization. + to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S)) + where R is the sum of discounted reward with value bootstrap + (because we don't always have full episode), set ``gae_lambda=1.0`` during initialization. The TD(lambda) estimator has also two special cases: - TD(1) is Monte-Carlo estimate (sum of discounted rewards) @@ -364,7 +364,6 @@ class RolloutBuffer(BaseBuffer): :param last_values: state value estimation for the last step (one for each env) :param dones: if the last step was a terminal step (one bool for each env). - """ # Convert to numpy last_values = last_values.clone().cpu().numpy().flatten() @@ -623,7 +622,7 @@ class DictRolloutBuffer(RolloutBuffer): :param action_space: Action space :param device: :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator - Equivalent to classic advantage when set to 1. + Equivalent to Monte-Carlo advantage estimate when set to 1. :param gamma: Discount factor :param n_envs: Number of parallel environments """ diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index 513d84a..956231f 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -18,6 +18,7 @@ class StackedObservations(object): :param num_envs: number of environments :param n_stack: Number of frames to stack + :param observation_space: Environment observation space. :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. If None, automatically detect channel to stack over in case of image observation or default to "last" (default). """