Rename the observations variable in the evaluation util to avoid shadowing (#1288)

* Rename the observations variable in the evaluation util to avoid shadowing

This enables a callback in evaluate_policy to have access to the
observation vector that is fed to the environment step function,
which is currently shadowed by the output observation.

* Update changelog

* Add test

* Move assignment outside of the loop

---------

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
npit 2023-04-11 19:00:33 +03:00 committed by GitHub
parent 84f5511e08
commit 4232f9daa9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 3 deletions

View file

@ -3,6 +3,34 @@
Changelog
==========
Release 1.8.1a0 (WIP)
--------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Renamed environment output observations in ``evaluate_policy`` to prevent shadowing the input observations during callbacks (@npit)
New Features:
^^^^^^^^^^^^^
`SB3-Contrib`_
^^^^^^^^^^^^^^
`RL Zoo`_
^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
Documentation:
^^^^^^^^^^^^^^
Release 1.8.0 (2023-04-07)
--------------------------
@ -1271,4 +1299,4 @@ And all the contributors:
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit

View file

@ -86,7 +86,7 @@ def evaluate_policy(
episode_starts = np.ones((env.num_envs,), dtype=bool)
while (episode_counts < episode_count_targets).any():
actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic)
observations, rewards, dones, infos = env.step(actions)
new_observations, rewards, dones, infos = env.step(actions)
current_rewards += rewards
current_lengths += 1
for i in range(n_envs):
@ -120,6 +120,8 @@ def evaluate_policy(
current_rewards[i] = 0
current_lengths[i] = 0
observations = new_observations
if render:
env.render()

View file

@ -1 +1 @@
1.8.0
1.8.1a0

View file

@ -183,6 +183,10 @@ def test_evaluate_policy(direct_policy: bool):
def dummy_callback(locals_, _globals):
locals_["model"].n_callback_calls += 1
assert "observations" in locals_
assert "new_observations" in locals_
assert locals_["new_observations"] is not locals_["observations"]
assert not np.allclose(locals_["new_observations"], locals_["observations"])
assert model.policy is not None
policy = model.policy if direct_policy else model