mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-02 03:55:39 +00:00
* Various changes from #780 * Fix env_checker for goal_env detection
This commit is contained in:
parent
cd630a3121
commit
e3b24829a5
20 changed files with 120 additions and 42 deletions
12
README.md
12
README.md
|
|
@ -126,13 +126,15 @@ env = gym.make("CartPole-v1")
|
|||
model = PPO("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=10_000)
|
||||
|
||||
obs = env.reset()
|
||||
vec_env = model.get_env()
|
||||
obs = vec_env.reset()
|
||||
for i in range(1000):
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, info = env.step(action)
|
||||
env.render()
|
||||
if done:
|
||||
obs = env.reset()
|
||||
obs, reward, done, info = vec_env.step(action)
|
||||
vec_env.render()
|
||||
# VecEnv resets automatically
|
||||
# if done:
|
||||
# obs = env.reset()
|
||||
|
||||
env.close()
|
||||
```
|
||||
|
|
|
|||
|
|
@ -94,11 +94,12 @@ In the following example, we will train, save and load a DQN model on the Lunar
|
|||
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
|
||||
|
||||
# Enjoy trained agent
|
||||
obs = env.reset()
|
||||
vec_env = model.get_env()
|
||||
obs = vec_env.reset()
|
||||
for i in range(1000):
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
obs, rewards, dones, info = vec_env.step(action)
|
||||
vec_env.render()
|
||||
|
||||
|
||||
Multiprocessing: Unleashing the Power of Vectorized Environments
|
||||
|
|
|
|||
|
|
@ -19,13 +19,15 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
|
|||
model = A2C("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=10_000)
|
||||
|
||||
obs = env.reset()
|
||||
vec_env = model.get_env()
|
||||
obs = vec_env.reset()
|
||||
for i in range(1000):
|
||||
action, _state = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, info = env.step(action)
|
||||
env.render()
|
||||
if done:
|
||||
obs = env.reset()
|
||||
obs, reward, done, info = vec_env.step(action)
|
||||
vec_env.render()
|
||||
# VecEnv resets automatically
|
||||
# if done:
|
||||
# obs = vec_env.reset()
|
||||
|
||||
.. note::
|
||||
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ Documentation:
|
|||
^^^^^^^^^^^^^^
|
||||
- Updated Hugging Face Integration page (@simoninithomas)
|
||||
|
||||
|
||||
Release 1.6.2 (2022-10-10)
|
||||
--------------------------
|
||||
|
||||
|
|
@ -77,7 +76,6 @@ Documentation:
|
|||
^^^^^^^^^^^^^^
|
||||
- Extended docstring of the ``wrapper_class`` parameter in ``make_vec_env`` (@AlexPasqua)
|
||||
|
||||
|
||||
Release 1.6.1 (2022-09-29)
|
||||
---------------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -19,10 +19,12 @@ It creates "virtual" transitions by relabeling transitions (changing the desired
|
|||
but a replay buffer class ``HerReplayBuffer`` that must be passed to an off-policy algorithm
|
||||
when using ``MultiInputPolicy`` (to have Dict observation support).
|
||||
|
||||
|
||||
.. warning::
|
||||
|
||||
HER requires the environment to inherits from `gym.GoalEnv <https://github.com/openai/gym/blob/3394e245727c1ae6851b504a50ba77c73cd4c65b/gym/core.py#L160>`_
|
||||
HER requires the environment to follow the legacy `gym_robotics.GoalEnv interface <https://github.com/Farama-Foundation/Gymnasium-Robotics/blob/a35b1c1fa669428bf640a2c7101e66eb1627ac3a/gym_robotics/core.py#L8>`_
|
||||
In short, the ``gym.Env`` must have:
|
||||
- a vectorized implementation of ``compute_reward()``
|
||||
- a dictionary observation space with three keys: ``observation``, ``achieved_goal`` and ``desired_goal``
|
||||
|
||||
|
||||
.. warning::
|
||||
|
|
|
|||
13
setup.py
13
setup.py
|
|
@ -48,13 +48,16 @@ env = gym.make("CartPole-v1")
|
|||
model = PPO("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=10_000)
|
||||
|
||||
obs = env.reset()
|
||||
vec_env = model.get_env()
|
||||
obs = vec_env.reset()
|
||||
for i in range(1000):
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, info = env.step(action)
|
||||
env.render()
|
||||
if done:
|
||||
obs = env.reset()
|
||||
obs, reward, done, info = vec_env.step(action)
|
||||
vec_env.render()
|
||||
# VecEnv resets automatically
|
||||
# if done:
|
||||
# obs = vec_env.reset()
|
||||
|
||||
```
|
||||
|
||||
Or just train a model with a one liner if [the environment is registered in Gym](https://www.gymlibrary.ml/content/environment_creation/) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ class EpisodicLifeEnv(gym.Wrapper):
|
|||
# then update lives to handle bonus lives
|
||||
lives = self.env.unwrapped.ale.lives()
|
||||
if 0 < lives < self.lives:
|
||||
# for Qbert sometimes we stay in lives == 0 condtion for a few frames
|
||||
# for Qbert sometimes we stay in lives == 0 condition for a few frames
|
||||
# so its important to keep lives > 0, so that we only reset once
|
||||
# the environment advertises done.
|
||||
done = True
|
||||
|
|
|
|||
|
|
@ -543,6 +543,7 @@ class BaseAlgorithm(ABC):
|
|||
return
|
||||
set_random_seed(seed, using_cuda=self.device.type == th.device("cuda").type)
|
||||
self.action_space.seed(seed)
|
||||
# self.env is always a VecEnv
|
||||
if self.env is not None:
|
||||
self.env.seed(seed)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import warnings
|
||||
from typing import Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -93,6 +93,62 @@ def _check_nan(env: gym.Env) -> None:
|
|||
_, _, _, _ = vec_env.step(action)
|
||||
|
||||
|
||||
def _is_goal_env(env: gym.Env) -> bool:
|
||||
"""
|
||||
Check if the env uses the convention for goal-conditioned envs (previously, the gym.GoalEnv interface)
|
||||
"""
|
||||
if isinstance(env, gym.Wrapper): # We need to unwrap the env since gym.Wrapper has the compute_reward method
|
||||
return _is_goal_env(env.unwrapped)
|
||||
return hasattr(env, "compute_reward")
|
||||
|
||||
|
||||
def _check_goal_env_obs(obs: dict, observation_space: spaces.Dict, method_name: str) -> None:
|
||||
"""
|
||||
Check that an environment implementing the `compute_rewards()` method
|
||||
(previously known as GoalEnv in gym) contains three elements,
|
||||
namely `observation`, `desired_goal`, and `achieved_goal`.
|
||||
"""
|
||||
assert len(observation_space.spaces) == 3, (
|
||||
"A goal conditioned env must contain 3 observation keys: `observation`, `desired_goal`, and `achieved_goal`."
|
||||
f"The current observation contains {len(observation_space.spaces)} keys: {list(observation_space.spaces.keys())}"
|
||||
)
|
||||
|
||||
for key in ["observation", "achieved_goal", "desired_goal"]:
|
||||
if key not in observation_space.spaces:
|
||||
raise AssertionError(
|
||||
f"The observation returned by the `{method_name}()` method of a goal-conditioned env requires the '{key}' "
|
||||
"key to be part of the observation dictionary. "
|
||||
f"Current keys are {list(observation_space.spaces.keys())}"
|
||||
)
|
||||
|
||||
|
||||
def _check_goal_env_compute_reward(
|
||||
obs: Dict[str, Union[np.ndarray, int]],
|
||||
env: gym.Env,
|
||||
reward: float,
|
||||
info: Dict[str, Any],
|
||||
):
|
||||
"""
|
||||
Check that reward is computed with `compute_reward`
|
||||
and that the implementation is vectorized.
|
||||
"""
|
||||
achieved_goal, desired_goal = obs["achieved_goal"], obs["desired_goal"]
|
||||
assert reward == env.compute_reward( # type: ignore[attr-defined]
|
||||
achieved_goal, desired_goal, info
|
||||
), "The reward was not computed with `compute_reward()`"
|
||||
|
||||
achieved_goal, desired_goal = np.array(achieved_goal), np.array(desired_goal)
|
||||
batch_achieved_goals = np.array([achieved_goal, achieved_goal])
|
||||
batch_desired_goals = np.array([desired_goal, desired_goal])
|
||||
if isinstance(achieved_goal, int) or len(achieved_goal.shape) == 0:
|
||||
batch_achieved_goals = batch_achieved_goals.reshape(2, 1)
|
||||
batch_desired_goals = batch_desired_goals.reshape(2, 1)
|
||||
batch_infos = np.array([info, info])
|
||||
rewards = env.compute_reward(batch_achieved_goals, batch_desired_goals, batch_infos) # type: ignore[attr-defined]
|
||||
assert rewards.shape == (2,), f"Unexpected shape for vectorized computation of reward: {rewards.shape} != (2,)"
|
||||
assert rewards[0] == reward, f"Vectorized computation of reward differs from single computation: {rewards[0]} != {reward}"
|
||||
|
||||
|
||||
def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spaces.Space, method_name: str) -> None:
|
||||
"""
|
||||
Check that the observation returned by the environment
|
||||
|
|
@ -141,7 +197,11 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
|
|||
# because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
|
||||
obs = env.reset()
|
||||
|
||||
if isinstance(observation_space, spaces.Dict):
|
||||
if _is_goal_env(env):
|
||||
# Make mypy happy, already checked
|
||||
assert isinstance(observation_space, spaces.Dict)
|
||||
_check_goal_env_obs(obs, observation_space, "reset")
|
||||
elif isinstance(observation_space, spaces.Dict):
|
||||
assert isinstance(obs, dict), "The observation returned by `reset()` must be a dictionary"
|
||||
|
||||
if not obs.keys() == observation_space.spaces.keys():
|
||||
|
|
@ -167,7 +227,12 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
|
|||
# Unpack
|
||||
obs, reward, done, info = data
|
||||
|
||||
if isinstance(observation_space, spaces.Dict):
|
||||
if _is_goal_env(env):
|
||||
# Make mypy happy, already checked
|
||||
assert isinstance(observation_space, spaces.Dict)
|
||||
_check_goal_env_obs(obs, observation_space, "step")
|
||||
_check_goal_env_compute_reward(obs, env, reward, info)
|
||||
elif isinstance(observation_space, spaces.Dict):
|
||||
assert isinstance(obs, dict), "The observation returned by `step()` must be a dictionary"
|
||||
|
||||
if not obs.keys() == observation_space.spaces.keys():
|
||||
|
|
@ -190,15 +255,16 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
|
|||
assert isinstance(done, bool), "The `done` signal must be a boolean"
|
||||
assert isinstance(info, dict), "The `info` returned by `step()` must be a python dictionary"
|
||||
|
||||
if isinstance(env, gym.GoalEnv):
|
||||
# For a GoalEnv, the keys are checked at reset
|
||||
# Goal conditioned env
|
||||
if _is_goal_env(env):
|
||||
assert reward == env.compute_reward(obs["achieved_goal"], obs["desired_goal"], info)
|
||||
|
||||
|
||||
def _check_spaces(env: gym.Env) -> None:
|
||||
"""
|
||||
Check that the observation and action spaces are defined
|
||||
and inherit from gym.spaces.Space.
|
||||
Check that the observation and action spaces are defined and inherit from gym.spaces.Space. For
|
||||
envs that follow the goal-conditioned standard (previously, the gym.GoalEnv interface) we check
|
||||
the observation space is gym.spaces.Dict
|
||||
"""
|
||||
# Helper to link to the code, because gym has no proper documentation
|
||||
gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/"
|
||||
|
|
@ -209,6 +275,11 @@ def _check_spaces(env: gym.Env) -> None:
|
|||
assert isinstance(env.observation_space, spaces.Space), "The observation space must inherit from gym.spaces" + gym_spaces
|
||||
assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gym.spaces" + gym_spaces
|
||||
|
||||
if _is_goal_env(env):
|
||||
assert isinstance(
|
||||
env.observation_space, spaces.Dict
|
||||
), "Goal conditioned envs (previously gym.GoalEnv) require the observation space to be gym.spaces.Dict"
|
||||
|
||||
|
||||
# Check render cannot be covered by CI
|
||||
def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> None: # pragma: no cover
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ from collections import OrderedDict
|
|||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from gym import GoalEnv, spaces
|
||||
from gym import Env, spaces
|
||||
from gym.envs.registration import EnvSpec
|
||||
|
||||
from stable_baselines3.common.type_aliases import GymStepReturn
|
||||
|
||||
|
||||
class BitFlippingEnv(GoalEnv):
|
||||
class BitFlippingEnv(Env):
|
||||
"""
|
||||
Simple bit flipping env, useful to test HER.
|
||||
The goal is to flip all the bits to get a vector of ones.
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class ActionNoise(ABC):
|
|||
The action noise base class
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def reset(self) -> None:
|
||||
|
|
|
|||
|
|
@ -602,7 +602,6 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
# Log training infos
|
||||
if log_interval is not None and self._episode_num % log_interval == 0:
|
||||
self._dump_logs()
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)
|
||||
|
|
|
|||
|
|
@ -150,7 +150,6 @@ class TD3(OffPolicyAlgorithm):
|
|||
self._update_learning_rate([self.actor.optimizer, self.critic.optimizer])
|
||||
|
||||
actor_losses, critic_losses = [], []
|
||||
|
||||
for _ in range(gradient_steps):
|
||||
|
||||
self._n_updates += 1
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ def test_callbacks(tmp_path, model_class):
|
|||
|
||||
def select_env(model_class) -> str:
|
||||
if model_class is DQN:
|
||||
return "CartPole-v0"
|
||||
return "CartPole-v1"
|
||||
else:
|
||||
return "Pendulum-v1"
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ ENV_CLASSES = [
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_id", ["CartPole-v0", "Pendulum-v1"])
|
||||
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
|
||||
def test_env(env_id):
|
||||
"""
|
||||
Check that environmnent integrated in Gym pass the test.
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ def test_auto_wrap(model_class):
|
|||
"""Test auto wrapping of env into a VecEnv."""
|
||||
# Use different environment for DQN
|
||||
if model_class is DQN:
|
||||
env_name = "CartPole-v0"
|
||||
env_name = "CartPole-v1"
|
||||
else:
|
||||
env_name = "Pendulum-v1"
|
||||
env = gym.make(env_name)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ def select_env(model_class: BaseAlgorithm) -> gym.Env:
|
|||
if model_class == DQN:
|
||||
return IdentityEnv(10)
|
||||
else:
|
||||
return IdentityEnvBox(10)
|
||||
return IdentityEnvBox(-10, 10)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
|
|
|
|||
|
|
@ -248,7 +248,7 @@ def test_evaluate_policy_monitors(vec_env_class):
|
|||
# Also test VecEnvs
|
||||
n_eval_episodes = 3
|
||||
n_envs = 2
|
||||
env_id = "CartPole-v0"
|
||||
env_id = "CartPole-v1"
|
||||
model = A2C("MlpPolicy", env_id, seed=0)
|
||||
|
||||
def make_eval_env(with_monitor, wrapper_class=gym.Wrapper):
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ def test_vec_monitor_ppo(recwarn):
|
|||
|
||||
# No warnings because using `VecMonitor`
|
||||
evaluate_policy(model, monitor_env)
|
||||
assert len(recwarn) == 0
|
||||
assert len(recwarn) == 0, f"{[str(warning) for warning in recwarn]}"
|
||||
|
||||
|
||||
def test_vec_monitor_warn():
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class DummyRewardEnv(gym.Env):
|
|||
return np.array([self.returned_rewards[self.return_reward_idx]])
|
||||
|
||||
|
||||
class DummyDictEnv(gym.GoalEnv):
|
||||
class DummyDictEnv(gym.Env):
|
||||
"""
|
||||
Dummy gym goal env for testing purposes
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue