mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-31 23:28:05 +00:00
Test with async reset
This commit is contained in:
parent
d81cf71057
commit
bfcec7f697
1 changed files with 16 additions and 2 deletions
|
|
@ -19,9 +19,10 @@ class DummyVecEnv(VecEnv):
|
|||
|
||||
:param env_fns: a list of functions
|
||||
that return environments to vectorize
|
||||
:param async_reset: Set artificial end of episodes to all env at the same time
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]):
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]], async_reset: bool = True):
|
||||
self.envs = [fn() for fn in env_fns]
|
||||
env = self.envs[0]
|
||||
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
|
||||
|
|
@ -34,6 +35,7 @@ class DummyVecEnv(VecEnv):
|
|||
self.buf_infos = [{} for _ in range(self.num_envs)]
|
||||
self.actions = None
|
||||
self.metadata = env.metadata
|
||||
self.async_reset = async_reset
|
||||
|
||||
def step_async(self, actions: np.ndarray) -> None:
|
||||
self.actions = actions
|
||||
|
|
@ -43,11 +45,23 @@ class DummyVecEnv(VecEnv):
|
|||
obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step(
|
||||
self.actions[env_idx]
|
||||
)
|
||||
if self.buf_dones[env_idx]:
|
||||
if self.buf_dones[env_idx] and not self.async_reset:
|
||||
# save final observation where user can get it, then reset
|
||||
self.buf_infos[env_idx]["terminal_observation"] = obs
|
||||
obs = self.envs[env_idx].reset()
|
||||
self._save_obs(env_idx, obs)
|
||||
|
||||
if self.buf_dones.any() and self.async_reset:
|
||||
for env_idx in range(self.num_envs):
|
||||
if not self.buf_dones[env_idx]:
|
||||
self.buf_infos[env_idx]["terminal_observation"] = self.buf_obs[None][env_idx]
|
||||
self.buf_dones[env_idx] = True
|
||||
self.buf_infos[env_idx]["TimeLimit.truncated"] = True
|
||||
else:
|
||||
self.buf_infos[env_idx]["terminal_observation"] = self.buf_obs[None][env_idx]
|
||||
obs = self.envs[env_idx].reset()
|
||||
self._save_obs(env_idx, obs)
|
||||
|
||||
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
||||
|
|
|
|||
Loading…
Reference in a new issue