From 4232f9daa97408dd58844e5587cd8a5adcaffaf5 Mon Sep 17 00:00:00 2001 From: npit Date: Tue, 11 Apr 2023 19:00:33 +0300 Subject: [PATCH] Rename the observations variable in the evaluation util to avoid shadowing (#1288) * Rename the observations variable in the evaluation util to avoid shadowing This enables a callback in evaluate_policy to have access to the observation vector that is fed to the environment step function, which is currently shadowed by the output observation. * Update changelog * Add test * Move assignment outside of the loop --------- Co-authored-by: Antonin RAFFIN Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 30 +++++++++++++++++++++++++- stable_baselines3/common/evaluation.py | 4 +++- stable_baselines3/version.txt | 2 +- tests/test_utils.py | 4 ++++ 4 files changed, 37 insertions(+), 3 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 2f7dbe2..dec4c2c 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,34 @@ Changelog ========== +Release 1.8.1a0 (WIP) +-------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Renamed environment output observations in ``evaluate_policy`` to prevent shadowing the input observations during callbacks (@npit) + +New Features: +^^^^^^^^^^^^^ + +`SB3-Contrib`_ +^^^^^^^^^^^^^^ + +`RL Zoo`_ +^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ + Release 1.8.0 (2023-04-07) -------------------------- @@ -1271,4 +1299,4 @@ And all the contributors: @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong -@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher +@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index b65edf8..3b63437 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -86,7 +86,7 @@ def evaluate_policy( episode_starts = np.ones((env.num_envs,), dtype=bool) while (episode_counts < episode_count_targets).any(): actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic) - observations, rewards, dones, infos = env.step(actions) + new_observations, rewards, dones, infos = env.step(actions) current_rewards += rewards current_lengths += 1 for i in range(n_envs): @@ -120,6 +120,8 @@ def evaluate_policy( current_rewards[i] = 0 current_lengths[i] = 0 + observations = new_observations + if render: env.render() diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 27f9cd3..9eba1a1 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.8.0 +1.8.1a0 diff --git a/tests/test_utils.py b/tests/test_utils.py index c4399a8..88de942 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -183,6 +183,10 @@ def test_evaluate_policy(direct_policy: bool): def dummy_callback(locals_, _globals): locals_["model"].n_callback_calls += 1 + assert "observations" in locals_ + assert "new_observations" in locals_ + assert locals_["new_observations"] is not locals_["observations"] + assert not np.allclose(locals_["new_observations"], locals_["observations"]) assert model.policy is not None policy = model.policy if direct_policy else model