diff --git a/README.md b/README.md index b4461ef..b487bd6 100644 --- a/README.md +++ b/README.md @@ -126,13 +126,15 @@ env = gym.make("CartPole-v1") model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) -obs = env.reset() +vec_env = model.get_env() +obs = vec_env.reset() for i in range(1000): action, _states = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) - env.render() - if done: - obs = env.reset() + obs, reward, done, info = vec_env.step(action) + vec_env.render() + # VecEnv resets automatically + # if done: + # obs = env.reset() env.close() ``` diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 646eddf..47bdc40 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -94,11 +94,12 @@ In the following example, we will train, save and load a DQN model on the Lunar mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10) # Enjoy trained agent - obs = env.reset() + vec_env = model.get_env() + obs = vec_env.reset() for i in range(1000): action, _states = model.predict(obs, deterministic=True) - obs, rewards, dones, info = env.step(action) - env.render() + obs, rewards, dones, info = vec_env.step(action) + vec_env.render() Multiprocessing: Unleashing the Power of Vectorized Environments diff --git a/docs/guide/quickstart.rst b/docs/guide/quickstart.rst index 7ad9e0e..5d1055a 100644 --- a/docs/guide/quickstart.rst +++ b/docs/guide/quickstart.rst @@ -19,13 +19,15 @@ Here is a quick example of how to train and run A2C on a CartPole environment: model = A2C("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) - obs = env.reset() + vec_env = model.get_env() + obs = vec_env.reset() for i in range(1000): action, _state = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) - env.render() - if done: - obs = env.reset() + obs, reward, done, info = vec_env.step(action) + vec_env.render() + # VecEnv resets automatically + # if done: + # obs = vec_env.reset() .. note:: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index f5cf0b6..9ae202b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -42,7 +42,6 @@ Documentation: ^^^^^^^^^^^^^^ - Updated Hugging Face Integration page (@simoninithomas) - Release 1.6.2 (2022-10-10) -------------------------- @@ -77,7 +76,6 @@ Documentation: ^^^^^^^^^^^^^^ - Extended docstring of the ``wrapper_class`` parameter in ``make_vec_env`` (@AlexPasqua) - Release 1.6.1 (2022-09-29) --------------------------- diff --git a/docs/modules/her.rst b/docs/modules/her.rst index 0b73351..817a991 100644 --- a/docs/modules/her.rst +++ b/docs/modules/her.rst @@ -19,10 +19,12 @@ It creates "virtual" transitions by relabeling transitions (changing the desired but a replay buffer class ``HerReplayBuffer`` that must be passed to an off-policy algorithm when using ``MultiInputPolicy`` (to have Dict observation support). - .. warning:: - HER requires the environment to inherits from `gym.GoalEnv `_ + HER requires the environment to follow the legacy `gym_robotics.GoalEnv interface `_ + In short, the ``gym.Env`` must have: + - a vectorized implementation of ``compute_reward()`` + - a dictionary observation space with three keys: ``observation``, ``achieved_goal`` and ``desired_goal`` .. warning:: diff --git a/setup.py b/setup.py index 44bcb15..b02e9e4 100644 --- a/setup.py +++ b/setup.py @@ -48,13 +48,16 @@ env = gym.make("CartPole-v1") model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) -obs = env.reset() +vec_env = model.get_env() +obs = vec_env.reset() for i in range(1000): action, _states = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) - env.render() - if done: - obs = env.reset() + obs, reward, done, info = vec_env.step(action) + vec_env.render() + # VecEnv resets automatically + # if done: + # obs = vec_env.reset() + ``` Or just train a model with a one liner if [the environment is registered in Gym](https://www.gymlibrary.ml/content/environment_creation/) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html): diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index a9b2eca..62178a1 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -86,7 +86,7 @@ class EpisodicLifeEnv(gym.Wrapper): # then update lives to handle bonus lives lives = self.env.unwrapped.ale.lives() if 0 < lives < self.lives: - # for Qbert sometimes we stay in lives == 0 condtion for a few frames + # for Qbert sometimes we stay in lives == 0 condition for a few frames # so its important to keep lives > 0, so that we only reset once # the environment advertises done. done = True diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 9351bfb..1309229 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -543,6 +543,7 @@ class BaseAlgorithm(ABC): return set_random_seed(seed, using_cuda=self.device.type == th.device("cuda").type) self.action_space.seed(seed) + # self.env is always a VecEnv if self.env is not None: self.env.seed(seed) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index efc05e3..1b6c3bb 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -1,5 +1,5 @@ import warnings -from typing import Union +from typing import Any, Dict, Union import gym import numpy as np @@ -93,6 +93,62 @@ def _check_nan(env: gym.Env) -> None: _, _, _, _ = vec_env.step(action) +def _is_goal_env(env: gym.Env) -> bool: + """ + Check if the env uses the convention for goal-conditioned envs (previously, the gym.GoalEnv interface) + """ + if isinstance(env, gym.Wrapper): # We need to unwrap the env since gym.Wrapper has the compute_reward method + return _is_goal_env(env.unwrapped) + return hasattr(env, "compute_reward") + + +def _check_goal_env_obs(obs: dict, observation_space: spaces.Dict, method_name: str) -> None: + """ + Check that an environment implementing the `compute_rewards()` method + (previously known as GoalEnv in gym) contains three elements, + namely `observation`, `desired_goal`, and `achieved_goal`. + """ + assert len(observation_space.spaces) == 3, ( + "A goal conditioned env must contain 3 observation keys: `observation`, `desired_goal`, and `achieved_goal`." + f"The current observation contains {len(observation_space.spaces)} keys: {list(observation_space.spaces.keys())}" + ) + + for key in ["observation", "achieved_goal", "desired_goal"]: + if key not in observation_space.spaces: + raise AssertionError( + f"The observation returned by the `{method_name}()` method of a goal-conditioned env requires the '{key}' " + "key to be part of the observation dictionary. " + f"Current keys are {list(observation_space.spaces.keys())}" + ) + + +def _check_goal_env_compute_reward( + obs: Dict[str, Union[np.ndarray, int]], + env: gym.Env, + reward: float, + info: Dict[str, Any], +): + """ + Check that reward is computed with `compute_reward` + and that the implementation is vectorized. + """ + achieved_goal, desired_goal = obs["achieved_goal"], obs["desired_goal"] + assert reward == env.compute_reward( # type: ignore[attr-defined] + achieved_goal, desired_goal, info + ), "The reward was not computed with `compute_reward()`" + + achieved_goal, desired_goal = np.array(achieved_goal), np.array(desired_goal) + batch_achieved_goals = np.array([achieved_goal, achieved_goal]) + batch_desired_goals = np.array([desired_goal, desired_goal]) + if isinstance(achieved_goal, int) or len(achieved_goal.shape) == 0: + batch_achieved_goals = batch_achieved_goals.reshape(2, 1) + batch_desired_goals = batch_desired_goals.reshape(2, 1) + batch_infos = np.array([info, info]) + rewards = env.compute_reward(batch_achieved_goals, batch_desired_goals, batch_infos) # type: ignore[attr-defined] + assert rewards.shape == (2,), f"Unexpected shape for vectorized computation of reward: {rewards.shape} != (2,)" + assert rewards[0] == reward, f"Vectorized computation of reward differs from single computation: {rewards[0]} != {reward}" + + def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spaces.Space, method_name: str) -> None: """ Check that the observation returned by the environment @@ -141,7 +197,11 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action # because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists obs = env.reset() - if isinstance(observation_space, spaces.Dict): + if _is_goal_env(env): + # Make mypy happy, already checked + assert isinstance(observation_space, spaces.Dict) + _check_goal_env_obs(obs, observation_space, "reset") + elif isinstance(observation_space, spaces.Dict): assert isinstance(obs, dict), "The observation returned by `reset()` must be a dictionary" if not obs.keys() == observation_space.spaces.keys(): @@ -167,7 +227,12 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action # Unpack obs, reward, done, info = data - if isinstance(observation_space, spaces.Dict): + if _is_goal_env(env): + # Make mypy happy, already checked + assert isinstance(observation_space, spaces.Dict) + _check_goal_env_obs(obs, observation_space, "step") + _check_goal_env_compute_reward(obs, env, reward, info) + elif isinstance(observation_space, spaces.Dict): assert isinstance(obs, dict), "The observation returned by `step()` must be a dictionary" if not obs.keys() == observation_space.spaces.keys(): @@ -190,15 +255,16 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action assert isinstance(done, bool), "The `done` signal must be a boolean" assert isinstance(info, dict), "The `info` returned by `step()` must be a python dictionary" - if isinstance(env, gym.GoalEnv): - # For a GoalEnv, the keys are checked at reset + # Goal conditioned env + if _is_goal_env(env): assert reward == env.compute_reward(obs["achieved_goal"], obs["desired_goal"], info) def _check_spaces(env: gym.Env) -> None: """ - Check that the observation and action spaces are defined - and inherit from gym.spaces.Space. + Check that the observation and action spaces are defined and inherit from gym.spaces.Space. For + envs that follow the goal-conditioned standard (previously, the gym.GoalEnv interface) we check + the observation space is gym.spaces.Dict """ # Helper to link to the code, because gym has no proper documentation gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/" @@ -209,6 +275,11 @@ def _check_spaces(env: gym.Env) -> None: assert isinstance(env.observation_space, spaces.Space), "The observation space must inherit from gym.spaces" + gym_spaces assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gym.spaces" + gym_spaces + if _is_goal_env(env): + assert isinstance( + env.observation_space, spaces.Dict + ), "Goal conditioned envs (previously gym.GoalEnv) require the observation space to be gym.spaces.Dict" + # Check render cannot be covered by CI def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> None: # pragma: no cover diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index a881b32..d6724c9 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -2,13 +2,13 @@ from collections import OrderedDict from typing import Any, Dict, Optional, Union import numpy as np -from gym import GoalEnv, spaces +from gym import Env, spaces from gym.envs.registration import EnvSpec from stable_baselines3.common.type_aliases import GymStepReturn -class BitFlippingEnv(GoalEnv): +class BitFlippingEnv(Env): """ Simple bit flipping env, useful to test HER. The goal is to flip all the bits to get a vector of ones. diff --git a/stable_baselines3/common/noise.py b/stable_baselines3/common/noise.py index baa72e9..5e8632d 100644 --- a/stable_baselines3/common/noise.py +++ b/stable_baselines3/common/noise.py @@ -10,7 +10,7 @@ class ActionNoise(ABC): The action noise base class """ - def __init__(self): + def __init__(self) -> None: super().__init__() def reset(self) -> None: diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 634d9e9..5e018fc 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -602,7 +602,6 @@ class OffPolicyAlgorithm(BaseAlgorithm): # Log training infos if log_interval is not None and self._episode_num % log_interval == 0: self._dump_logs() - callback.on_rollout_end() return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training) diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index a8bc3ef..97812a9 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -150,7 +150,6 @@ class TD3(OffPolicyAlgorithm): self._update_learning_rate([self.actor.optimizer, self.critic.optimizer]) actor_losses, critic_losses = [], [] - for _ in range(gradient_steps): self._n_updates += 1 diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index f749d5a..a0e20b7 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -102,7 +102,7 @@ def test_callbacks(tmp_path, model_class): def select_env(model_class) -> str: if model_class is DQN: - return "CartPole-v0" + return "CartPole-v1" else: return "Pendulum-v1" diff --git a/tests/test_envs.py b/tests/test_envs.py index 8b8cb8c..1281bb4 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -28,7 +28,7 @@ ENV_CLASSES = [ ] -@pytest.mark.parametrize("env_id", ["CartPole-v0", "Pendulum-v1"]) +@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) def test_env(env_id): """ Check that environmnent integrated in Gym pass the test. diff --git a/tests/test_predict.py b/tests/test_predict.py index 93bbc9d..22ff4fd 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -40,7 +40,7 @@ def test_auto_wrap(model_class): """Test auto wrapping of env into a VecEnv.""" # Use different environment for DQN if model_class is DQN: - env_name = "CartPole-v0" + env_name = "CartPole-v1" else: env_name = "Pendulum-v1" env = gym.make(env_name) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 91b0760..f96b69e 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -30,7 +30,7 @@ def select_env(model_class: BaseAlgorithm) -> gym.Env: if model_class == DQN: return IdentityEnv(10) else: - return IdentityEnvBox(10) + return IdentityEnvBox(-10, 10) @pytest.mark.parametrize("model_class", MODEL_LIST) diff --git a/tests/test_utils.py b/tests/test_utils.py index e74b1d0..83d695a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -248,7 +248,7 @@ def test_evaluate_policy_monitors(vec_env_class): # Also test VecEnvs n_eval_episodes = 3 n_envs = 2 - env_id = "CartPole-v0" + env_id = "CartPole-v1" model = A2C("MlpPolicy", env_id, seed=0) def make_eval_env(with_monitor, wrapper_class=gym.Wrapper): diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index 5ccc33e..0a146a0 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -140,7 +140,7 @@ def test_vec_monitor_ppo(recwarn): # No warnings because using `VecMonitor` evaluate_policy(model, monitor_env) - assert len(recwarn) == 0 + assert len(recwarn) == 0, f"{[str(warning) for warning in recwarn]}" def test_vec_monitor_warn(): diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 07c720f..00af193 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -40,7 +40,7 @@ class DummyRewardEnv(gym.Env): return np.array([self.returned_rewards[self.return_reward_idx]]) -class DummyDictEnv(gym.GoalEnv): +class DummyDictEnv(gym.Env): """ Dummy gym goal env for testing purposes """