From 2ca94cb73dfa311fd3e0dddbf9e516d84ddb8da8 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 25 Sep 2023 12:39:22 +0200 Subject: [PATCH] Add check for common mistake when mixing Gym/VecEnv API (#1696) --- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/policies.py | 11 +++++++++++ stable_baselines3/version.txt | 2 +- tests/test_predict.py | 9 +++++++++ 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a0a29b0..3932918 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.2.0a4 (WIP) +Release 2.2.0a5 (WIP) -------------------------- Breaking Changes: @@ -13,6 +13,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Improved error message of the ``env_checker`` for env wrongly detected as GoalEnv (``compute_reward()`` is defined) +- Improved error message when mixing Gym API with VecEnv API (see GH#1694) Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index a4b462b..e975349 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -343,6 +343,17 @@ class BasePolicy(BaseModel, ABC): # Switch to eval mode (this affects batch norm / dropout) self.set_training_mode(False) + # Check for common mistake that the user does not mix Gym/VecEnv API + # Tuple obs are not supported by SB3, so we can safely do that check + if isinstance(observation, tuple) and len(observation) == 2 and isinstance(observation[1], dict): + raise ValueError( + "You have passed a tuple to the predict() function instead of a Numpy array or a Dict. " + "You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) " + "vs `obs = vec_env.reset()` (SB3 VecEnv). " + "See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 " + "and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api" + ) + observation, vectorized_env = self.obs_to_tensor(observation) with th.no_grad(): diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index ddcf092..210ed6b 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.2.0a4 +2.2.0a5 diff --git a/tests/test_predict.py b/tests/test_predict.py index aac6b16..9a84523 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -116,3 +116,12 @@ def test_subclassed_space_env(model_class): model.learn(300) obs, _ = env.reset() env.step(model.predict(obs)) + + +def test_mixing_gym_vecenv_api(): + env = gym.make("CartPole-v1") + model = PPO("MlpPolicy", env) + # Reset return a tuple (obs, info) + wrong_obs = env.reset() + with pytest.raises(ValueError, match="mixing Gym API"): + model.predict(wrong_obs)