diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index f3af499..10bba85 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -90,7 +90,7 @@ SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API: Note that if ``render_mode != "rgb_array"``, you can only call ``vec_env.render()`` (without argument or with ``mode=env.render_mode``). - the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator or pass options, - you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)`` and ``obs = vec_env.reset()`` afterward (seed and options are discared after each call to ``reset()``). + you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)`` and ``obs = vec_env.reset()`` afterward (seed and options are discarded after each call to ``reset()``). - methods and attributes of the underlying Gym envs can be accessed, called and set using ``vec_env.get_attr("attribute_name")``, ``vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)`` and ``vec_env.set_attr("attribute_name", new_value)``. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index ffe1d7e..19aeeec 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,8 +3,9 @@ Changelog ========== -Release 2.2.0a8 (WIP) +Release 2.2.0a9 (WIP) -------------------------- +**Support for options at reset, bug fixes and better error messages** Breaking Changes: ^^^^^^^^^^^^^^^^^ @@ -16,6 +17,8 @@ New Features: - Improved error message of the ``env_checker`` for env wrongly detected as GoalEnv (``compute_reward()`` is defined) - Improved error message when mixing Gym API with VecEnv API (see GH#1694) - Add support for setting ``options`` at reset with VecEnv via the ``set_options()`` method. Same as seeds logic, options are reset at the end of an episode (@ReHoss) +- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to on-policy algorithms (A2C and PPO) + Bug Fixes: ^^^^^^^^^^ @@ -36,9 +39,9 @@ Bug Fixes: `RL Zoo`_ ^^^^^^^^^ -`SBX`_ -^^^^^^^^^ -- Added ``DDPG`` and ``TD3`` +`SBX`_ (SB3 + Jax) +^^^^^^^^^^^^^^^^^^ +- Added ``DDPG`` and ``TD3`` algorithms Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index fda20c9..718571f 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -4,6 +4,7 @@ import torch as th from gymnasium import spaces from torch.nn import functional as F +from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -41,6 +42,8 @@ class A2C(OnPolicyAlgorithm): instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) + :param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected. + :param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation. :param normalize_advantage: Whether to normalize or not the advantage :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over @@ -75,6 +78,8 @@ class A2C(OnPolicyAlgorithm): use_rms_prop: bool = True, use_sde: bool = False, sde_sample_freq: int = -1, + rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, + rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, normalize_advantage: bool = False, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, @@ -96,6 +101,8 @@ class A2C(OnPolicyAlgorithm): max_grad_norm=max_grad_norm, use_sde=use_sde, sde_sample_freq=sde_sample_freq, + rollout_buffer_class=rollout_buffer_class, + rollout_buffer_kwargs=rollout_buffer_kwargs, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 4f9bb08..ddd0f8d 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -37,6 +37,8 @@ class OnPolicyAlgorithm(BaseAlgorithm): instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) + :param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected. + :param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation. :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) @@ -68,6 +70,8 @@ class OnPolicyAlgorithm(BaseAlgorithm): max_grad_norm: float, use_sde: bool, sde_sample_freq: int, + rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, + rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, monitor_wrapper: bool = True, @@ -100,6 +104,8 @@ class OnPolicyAlgorithm(BaseAlgorithm): self.ent_coef = ent_coef self.vf_coef = vf_coef self.max_grad_norm = max_grad_norm + self.rollout_buffer_class = rollout_buffer_class + self.rollout_buffer_kwargs = rollout_buffer_kwargs or {} if _init_setup_model: self._setup_model() @@ -108,9 +114,13 @@ class OnPolicyAlgorithm(BaseAlgorithm): self._setup_lr_schedule() self.set_random_seed(self.seed) - buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer + if self.rollout_buffer_class is None: + if isinstance(self.observation_space, spaces.Dict): + self.rollout_buffer_class = DictRolloutBuffer + else: + self.rollout_buffer_class = RolloutBuffer - self.rollout_buffer = buffer_cls( + self.rollout_buffer = self.rollout_buffer_class( self.n_steps, self.observation_space, # type: ignore[arg-type] self.action_space, @@ -118,6 +128,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs, + **self.rollout_buffer_kwargs, ) self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 9bd83b0..ea7cf5e 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -6,6 +6,7 @@ import torch as th from gymnasium import spaces from torch.nn import functional as F +from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -52,6 +53,8 @@ class PPO(OnPolicyAlgorithm): instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) + :param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected. + :param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation :param target_kl: Limit the KL divergence between updates, because the clipping is not enough to prevent large update see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) @@ -92,6 +95,8 @@ class PPO(OnPolicyAlgorithm): max_grad_norm: float = 0.5, use_sde: bool = False, sde_sample_freq: int = -1, + rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, + rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, target_kl: Optional[float] = None, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, @@ -113,6 +118,8 @@ class PPO(OnPolicyAlgorithm): max_grad_norm=max_grad_norm, use_sde=use_sde, sde_sample_freq=sde_sample_freq, + rollout_buffer_class=rollout_buffer_class, + rollout_buffer_kwargs=rollout_buffer_kwargs, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index f1f23b3..b7120ad 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.2.0a8 +2.2.0a9 diff --git a/tests/test_buffers.py b/tests/test_buffers.py index e7d4a1c..2ea366a 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -4,6 +4,7 @@ import pytest import torch as th from gymnasium import spaces +from stable_baselines3 import A2C from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env @@ -150,3 +151,16 @@ def test_device_buffer(replay_buffer_cls, device): assert value[key].device.type == desired_device elif isinstance(value, th.Tensor): assert value.device.type == desired_device + + +def test_custom_rollout_buffer(): + A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict()) + + with pytest.raises(TypeError, match="unexpected keyword argument 'wrong_keyword'"): + A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict(wrong_keyword=1)) + + with pytest.raises(TypeError, match="got multiple values for keyword argument 'gamma'"): + A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict(gamma=1)) + + with pytest.raises(AssertionError, match="DictRolloutBuffer must be used with Dict obs space only"): + A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=DictRolloutBuffer)