Test with async reset

This commit is contained in:
Antonin Raffin 2021-05-30 14:40:23 +02:00
parent d81cf71057
commit bfcec7f697

View file

@ -19,9 +19,10 @@ class DummyVecEnv(VecEnv):
:param env_fns: a list of functions
that return environments to vectorize
:param async_reset: Set artificial end of episodes to all env at the same time
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]]):
def __init__(self, env_fns: List[Callable[[], gym.Env]], async_reset: bool = True):
self.envs = [fn() for fn in env_fns]
env = self.envs[0]
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
@ -34,6 +35,7 @@ class DummyVecEnv(VecEnv):
self.buf_infos = [{} for _ in range(self.num_envs)]
self.actions = None
self.metadata = env.metadata
self.async_reset = async_reset
def step_async(self, actions: np.ndarray) -> None:
self.actions = actions
@ -43,11 +45,23 @@ class DummyVecEnv(VecEnv):
obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step(
self.actions[env_idx]
)
if self.buf_dones[env_idx]:
if self.buf_dones[env_idx] and not self.async_reset:
# save final observation where user can get it, then reset
self.buf_infos[env_idx]["terminal_observation"] = obs
obs = self.envs[env_idx].reset()
self._save_obs(env_idx, obs)
if self.buf_dones.any() and self.async_reset:
for env_idx in range(self.num_envs):
if not self.buf_dones[env_idx]:
self.buf_infos[env_idx]["terminal_observation"] = self.buf_obs[None][env_idx]
self.buf_dones[env_idx] = True
self.buf_infos[env_idx]["TimeLimit.truncated"] = True
else:
self.buf_infos[env_idx]["terminal_observation"] = self.buf_obs[None][env_idx]
obs = self.envs[env_idx].reset()
self._save_obs(env_idx, obs)
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: