Add rollout_buffer_class parameter to on-policy algorithms (#1720)

* Add rollout_buffer_class and rollout_buffer_kwargs parameters to OnPolicyAlgorithm

* Add rollout_buffer_class and rollout_buffer_kwargs to PPO.

* Add rollout_buffer_class and rollout_buffer_kwargs to A2C.

* Make use of the rollout buffer kwargs.

* Update version

* Add test and update doc

---------

Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
M. Ernestus 2023-10-27 17:36:24 +02:00 committed by GitHub
parent f56ddeda10
commit 69afefc91d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 50 additions and 8 deletions

View file

@ -90,7 +90,7 @@ SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API:
Note that if ``render_mode != "rgb_array"``, you can only call ``vec_env.render()`` (without argument or with ``mode=env.render_mode``).
- the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator or pass options,
you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)`` and ``obs = vec_env.reset()`` afterward (seed and options are discared after each call to ``reset()``).
you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)`` and ``obs = vec_env.reset()`` afterward (seed and options are discarded after each call to ``reset()``).
- methods and attributes of the underlying Gym envs can be accessed, called and set using ``vec_env.get_attr("attribute_name")``,
``vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)`` and ``vec_env.set_attr("attribute_name", new_value)``.

View file

@ -3,8 +3,9 @@
Changelog
==========
Release 2.2.0a8 (WIP)
Release 2.2.0a9 (WIP)
--------------------------
**Support for options at reset, bug fixes and better error messages**
Breaking Changes:
^^^^^^^^^^^^^^^^^
@ -16,6 +17,8 @@ New Features:
- Improved error message of the ``env_checker`` for env wrongly detected as GoalEnv (``compute_reward()`` is defined)
- Improved error message when mixing Gym API with VecEnv API (see GH#1694)
- Add support for setting ``options`` at reset with VecEnv via the ``set_options()`` method. Same as seeds logic, options are reset at the end of an episode (@ReHoss)
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to on-policy algorithms (A2C and PPO)
Bug Fixes:
^^^^^^^^^^
@ -36,9 +39,9 @@ Bug Fixes:
`RL Zoo`_
^^^^^^^^^
`SBX`_
^^^^^^^^^
- Added ``DDPG`` and ``TD3``
`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
- Added ``DDPG`` and ``TD3`` algorithms
Deprecations:
^^^^^^^^^^^^^

View file

@ -4,6 +4,7 @@ import torch as th
from gymnasium import spaces
from torch.nn import functional as F
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
@ -41,6 +42,8 @@ class A2C(OnPolicyAlgorithm):
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation.
:param normalize_advantage: Whether to normalize or not the advantage
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
@ -75,6 +78,8 @@ class A2C(OnPolicyAlgorithm):
use_rms_prop: bool = True,
use_sde: bool = False,
sde_sample_freq: int = -1,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
normalize_advantage: bool = False,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
@ -96,6 +101,8 @@ class A2C(OnPolicyAlgorithm):
max_grad_norm=max_grad_norm,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,

View file

@ -37,6 +37,8 @@ class OnPolicyAlgorithm(BaseAlgorithm):
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation.
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
@ -68,6 +70,8 @@ class OnPolicyAlgorithm(BaseAlgorithm):
max_grad_norm: float,
use_sde: bool,
sde_sample_freq: int,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
monitor_wrapper: bool = True,
@ -100,6 +104,8 @@ class OnPolicyAlgorithm(BaseAlgorithm):
self.ent_coef = ent_coef
self.vf_coef = vf_coef
self.max_grad_norm = max_grad_norm
self.rollout_buffer_class = rollout_buffer_class
self.rollout_buffer_kwargs = rollout_buffer_kwargs or {}
if _init_setup_model:
self._setup_model()
@ -108,9 +114,13 @@ class OnPolicyAlgorithm(BaseAlgorithm):
self._setup_lr_schedule()
self.set_random_seed(self.seed)
buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer
if self.rollout_buffer_class is None:
if isinstance(self.observation_space, spaces.Dict):
self.rollout_buffer_class = DictRolloutBuffer
else:
self.rollout_buffer_class = RolloutBuffer
self.rollout_buffer = buffer_cls(
self.rollout_buffer = self.rollout_buffer_class(
self.n_steps,
self.observation_space, # type: ignore[arg-type]
self.action_space,
@ -118,6 +128,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
gamma=self.gamma,
gae_lambda=self.gae_lambda,
n_envs=self.n_envs,
**self.rollout_buffer_kwargs,
)
self.policy = self.policy_class( # type: ignore[assignment]
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs

View file

@ -6,6 +6,7 @@ import torch as th
from gymnasium import spaces
from torch.nn import functional as F
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
@ -52,6 +53,8 @@ class PPO(OnPolicyAlgorithm):
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
:param target_kl: Limit the KL divergence between updates,
because the clipping is not enough to prevent large update
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
@ -92,6 +95,8 @@ class PPO(OnPolicyAlgorithm):
max_grad_norm: float = 0.5,
use_sde: bool = False,
sde_sample_freq: int = -1,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
target_kl: Optional[float] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
@ -113,6 +118,8 @@ class PPO(OnPolicyAlgorithm):
max_grad_norm=max_grad_norm,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,

View file

@ -1 +1 @@
2.2.0a8
2.2.0a9

View file

@ -4,6 +4,7 @@ import pytest
import torch as th
from gymnasium import spaces
from stable_baselines3 import A2C
from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
@ -150,3 +151,16 @@ def test_device_buffer(replay_buffer_cls, device):
assert value[key].device.type == desired_device
elif isinstance(value, th.Tensor):
assert value.device.type == desired_device
def test_custom_rollout_buffer():
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict())
with pytest.raises(TypeError, match="unexpected keyword argument 'wrong_keyword'"):
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict(wrong_keyword=1))
with pytest.raises(TypeError, match="got multiple values for keyword argument 'gamma'"):
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict(gamma=1))
with pytest.raises(AssertionError, match="DictRolloutBuffer must be used with Dict obs space only"):
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=DictRolloutBuffer)