mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Add check for common mistake when mixing Gym/VecEnv API (#1696)
This commit is contained in:
parent
b85fa7533e
commit
2ca94cb73d
4 changed files with 23 additions and 2 deletions
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.2.0a4 (WIP)
|
||||
Release 2.2.0a5 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -13,6 +13,7 @@ Breaking Changes:
|
|||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Improved error message of the ``env_checker`` for env wrongly detected as GoalEnv (``compute_reward()`` is defined)
|
||||
- Improved error message when mixing Gym API with VecEnv API (see GH#1694)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -343,6 +343,17 @@ class BasePolicy(BaseModel, ABC):
|
|||
# Switch to eval mode (this affects batch norm / dropout)
|
||||
self.set_training_mode(False)
|
||||
|
||||
# Check for common mistake that the user does not mix Gym/VecEnv API
|
||||
# Tuple obs are not supported by SB3, so we can safely do that check
|
||||
if isinstance(observation, tuple) and len(observation) == 2 and isinstance(observation[1], dict):
|
||||
raise ValueError(
|
||||
"You have passed a tuple to the predict() function instead of a Numpy array or a Dict. "
|
||||
"You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) "
|
||||
"vs `obs = vec_env.reset()` (SB3 VecEnv). "
|
||||
"See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 "
|
||||
"and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
|
||||
)
|
||||
|
||||
observation, vectorized_env = self.obs_to_tensor(observation)
|
||||
|
||||
with th.no_grad():
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.2.0a4
|
||||
2.2.0a5
|
||||
|
|
|
|||
|
|
@ -116,3 +116,12 @@ def test_subclassed_space_env(model_class):
|
|||
model.learn(300)
|
||||
obs, _ = env.reset()
|
||||
env.step(model.predict(obs))
|
||||
|
||||
|
||||
def test_mixing_gym_vecenv_api():
|
||||
env = gym.make("CartPole-v1")
|
||||
model = PPO("MlpPolicy", env)
|
||||
# Reset return a tuple (obs, info)
|
||||
wrong_obs = env.reset()
|
||||
with pytest.raises(ValueError, match="mixing Gym API"):
|
||||
model.predict(wrong_obs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue