mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-30 03:38:13 +00:00
Add sticky actions for Atari games (#1286)
* repeat_action_probability * Add test * Undo atari wrapper doc change since CI fails * remove action_repeat_probability from make_atari_env * Add sticky action wrapper and improve documentation * Update changelog * handle the case noop_max=0 * Update tests * Comply to ALE implementation * Reorder doc * Add doc warning and don't wrap with sticky action when not needed * fix docstring and reorder * Move `action_repeat_probability` args at the last position * Add ref * Update doc and wrap with frameskip only if needed * Update changelog Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
637988c9cc
commit
5ee9009535
4 changed files with 110 additions and 41 deletions
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.8.0a2 (WIP)
|
||||
Release 1.8.0a3 (WIP)
|
||||
--------------------------
|
||||
|
||||
|
||||
|
|
@ -14,6 +14,8 @@ Breaking Changes:
|
|||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``repeat_action_probability`` argument in ``AtariWrapper``.
|
||||
- Only use ``NoopResetEnv`` and ``MaxAndSkipEnv`` when needed in ``AtariWrapper``
|
||||
|
||||
`SB3-Contrib`_
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -12,13 +12,39 @@ except ImportError:
|
|||
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
|
||||
|
||||
|
||||
class StickyActionEnv(gym.Wrapper):
|
||||
"""
|
||||
Sticky action.
|
||||
|
||||
Paper: https://arxiv.org/abs/1709.06009
|
||||
Official implementation: https://github.com/mgbellemare/Arcade-Learning-Environment
|
||||
|
||||
:param env: Environment to wrap
|
||||
:param action_repeat_probability: Probability of repeating the last action
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, action_repeat_probability: float) -> None:
|
||||
super().__init__(env)
|
||||
self.action_repeat_probability = action_repeat_probability
|
||||
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
|
||||
|
||||
def reset(self, **kwargs) -> GymObs:
|
||||
self._sticky_action = 0 # NOOP
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
def step(self, action: int) -> GymStepReturn:
|
||||
if self.np_random.random() >= self.action_repeat_probability:
|
||||
self._sticky_action = action
|
||||
return self.env.step(self._sticky_action)
|
||||
|
||||
|
||||
class NoopResetEnv(gym.Wrapper):
|
||||
"""
|
||||
Sample initial states by taking random number of no-ops on reset.
|
||||
No-op is assumed to be action 0.
|
||||
|
||||
:param env: the environment to wrap
|
||||
:param noop_max: the maximum value of no-ops to run
|
||||
:param env: Environment to wrap
|
||||
:param noop_max: Maximum value of no-ops to run
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
|
||||
|
|
@ -47,7 +73,7 @@ class FireResetEnv(gym.Wrapper):
|
|||
"""
|
||||
Take action on reset for environments that are fixed until firing.
|
||||
|
||||
:param env: the environment to wrap
|
||||
:param env: Environment to wrap
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env) -> None:
|
||||
|
|
@ -71,7 +97,7 @@ class EpisodicLifeEnv(gym.Wrapper):
|
|||
Make end-of-life == end-of-episode, but only reset on true game over.
|
||||
Done by DeepMind for the DQN and co. since it helps value estimation.
|
||||
|
||||
:param env: the environment to wrap
|
||||
:param env: Environment to wrap
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env) -> None:
|
||||
|
|
@ -120,9 +146,11 @@ class EpisodicLifeEnv(gym.Wrapper):
|
|||
class MaxAndSkipEnv(gym.Wrapper):
|
||||
"""
|
||||
Return only every ``skip``-th frame (frameskipping)
|
||||
and return the max between the two last frames.
|
||||
|
||||
:param env: the environment
|
||||
:param skip: number of ``skip``-th frame
|
||||
:param env: Environment to wrap
|
||||
:param skip: Number of ``skip``-th frame
|
||||
The same action will be taken ``skip`` times.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, skip: int = 4) -> None:
|
||||
|
|
@ -159,9 +187,9 @@ class MaxAndSkipEnv(gym.Wrapper):
|
|||
|
||||
class ClipRewardEnv(gym.RewardWrapper):
|
||||
"""
|
||||
Clips the reward to {+1, 0, -1} by its sign.
|
||||
Clip the reward to {+1, 0, -1} by its sign.
|
||||
|
||||
:param env: the environment
|
||||
:param env: Environment to wrap
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env) -> None:
|
||||
|
|
@ -182,9 +210,9 @@ class WarpFrame(gym.ObservationWrapper):
|
|||
Convert to grayscale and warp frames to 84x84 (default)
|
||||
as done in the Nature paper and later work.
|
||||
|
||||
:param env: the environment
|
||||
:param width:
|
||||
:param height:
|
||||
:param env: Environment to wrap
|
||||
:param width: New frame width
|
||||
:param height: New frame height
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None:
|
||||
|
|
@ -213,20 +241,29 @@ class AtariWrapper(gym.Wrapper):
|
|||
|
||||
Specifically:
|
||||
|
||||
* NoopReset: obtain initial state by taking random number of no-ops on reset.
|
||||
* Noop reset: obtain initial state by taking random number of no-ops on reset.
|
||||
* Frame skipping: 4 by default
|
||||
* Max-pooling: most recent two observations
|
||||
* Termination signal when a life is lost.
|
||||
* Resize to a square image: 84x84 by default
|
||||
* Grayscale observation
|
||||
* Clip reward to {-1, 0, 1}
|
||||
* Sticky actions: disabled by default
|
||||
|
||||
:param env: gym environment
|
||||
:param noop_max: max number of no-ops
|
||||
:param frame_skip: the frequency at which the agent experiences the game.
|
||||
:param screen_size: resize Atari frame
|
||||
:param terminal_on_life_loss: if True, then step() returns done=True whenever a life is lost.
|
||||
See https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/
|
||||
for a visual explanation.
|
||||
|
||||
.. warning::
|
||||
Use this wrapper only with Atari v4 without frame skip: ``env_id = "*NoFrameskip-v4"``.
|
||||
|
||||
:param env: Environment to wrap
|
||||
:param noop_max: Max number of no-ops
|
||||
:param frame_skip: Frequency at which the agent experiences the game.
|
||||
This correspond to repeating the action ``frame_skip`` times.
|
||||
:param screen_size: Resize Atari frame
|
||||
:param terminal_on_life_loss: If True, then step() returns done=True whenever a life is lost.
|
||||
:param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign.
|
||||
:param action_repeat_probability: Probability of repeating the last action
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -237,9 +274,15 @@ class AtariWrapper(gym.Wrapper):
|
|||
screen_size: int = 84,
|
||||
terminal_on_life_loss: bool = True,
|
||||
clip_reward: bool = True,
|
||||
action_repeat_probability: float = 0.0,
|
||||
) -> None:
|
||||
env = NoopResetEnv(env, noop_max=noop_max)
|
||||
env = MaxAndSkipEnv(env, skip=frame_skip)
|
||||
if action_repeat_probability > 0.0:
|
||||
env = StickyActionEnv(env, action_repeat_probability)
|
||||
if noop_max > 0:
|
||||
env = NoopResetEnv(env, noop_max=noop_max)
|
||||
# frame_skip=1 is the same as no frame-skip (action repeat)
|
||||
if frame_skip > 1:
|
||||
env = MaxAndSkipEnv(env, skip=frame_skip)
|
||||
if terminal_on_life_loss:
|
||||
env = EpisodicLifeEnv(env)
|
||||
if "FIRE" in env.unwrapped.get_action_meanings():
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.8.0a2
|
||||
1.8.0a3
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from gym import spaces
|
|||
|
||||
import stable_baselines3 as sb3
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.common.atari_wrappers import ClipRewardEnv, MaxAndSkipEnv
|
||||
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
|
||||
from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
|
|
@ -55,30 +55,54 @@ def test_make_vec_env_func_checker():
|
|||
env.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4"])
|
||||
@pytest.mark.parametrize("n_envs", [1, 2])
|
||||
@pytest.mark.parametrize("wrapper_kwargs", [None, dict(clip_reward=False, screen_size=60)])
|
||||
def test_make_atari_env(env_id, n_envs, wrapper_kwargs):
|
||||
env = make_atari_env(env_id, n_envs, wrapper_kwargs=wrapper_kwargs, monitor_dir=None, seed=0)
|
||||
# Use Asterix as it does not requires fire reset
|
||||
@pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4", "AsterixNoFrameskip-v4"])
|
||||
@pytest.mark.parametrize("noop_max", [0, 10])
|
||||
@pytest.mark.parametrize("action_repeat_probability", [0.0, 0.25])
|
||||
@pytest.mark.parametrize("frame_skip", [1, 4])
|
||||
@pytest.mark.parametrize("screen_size", [60])
|
||||
@pytest.mark.parametrize("terminal_on_life_loss", [True, False])
|
||||
@pytest.mark.parametrize("clip_reward", [True])
|
||||
def test_make_atari_env(
|
||||
env_id, noop_max, action_repeat_probability, frame_skip, screen_size, terminal_on_life_loss, clip_reward
|
||||
):
|
||||
n_envs = 2
|
||||
wrapper_kwargs = {
|
||||
"noop_max": noop_max,
|
||||
"action_repeat_probability": action_repeat_probability,
|
||||
"frame_skip": frame_skip,
|
||||
"screen_size": screen_size,
|
||||
"terminal_on_life_loss": terminal_on_life_loss,
|
||||
"clip_reward": clip_reward,
|
||||
}
|
||||
venv = make_atari_env(
|
||||
env_id,
|
||||
n_envs=2,
|
||||
wrapper_kwargs=wrapper_kwargs,
|
||||
monitor_dir=None,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert env.num_envs == n_envs
|
||||
assert venv.num_envs == n_envs
|
||||
|
||||
obs = env.reset()
|
||||
needs_fire_reset = env_id == "BreakoutNoFrameskip-v4"
|
||||
expected_frame_number_low = frame_skip * 2 if needs_fire_reset else 0 # FIRE - UP on reset
|
||||
expected_frame_number_high = expected_frame_number_low + noop_max
|
||||
expected_shape = (n_envs, screen_size, screen_size, 1)
|
||||
|
||||
new_obs, reward, _, _ = env.step([env.action_space.sample() for _ in range(n_envs)])
|
||||
obs = venv.reset()
|
||||
frame_numbers = [env.unwrapped.ale.getEpisodeFrameNumber() for env in venv.envs]
|
||||
for frame_number in frame_numbers:
|
||||
assert expected_frame_number_low <= frame_number <= expected_frame_number_high
|
||||
assert obs.shape == expected_shape
|
||||
|
||||
assert obs.shape == new_obs.shape
|
||||
new_obs, reward, _, _ = venv.step([venv.action_space.sample() for _ in range(n_envs)])
|
||||
|
||||
# Wrapped into DummyVecEnv
|
||||
wrapped_atari_env = env.envs[0]
|
||||
if wrapper_kwargs is not None:
|
||||
assert obs.shape == (n_envs, 60, 60, 1)
|
||||
assert wrapped_atari_env.observation_space.shape == (60, 60, 1)
|
||||
assert not isinstance(wrapped_atari_env.env, ClipRewardEnv)
|
||||
else:
|
||||
assert obs.shape == (n_envs, 84, 84, 1)
|
||||
assert wrapped_atari_env.observation_space.shape == (84, 84, 1)
|
||||
assert isinstance(wrapped_atari_env.env, ClipRewardEnv)
|
||||
new_frame_numbers = [env.unwrapped.ale.getEpisodeFrameNumber() for env in venv.envs]
|
||||
for frame_number, new_frame_number in zip(frame_numbers, new_frame_numbers):
|
||||
assert new_frame_number - frame_number == frame_skip
|
||||
assert new_obs.shape == expected_shape
|
||||
if clip_reward:
|
||||
assert np.max(np.abs(reward)) < 1.0
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue