diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 2f7dbe2..dec4c2c 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,34 @@ Changelog ========== +Release 1.8.1a0 (WIP) +-------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Renamed environment output observations in ``evaluate_policy`` to prevent shadowing the input observations during callbacks (@npit) + +New Features: +^^^^^^^^^^^^^ + +`SB3-Contrib`_ +^^^^^^^^^^^^^^ + +`RL Zoo`_ +^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ + Release 1.8.0 (2023-04-07) -------------------------- @@ -1271,4 +1299,4 @@ And all the contributors: @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong -@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher +@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index b65edf8..3b63437 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -86,7 +86,7 @@ def evaluate_policy( episode_starts = np.ones((env.num_envs,), dtype=bool) while (episode_counts < episode_count_targets).any(): actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic) - observations, rewards, dones, infos = env.step(actions) + new_observations, rewards, dones, infos = env.step(actions) current_rewards += rewards current_lengths += 1 for i in range(n_envs): @@ -120,6 +120,8 @@ def evaluate_policy( current_rewards[i] = 0 current_lengths[i] = 0 + observations = new_observations + if render: env.render() diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 27f9cd3..9eba1a1 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.8.0 +1.8.1a0 diff --git a/tests/test_utils.py b/tests/test_utils.py index c4399a8..88de942 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -183,6 +183,10 @@ def test_evaluate_policy(direct_policy: bool): def dummy_callback(locals_, _globals): locals_["model"].n_callback_calls += 1 + assert "observations" in locals_ + assert "new_observations" in locals_ + assert locals_["new_observations"] is not locals_["observations"] + assert not np.allclose(locals_["new_observations"], locals_["observations"]) assert model.policy is not None policy = model.policy if direct_policy else model