mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Rename the observations variable in the evaluation util to avoid shadowing (#1288)
* Rename the observations variable in the evaluation util to avoid shadowing This enables a callback in evaluate_policy to have access to the observation vector that is fed to the environment step function, which is currently shadowed by the output observation. * Update changelog * Add test * Move assignment outside of the loop --------- Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
parent
84f5511e08
commit
4232f9daa9
4 changed files with 37 additions and 3 deletions
|
|
@ -3,6 +3,34 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.8.1a0 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Renamed environment output observations in ``evaluate_policy`` to prevent shadowing the input observations during callbacks (@npit)
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
`SB3-Contrib`_
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
`RL Zoo`_
|
||||
^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Release 1.8.0 (2023-04-07)
|
||||
--------------------------
|
||||
|
|
@ -1271,4 +1299,4 @@ And all the contributors:
|
|||
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
|
||||
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi
|
||||
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
|
||||
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher
|
||||
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ def evaluate_policy(
|
|||
episode_starts = np.ones((env.num_envs,), dtype=bool)
|
||||
while (episode_counts < episode_count_targets).any():
|
||||
actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic)
|
||||
observations, rewards, dones, infos = env.step(actions)
|
||||
new_observations, rewards, dones, infos = env.step(actions)
|
||||
current_rewards += rewards
|
||||
current_lengths += 1
|
||||
for i in range(n_envs):
|
||||
|
|
@ -120,6 +120,8 @@ def evaluate_policy(
|
|||
current_rewards[i] = 0
|
||||
current_lengths[i] = 0
|
||||
|
||||
observations = new_observations
|
||||
|
||||
if render:
|
||||
env.render()
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.8.0
|
||||
1.8.1a0
|
||||
|
|
|
|||
|
|
@ -183,6 +183,10 @@ def test_evaluate_policy(direct_policy: bool):
|
|||
|
||||
def dummy_callback(locals_, _globals):
|
||||
locals_["model"].n_callback_calls += 1
|
||||
assert "observations" in locals_
|
||||
assert "new_observations" in locals_
|
||||
assert locals_["new_observations"] is not locals_["observations"]
|
||||
assert not np.allclose(locals_["new_observations"], locals_["observations"])
|
||||
|
||||
assert model.policy is not None
|
||||
policy = model.policy if direct_policy else model
|
||||
|
|
|
|||
Loading…
Reference in a new issue