mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-26 22:45:15 +00:00
Fix several VecEnv issues, add fork start method to tests (#43)
* Fix several VecEnv issues, add `fork` start method to tests * Fix signature
This commit is contained in:
parent
403fff5d50
commit
353ea81080
6 changed files with 59 additions and 19 deletions
|
|
@ -3,6 +3,34 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
|
||||
Pre-Release 0.7.0a0 (WIP)
|
||||
------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- ``render()`` method of ``VecEnvs`` now only accept one argument: ``mode``
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed ``render()`` method for ``VecEnvs``
|
||||
- Fixed ``seed()``` method for ``SubprocVecEnv``
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Re-enable unsafe ``fork`` start method in the tests (was causing a deadlock with tensorflow)
|
||||
- Added a test for seeding ``SubprocVecEnv``` and rendering
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Pre-Release 0.6.0 (2020-06-01)
|
||||
------------------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -162,22 +162,22 @@ class VecEnv(ABC):
|
|||
self.step_async(actions)
|
||||
return self.step_wait()
|
||||
|
||||
def get_images(self, *args, **kwargs) -> Sequence[np.ndarray]:
|
||||
def get_images(self) -> Sequence[np.ndarray]:
|
||||
"""
|
||||
Return RGB images from each environment
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def render(self, *args, mode: str = 'human', **kwargs):
|
||||
def render(self, mode: str = 'human'):
|
||||
"""
|
||||
Gym environment rendering
|
||||
|
||||
:param mode: the rendering type
|
||||
"""
|
||||
try:
|
||||
imgs = self.get_images(*args, **kwargs)
|
||||
imgs = self.get_images()
|
||||
except NotImplementedError:
|
||||
logger.warn('Render not defined for {}'.format(self))
|
||||
logger.warn(f'Render not defined for {self}')
|
||||
return
|
||||
|
||||
# Create a big image by tiling images from subprocesses
|
||||
|
|
@ -189,7 +189,7 @@ class VecEnv(ABC):
|
|||
elif mode == 'rgb_array':
|
||||
return bigimg
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError(f'Render mode {mode} is not supported by VecEnvs')
|
||||
|
||||
@abstractmethod
|
||||
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
||||
|
|
@ -268,8 +268,8 @@ class VecEnvWrapper(VecEnv):
|
|||
def close(self):
|
||||
return self.venv.close()
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
return self.venv.render(*args, **kwargs)
|
||||
def render(self, mode: str = 'human'):
|
||||
return self.venv.render(mode=mode)
|
||||
|
||||
def get_images(self):
|
||||
return self.venv.get_images()
|
||||
|
|
|
|||
|
|
@ -66,10 +66,10 @@ class DummyVecEnv(VecEnv):
|
|||
for env in self.envs:
|
||||
env.close()
|
||||
|
||||
def get_images(self, *args, **kwargs) -> Sequence[np.ndarray]:
|
||||
return [env.render(*args, mode='rgb_array', **kwargs) for env in self.envs]
|
||||
def get_images(self) -> Sequence[np.ndarray]:
|
||||
return [env.render(mode='rgb_array') for env in self.envs]
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
def render(self, mode: str = 'human'):
|
||||
"""
|
||||
Gym environment rendering. If there are multiple environments then
|
||||
they are tiled together in one image via ``BaseVecEnv.render()``.
|
||||
|
|
@ -82,9 +82,9 @@ class DummyVecEnv(VecEnv):
|
|||
:param mode: The rendering type.
|
||||
"""
|
||||
if self.num_envs == 1:
|
||||
return self.envs[0].render(*args, **kwargs)
|
||||
return self.envs[0].render(mode=mode)
|
||||
else:
|
||||
return super().render(*args, **kwargs)
|
||||
return super().render(mode=mode)
|
||||
|
||||
def _save_obs(self, env_idx, obs):
|
||||
for key in self.keys:
|
||||
|
|
|
|||
|
|
@ -21,11 +21,13 @@ def _worker(remote, parent_remote, env_fn_wrapper):
|
|||
info['terminal_observation'] = observation
|
||||
observation = env.reset()
|
||||
remote.send((observation, reward, done, info))
|
||||
elif cmd == 'seed':
|
||||
remote.send(env.seed(data))
|
||||
elif cmd == 'reset':
|
||||
observation = env.reset()
|
||||
remote.send(observation)
|
||||
elif cmd == 'render':
|
||||
remote.send(env.render(*data[0], **data[1]))
|
||||
remote.send(env.render(data))
|
||||
elif cmd == 'close':
|
||||
remote.close()
|
||||
break
|
||||
|
|
@ -39,7 +41,7 @@ def _worker(remote, parent_remote, env_fn_wrapper):
|
|||
elif cmd == 'set_attr':
|
||||
remote.send(setattr(env, data[0], data[1]))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
|
|
@ -129,11 +131,11 @@ class SubprocVecEnv(VecEnv):
|
|||
process.join()
|
||||
self.closed = True
|
||||
|
||||
def get_images(self, *args, **kwargs) -> Sequence[np.ndarray]:
|
||||
def get_images(self) -> Sequence[np.ndarray]:
|
||||
for pipe in self.remotes:
|
||||
# gather images from subprocesses
|
||||
# `mode` will be taken into account later
|
||||
pipe.send(('render', (args, {'mode': 'rgb_array', **kwargs})))
|
||||
pipe.send(('render', 'rgb_array'))
|
||||
imgs = [pipe.recv() for pipe in self.remotes]
|
||||
return imgs
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
0.6.0
|
||||
0.7.0a0
|
||||
|
|
|
|||
|
|
@ -40,6 +40,10 @@ class CustomGymEnv(gym.Env):
|
|||
self.state = self.observation_space.sample()
|
||||
|
||||
def render(self, mode='human'):
|
||||
if mode == 'rgb_array':
|
||||
return np.zeros((4, 4, 3))
|
||||
|
||||
def seed(self, seed=None):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -71,6 +75,11 @@ def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper):
|
|||
else:
|
||||
vec_env = vec_env_wrapper(vec_env)
|
||||
|
||||
# Test seed method
|
||||
vec_env.seed(0)
|
||||
# Test render method call
|
||||
# vec_env.render() # we need a X server to test the "human" mode
|
||||
vec_env.render(mode='rgb_array')
|
||||
env_method_results = vec_env.env_method('custom_method', 1, indices=None, dim_1=2)
|
||||
setattr_results = []
|
||||
# Set current_step to an arbitrary value
|
||||
|
|
@ -271,9 +280,10 @@ def test_vecenv_tuple_spaces(vec_env_class):
|
|||
def test_subproc_start_method():
|
||||
start_methods = [None]
|
||||
# Only test thread-safe methods. Others may deadlock tests! (gh/428)
|
||||
safe_methods = {'forkserver', 'spawn'}
|
||||
# Note: adding unsafe `fork` method as we are now using PyTorch
|
||||
all_methods = {'forkserver', 'spawn', 'fork'}
|
||||
available_methods = multiprocessing.get_all_start_methods()
|
||||
start_methods += list(safe_methods.intersection(available_methods))
|
||||
start_methods += list(all_methods.intersection(available_methods))
|
||||
space = gym.spaces.Discrete(2)
|
||||
|
||||
def obs_assert(obs):
|
||||
|
|
|
|||
Loading…
Reference in a new issue