From bfcec7f69723ecfb6070b7f8aa1c9843edbf8b34 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 30 May 2021 14:40:23 +0200 Subject: [PATCH] Test with async reset --- .../common/vec_env/dummy_vec_env.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 5eb87cd..d0267a1 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -19,9 +19,10 @@ class DummyVecEnv(VecEnv): :param env_fns: a list of functions that return environments to vectorize + :param async_reset: Set artificial end of episodes to all env at the same time """ - def __init__(self, env_fns: List[Callable[[], gym.Env]]): + def __init__(self, env_fns: List[Callable[[], gym.Env]], async_reset: bool = True): self.envs = [fn() for fn in env_fns] env = self.envs[0] VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) @@ -34,6 +35,7 @@ class DummyVecEnv(VecEnv): self.buf_infos = [{} for _ in range(self.num_envs)] self.actions = None self.metadata = env.metadata + self.async_reset = async_reset def step_async(self, actions: np.ndarray) -> None: self.actions = actions @@ -43,11 +45,23 @@ class DummyVecEnv(VecEnv): obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step( self.actions[env_idx] ) - if self.buf_dones[env_idx]: + if self.buf_dones[env_idx] and not self.async_reset: # save final observation where user can get it, then reset self.buf_infos[env_idx]["terminal_observation"] = obs obs = self.envs[env_idx].reset() self._save_obs(env_idx, obs) + + if self.buf_dones.any() and self.async_reset: + for env_idx in range(self.num_envs): + if not self.buf_dones[env_idx]: + self.buf_infos[env_idx]["terminal_observation"] = self.buf_obs[None][env_idx] + self.buf_dones[env_idx] = True + self.buf_infos[env_idx]["TimeLimit.truncated"] = True + else: + self.buf_infos[env_idx]["terminal_observation"] = self.buf_obs[None][env_idx] + obs = self.envs[env_idx].reset() + self._save_obs(env_idx, obs) + return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: