mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-09 00:31:14 +00:00
Add rollout_buffer_class parameter to on-policy algorithms (#1720)
* Add rollout_buffer_class and rollout_buffer_kwargs parameters to OnPolicyAlgorithm * Add rollout_buffer_class and rollout_buffer_kwargs to PPO. * Add rollout_buffer_class and rollout_buffer_kwargs to A2C. * Make use of the rollout buffer kwargs. * Update version * Add test and update doc --------- Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
parent
f56ddeda10
commit
69afefc91d
7 changed files with 50 additions and 8 deletions
|
|
@ -90,7 +90,7 @@ SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API:
|
|||
Note that if ``render_mode != "rgb_array"``, you can only call ``vec_env.render()`` (without argument or with ``mode=env.render_mode``).
|
||||
|
||||
- the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator or pass options,
|
||||
you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)`` and ``obs = vec_env.reset()`` afterward (seed and options are discared after each call to ``reset()``).
|
||||
you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)`` and ``obs = vec_env.reset()`` afterward (seed and options are discarded after each call to ``reset()``).
|
||||
|
||||
- methods and attributes of the underlying Gym envs can be accessed, called and set using ``vec_env.get_attr("attribute_name")``,
|
||||
``vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)`` and ``vec_env.set_attr("attribute_name", new_value)``.
|
||||
|
|
|
|||
|
|
@ -3,8 +3,9 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.2.0a8 (WIP)
|
||||
Release 2.2.0a9 (WIP)
|
||||
--------------------------
|
||||
**Support for options at reset, bug fixes and better error messages**
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
|
@ -16,6 +17,8 @@ New Features:
|
|||
- Improved error message of the ``env_checker`` for env wrongly detected as GoalEnv (``compute_reward()`` is defined)
|
||||
- Improved error message when mixing Gym API with VecEnv API (see GH#1694)
|
||||
- Add support for setting ``options`` at reset with VecEnv via the ``set_options()`` method. Same as seeds logic, options are reset at the end of an episode (@ReHoss)
|
||||
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to on-policy algorithms (A2C and PPO)
|
||||
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -36,9 +39,9 @@ Bug Fixes:
|
|||
`RL Zoo`_
|
||||
^^^^^^^^^
|
||||
|
||||
`SBX`_
|
||||
^^^^^^^^^
|
||||
- Added ``DDPG`` and ``TD3``
|
||||
`SBX`_ (SB3 + Jax)
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
- Added ``DDPG`` and ``TD3`` algorithms
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import torch as th
|
|||
from gymnasium import spaces
|
||||
from torch.nn import functional as F
|
||||
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
|
|
@ -41,6 +42,8 @@ class A2C(OnPolicyAlgorithm):
|
|||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
|
||||
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation.
|
||||
:param normalize_advantage: Whether to normalize or not the advantage
|
||||
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
|
||||
the reported success rate, mean episode length, and mean reward over
|
||||
|
|
@ -75,6 +78,8 @@ class A2C(OnPolicyAlgorithm):
|
|||
use_rms_prop: bool = True,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_advantage: bool = False,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
|
|
@ -96,6 +101,8 @@ class A2C(OnPolicyAlgorithm):
|
|||
max_grad_norm=max_grad_norm,
|
||||
use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq,
|
||||
rollout_buffer_class=rollout_buffer_class,
|
||||
rollout_buffer_kwargs=rollout_buffer_kwargs,
|
||||
stats_window_size=stats_window_size,
|
||||
tensorboard_log=tensorboard_log,
|
||||
policy_kwargs=policy_kwargs,
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
|
||||
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation.
|
||||
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
|
||||
the reported success rate, mean episode length, and mean reward over
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
|
|
@ -68,6 +70,8 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
max_grad_norm: float,
|
||||
use_sde: bool,
|
||||
sde_sample_freq: int,
|
||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
monitor_wrapper: bool = True,
|
||||
|
|
@ -100,6 +104,8 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
self.ent_coef = ent_coef
|
||||
self.vf_coef = vf_coef
|
||||
self.max_grad_norm = max_grad_norm
|
||||
self.rollout_buffer_class = rollout_buffer_class
|
||||
self.rollout_buffer_kwargs = rollout_buffer_kwargs or {}
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
|
@ -108,9 +114,13 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer
|
||||
if self.rollout_buffer_class is None:
|
||||
if isinstance(self.observation_space, spaces.Dict):
|
||||
self.rollout_buffer_class = DictRolloutBuffer
|
||||
else:
|
||||
self.rollout_buffer_class = RolloutBuffer
|
||||
|
||||
self.rollout_buffer = buffer_cls(
|
||||
self.rollout_buffer = self.rollout_buffer_class(
|
||||
self.n_steps,
|
||||
self.observation_space, # type: ignore[arg-type]
|
||||
self.action_space,
|
||||
|
|
@ -118,6 +128,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
gamma=self.gamma,
|
||||
gae_lambda=self.gae_lambda,
|
||||
n_envs=self.n_envs,
|
||||
**self.rollout_buffer_kwargs,
|
||||
)
|
||||
self.policy = self.policy_class( # type: ignore[assignment]
|
||||
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import torch as th
|
|||
from gymnasium import spaces
|
||||
from torch.nn import functional as F
|
||||
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
|
|
@ -52,6 +53,8 @@ class PPO(OnPolicyAlgorithm):
|
|||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
|
||||
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
|
||||
:param target_kl: Limit the KL divergence between updates,
|
||||
because the clipping is not enough to prevent large update
|
||||
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
||||
|
|
@ -92,6 +95,8 @@ class PPO(OnPolicyAlgorithm):
|
|||
max_grad_norm: float = 0.5,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
target_kl: Optional[float] = None,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
|
|
@ -113,6 +118,8 @@ class PPO(OnPolicyAlgorithm):
|
|||
max_grad_norm=max_grad_norm,
|
||||
use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq,
|
||||
rollout_buffer_class=rollout_buffer_class,
|
||||
rollout_buffer_kwargs=rollout_buffer_kwargs,
|
||||
stats_window_size=stats_window_size,
|
||||
tensorboard_log=tensorboard_log,
|
||||
policy_kwargs=policy_kwargs,
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.2.0a8
|
||||
2.2.0a9
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import pytest
|
|||
import torch as th
|
||||
from gymnasium import spaces
|
||||
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
|
|
@ -150,3 +151,16 @@ def test_device_buffer(replay_buffer_cls, device):
|
|||
assert value[key].device.type == desired_device
|
||||
elif isinstance(value, th.Tensor):
|
||||
assert value.device.type == desired_device
|
||||
|
||||
|
||||
def test_custom_rollout_buffer():
|
||||
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict())
|
||||
|
||||
with pytest.raises(TypeError, match="unexpected keyword argument 'wrong_keyword'"):
|
||||
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict(wrong_keyword=1))
|
||||
|
||||
with pytest.raises(TypeError, match="got multiple values for keyword argument 'gamma'"):
|
||||
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict(gamma=1))
|
||||
|
||||
with pytest.raises(AssertionError, match="DictRolloutBuffer must be used with Dict obs space only"):
|
||||
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=DictRolloutBuffer)
|
||||
|
|
|
|||
Loading…
Reference in a new issue