diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index dec4c2c..4ae5e47 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -21,6 +21,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Fixed ``VecExtractDictObs`` does not handle terminal observation (@WeberSamuel) Deprecations: ^^^^^^^^^^^^^ @@ -1299,4 +1300,4 @@ And all the contributors: @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong -@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit +@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel \ No newline at end of file diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 0b3e1b4..708c021 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -339,7 +339,7 @@ class VecEnvWrapper(VecEnv): return attr - def getattr_depth_check(self, name: str, already_found: bool) -> str: + def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]: """See base class. :return: name of module whose attribute is being shadowed, if any. diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index ae7aebc..e674ee0 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -135,14 +135,16 @@ class StackedObservations(Generic[TObs]): :return: The stacked reset observation """ if isinstance(observation, dict): - return {key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()} + return { + key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items() + } # pytype: disable=bad-return-type self.stacked_obs[...] = 0 if self.channels_first: self.stacked_obs[:, -observation.shape[self.stack_dimension] :, ...] = observation else: self.stacked_obs[..., -observation.shape[self.stack_dimension] :] = observation - return self.stacked_obs + return self.stacked_obs # pytype: disable=bad-return-type def update( self, diff --git a/stable_baselines3/common/vec_env/vec_extract_dict_obs.py b/stable_baselines3/common/vec_env/vec_extract_dict_obs.py index 8582b7a..66872dd 100644 --- a/stable_baselines3/common/vec_env/vec_extract_dict_obs.py +++ b/stable_baselines3/common/vec_env/vec_extract_dict_obs.py @@ -20,5 +20,8 @@ class VecExtractDictObs(VecEnvWrapper): return obs[self.key] def step_wait(self) -> VecEnvStepReturn: - obs, reward, done, info = self.venv.step_wait() - return obs[self.key], reward, done, info + obs, reward, done, infos = self.venv.step_wait() + for info in infos: + if "terminal_observation" in info: + info["terminal_observation"] = info["terminal_observation"][self.key] + return obs[self.key], reward, done, infos diff --git a/tests/test_vec_extract_dict_obs.py b/tests/test_vec_extract_dict_obs.py index 1507442..17728bb 100644 --- a/tests/test_vec_extract_dict_obs.py +++ b/tests/test_vec_extract_dict_obs.py @@ -14,19 +14,31 @@ class DictObsVecEnv: self.num_envs = 4 self.action_space = spaces.Discrete(2) self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)}) + self.n_steps = 0 + self.max_steps = 5 def step_async(self, actions): self.actions = actions def step_wait(self): + self.n_steps += 1 + done = self.n_steps >= self.max_steps + if done: + infos = [ + {"terminal_observation": {"rgb": np.zeros((86, 86))}, "TimeLimit.truncated": True} + for _ in range(self.num_envs) + ] + else: + infos = [] return ( {"rgb": np.zeros((self.num_envs, 86, 86))}, np.zeros((self.num_envs,)), - np.zeros((self.num_envs,), dtype=bool), - [{} for _ in range(self.num_envs)], + np.ones((self.num_envs,), dtype=bool) * done, + infos, ) def reset(self): + self.n_steps = 0 return {"rgb": np.zeros((self.num_envs, 86, 86))} def render(self, mode="human", close=False): @@ -40,6 +52,16 @@ def test_extract_dict_obs(): env = VecExtractDictObs(env, "rgb") assert env.reset().shape == (4, 86, 86) + for _ in range(10): + obs, _, dones, infos = env.step([env.action_space.sample() for _ in range(env.num_envs)]) + assert obs.shape == (4, 86, 86) + for idx, info in enumerate(infos): + if "terminal_observation" in info: + assert dones[idx] + assert info["terminal_observation"].shape == (86, 86) + else: + assert not dones[idx] + def test_vec_with_ppo(): """