mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-01 03:45:11 +00:00
Fix VecExtractDictObs does not handle terminal observation (#1443)
* VecExtractDictObs handle terminal_observation * Added VecExtractDictObs handle terminal_output to changelog * Update changelog.rst * Update test_vec_extract_dict_obs.py Add random dones in env to test if terminal_observation is properly handled * Made test deterministic * Fixed bug in test * Improved test * Fix format in test * Update test * Fix type hint * Ignore pytype warning * Ignore pytype --------- Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
4232f9daa9
commit
15c9daa2ba
5 changed files with 36 additions and 8 deletions
|
|
@ -21,6 +21,7 @@ New Features:
|
|||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed ``VecExtractDictObs`` does not handle terminal observation (@WeberSamuel)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -1299,4 +1300,4 @@ And all the contributors:
|
|||
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
|
||||
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi
|
||||
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
|
||||
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit
|
||||
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel
|
||||
|
|
@ -339,7 +339,7 @@ class VecEnvWrapper(VecEnv):
|
|||
|
||||
return attr
|
||||
|
||||
def getattr_depth_check(self, name: str, already_found: bool) -> str:
|
||||
def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]:
|
||||
"""See base class.
|
||||
|
||||
:return: name of module whose attribute is being shadowed, if any.
|
||||
|
|
|
|||
|
|
@ -135,14 +135,16 @@ class StackedObservations(Generic[TObs]):
|
|||
:return: The stacked reset observation
|
||||
"""
|
||||
if isinstance(observation, dict):
|
||||
return {key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()}
|
||||
return {
|
||||
key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()
|
||||
} # pytype: disable=bad-return-type
|
||||
|
||||
self.stacked_obs[...] = 0
|
||||
if self.channels_first:
|
||||
self.stacked_obs[:, -observation.shape[self.stack_dimension] :, ...] = observation
|
||||
else:
|
||||
self.stacked_obs[..., -observation.shape[self.stack_dimension] :] = observation
|
||||
return self.stacked_obs
|
||||
return self.stacked_obs # pytype: disable=bad-return-type
|
||||
|
||||
def update(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -20,5 +20,8 @@ class VecExtractDictObs(VecEnvWrapper):
|
|||
return obs[self.key]
|
||||
|
||||
def step_wait(self) -> VecEnvStepReturn:
|
||||
obs, reward, done, info = self.venv.step_wait()
|
||||
return obs[self.key], reward, done, info
|
||||
obs, reward, done, infos = self.venv.step_wait()
|
||||
for info in infos:
|
||||
if "terminal_observation" in info:
|
||||
info["terminal_observation"] = info["terminal_observation"][self.key]
|
||||
return obs[self.key], reward, done, infos
|
||||
|
|
|
|||
|
|
@ -14,19 +14,31 @@ class DictObsVecEnv:
|
|||
self.num_envs = 4
|
||||
self.action_space = spaces.Discrete(2)
|
||||
self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)})
|
||||
self.n_steps = 0
|
||||
self.max_steps = 5
|
||||
|
||||
def step_async(self, actions):
|
||||
self.actions = actions
|
||||
|
||||
def step_wait(self):
|
||||
self.n_steps += 1
|
||||
done = self.n_steps >= self.max_steps
|
||||
if done:
|
||||
infos = [
|
||||
{"terminal_observation": {"rgb": np.zeros((86, 86))}, "TimeLimit.truncated": True}
|
||||
for _ in range(self.num_envs)
|
||||
]
|
||||
else:
|
||||
infos = []
|
||||
return (
|
||||
{"rgb": np.zeros((self.num_envs, 86, 86))},
|
||||
np.zeros((self.num_envs,)),
|
||||
np.zeros((self.num_envs,), dtype=bool),
|
||||
[{} for _ in range(self.num_envs)],
|
||||
np.ones((self.num_envs,), dtype=bool) * done,
|
||||
infos,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self.n_steps = 0
|
||||
return {"rgb": np.zeros((self.num_envs, 86, 86))}
|
||||
|
||||
def render(self, mode="human", close=False):
|
||||
|
|
@ -40,6 +52,16 @@ def test_extract_dict_obs():
|
|||
env = VecExtractDictObs(env, "rgb")
|
||||
assert env.reset().shape == (4, 86, 86)
|
||||
|
||||
for _ in range(10):
|
||||
obs, _, dones, infos = env.step([env.action_space.sample() for _ in range(env.num_envs)])
|
||||
assert obs.shape == (4, 86, 86)
|
||||
for idx, info in enumerate(infos):
|
||||
if "terminal_observation" in info:
|
||||
assert dones[idx]
|
||||
assert info["terminal_observation"].shape == (86, 86)
|
||||
else:
|
||||
assert not dones[idx]
|
||||
|
||||
|
||||
def test_vec_with_ppo():
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue