mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Support for VecMonitor for gym3-style environments (#311)
* add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
1ed15bf6ee
commit
ddbe0e93f9
13 changed files with 424 additions and 36 deletions
|
|
@ -642,6 +642,30 @@ A2C policy gradient updates on the model.
|
|||
print(f"Best fitness: {top_candidates[0][1]:.2f}")
|
||||
|
||||
|
||||
SB3 and ProcgenEnv
|
||||
------------------
|
||||
|
||||
Some environments like `Procgen <https://github.com/openai/procgen>`_ already produce a vectorized
|
||||
environment (see discussion in `issue #314 <https://github.com/DLR-RM/stable-baselines3/issues/314>`_). In order to use it with SB3, you must wrap it in a ``VecMonitor`` wrapper which will also allow
|
||||
to keep track of the agent progress.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from procgen import ProcgenEnv
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
|
||||
|
||||
# ProcgenEnv is already vectorized
|
||||
venv = ProcgenEnv(num_envs=2, env_name='starpilot')
|
||||
# PPO does not currently support Dict observations
|
||||
# this will be solved in https://github.com/DLR-RM/stable-baselines3/pull/243
|
||||
venv = VecExtractDictObs(venv, "rgb")
|
||||
venv = VecMonitor(venv=venv)
|
||||
|
||||
model = PPO("MlpPolicy", venv, verbose=1)
|
||||
model.learn(10000)
|
||||
|
||||
|
||||
Record a Video
|
||||
--------------
|
||||
|
|
|
|||
|
|
@ -27,14 +27,22 @@ SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️
|
|||
|
||||
When using vectorized environments, the environments are automatically reset at the end of each episode.
|
||||
Thus, the observation returned for the i-th environment when ``done[i]`` is true will in fact be the first observation of the next episode, not the last observation of the episode that has just terminated.
|
||||
You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the vecenv.
|
||||
You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the ``VecEnv``.
|
||||
|
||||
|
||||
.. warning::
|
||||
|
||||
When using ``SubprocVecEnv``, users must wrap the code in an ``if __name__ == "__main__":`` if using the ``forkserver`` or ``spawn`` start method (default on Windows).
|
||||
On Linux, the default start method is ``fork`` which is not thread safe and can create deadlocks.
|
||||
When defining a custom ``VecEnv`` (for instance, using gym3 ``ProcgenEnv``), you should provide ``terminal_observation`` keys in the info dicts returned by the ``VecEnv``
|
||||
(cf. note above).
|
||||
|
||||
|
||||
.. warning::
|
||||
|
||||
When using ``SubprocVecEnv``, users must wrap the code in an ``if __name__ == "__main__":`` if using the ``forkserver`` or ``spawn`` start method (default on Windows).
|
||||
On Linux, the default start method is ``fork`` which is not thread safe and can create deadlocks.
|
||||
|
||||
For more information, see Python's `multiprocessing guidelines <https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods>`_.
|
||||
|
||||
For more information, see Python's `multiprocessing guidelines <https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods>`_.
|
||||
|
||||
VecEnv
|
||||
------
|
||||
|
|
@ -90,3 +98,15 @@ VecTransposeImage
|
|||
|
||||
.. autoclass:: VecTransposeImage
|
||||
:members:
|
||||
|
||||
VecMonitor
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: VecMonitor
|
||||
:members:
|
||||
|
||||
VecExtractDictObs
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: VecExtractDictObs
|
||||
:members:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.1.0a1 (WIP)
|
||||
Release 1.1.0a2 (WIP)
|
||||
---------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -12,6 +12,11 @@ Breaking Changes:
|
|||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added `VecMonitor <https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/vec_monitor.py>`_ and
|
||||
`VecExtractDictObs <https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/vec_extract_dict_obs.py>`_ wrappers
|
||||
to handle gym3-style vectorized environments (@vwxyzjn)
|
||||
- Ignored the terminal observation if the it is not provided by the environment
|
||||
such as the gym3-style vectorized environments. (@vwxyzjn)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -33,6 +38,8 @@ Documentation:
|
|||
- Clarify channel-first/channel-last recommendation
|
||||
- Update sphinx environment installation instructions (@tom-doerr)
|
||||
- Clarify pip installation in Zsh (@tom-doerr)
|
||||
- Added example for using ``ProcgenEnv``
|
||||
|
||||
|
||||
Release 1.0 (2021-03-15)
|
||||
------------------------
|
||||
|
|
@ -54,6 +61,7 @@ New Features:
|
|||
^^^^^^^^^^^^^
|
||||
- Added support for ``custom_objects`` when loading models
|
||||
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed a bug with ``DQN`` predict method when using ``deterministic=False`` with image space
|
||||
|
|
@ -640,5 +648,5 @@ And all the contributors:
|
|||
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
|
||||
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
|
||||
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
|
||||
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8
|
||||
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
|
||||
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import gym
|
|||
import numpy as np
|
||||
|
||||
from stable_baselines3.common import base_class
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
from stable_baselines3.common.vec_env import VecEnv, VecMonitor, is_vecenv_wrapped
|
||||
|
||||
|
||||
def evaluate_policy(
|
||||
|
|
@ -57,7 +57,7 @@ def evaluate_policy(
|
|||
|
||||
if isinstance(env, VecEnv):
|
||||
assert env.num_envs == 1, "You must pass only one environment when using this function"
|
||||
is_monitor_wrapped = env.env_is_wrapped(Monitor)[0]
|
||||
is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
|
||||
else:
|
||||
is_monitor_wrapped = is_wrapped(env, Monitor)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
__all__ = ["Monitor", "get_monitor_files", "load_results"]
|
||||
__all__ = ["Monitor", "ResultsWriter", "get_monitor_files", "load_results"]
|
||||
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from glob import glob
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -38,27 +38,20 @@ class Monitor(gym.Wrapper):
|
|||
):
|
||||
super(Monitor, self).__init__(env=env)
|
||||
self.t_start = time.time()
|
||||
if filename is None:
|
||||
self.file_handler = None
|
||||
self.logger = None
|
||||
if filename is not None:
|
||||
self.results_writer = ResultsWriter(
|
||||
filename,
|
||||
header={"t_start": self.t_start, "env_id": env.spec and env.spec.id},
|
||||
extra_keys=reset_keywords + info_keywords,
|
||||
)
|
||||
else:
|
||||
if not filename.endswith(Monitor.EXT):
|
||||
if os.path.isdir(filename):
|
||||
filename = os.path.join(filename, Monitor.EXT)
|
||||
else:
|
||||
filename = filename + "." + Monitor.EXT
|
||||
self.file_handler = open(filename, "wt")
|
||||
self.file_handler.write("#%s\n" % json.dumps({"t_start": self.t_start, "env_id": env.spec and env.spec.id}))
|
||||
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + reset_keywords + info_keywords)
|
||||
self.logger.writeheader()
|
||||
self.file_handler.flush()
|
||||
|
||||
self.results_writer = None
|
||||
self.reset_keywords = reset_keywords
|
||||
self.info_keywords = info_keywords
|
||||
self.allow_early_resets = allow_early_resets
|
||||
self.rewards = None
|
||||
self.needs_reset = True
|
||||
self.episode_rewards = []
|
||||
self.episode_returns = []
|
||||
self.episode_lengths = []
|
||||
self.episode_times = []
|
||||
self.total_steps = 0
|
||||
|
|
@ -81,7 +74,7 @@ class Monitor(gym.Wrapper):
|
|||
for key in self.reset_keywords:
|
||||
value = kwargs.get(key)
|
||||
if value is None:
|
||||
raise ValueError("Expected you to pass kwarg {} into reset".format(key))
|
||||
raise ValueError(f"Expected you to pass keyword argument {key} into reset")
|
||||
self.current_reset_info[key] = value
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
|
|
@ -103,13 +96,12 @@ class Monitor(gym.Wrapper):
|
|||
ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.t_start, 6)}
|
||||
for key in self.info_keywords:
|
||||
ep_info[key] = info[key]
|
||||
self.episode_rewards.append(ep_rew)
|
||||
self.episode_returns.append(ep_rew)
|
||||
self.episode_lengths.append(ep_len)
|
||||
self.episode_times.append(time.time() - self.t_start)
|
||||
ep_info.update(self.current_reset_info)
|
||||
if self.logger:
|
||||
self.logger.writerow(ep_info)
|
||||
self.file_handler.flush()
|
||||
if self.results_writer:
|
||||
self.results_writer.write_row(ep_info)
|
||||
info["episode"] = ep_info
|
||||
self.total_steps += 1
|
||||
return observation, reward, done, info
|
||||
|
|
@ -119,8 +111,8 @@ class Monitor(gym.Wrapper):
|
|||
Closes the environment
|
||||
"""
|
||||
super(Monitor, self).close()
|
||||
if self.file_handler is not None:
|
||||
self.file_handler.close()
|
||||
if self.results_writer is not None:
|
||||
self.results_writer.close()
|
||||
|
||||
def get_total_steps(self) -> int:
|
||||
"""
|
||||
|
|
@ -136,7 +128,7 @@ class Monitor(gym.Wrapper):
|
|||
|
||||
:return:
|
||||
"""
|
||||
return self.episode_rewards
|
||||
return self.episode_returns
|
||||
|
||||
def get_episode_lengths(self) -> List[int]:
|
||||
"""
|
||||
|
|
@ -163,6 +155,52 @@ class LoadMonitorResultsError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class ResultsWriter:
|
||||
"""
|
||||
A result writer that saves the data from the `Monitor` class
|
||||
|
||||
:param filename: the location to save a log file, can be None for no log
|
||||
:param header: the header dictionary object of the saved csv
|
||||
:param reset_keywords: the extra information to log, typically is composed of
|
||||
``reset_keywords`` and ``info_keywords``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filename: str = "",
|
||||
header: Dict[str, Union[float, str]] = None,
|
||||
extra_keys: Tuple[str, ...] = (),
|
||||
):
|
||||
if header is None:
|
||||
header = {}
|
||||
if not filename.endswith(Monitor.EXT):
|
||||
if os.path.isdir(filename):
|
||||
filename = os.path.join(filename, Monitor.EXT)
|
||||
else:
|
||||
filename = filename + "." + Monitor.EXT
|
||||
self.file_handler = open(filename, "wt")
|
||||
self.file_handler.write("#%s\n" % json.dumps(header))
|
||||
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + extra_keys)
|
||||
self.logger.writeheader()
|
||||
self.file_handler.flush()
|
||||
|
||||
def write_row(self, epinfo: Dict[str, Union[float, int]]) -> None:
|
||||
"""
|
||||
Close the file handler
|
||||
|
||||
:param epinfo: the information on episodic return, length, and time
|
||||
"""
|
||||
if self.logger:
|
||||
self.logger.writerow(epinfo)
|
||||
self.file_handler.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the file handler
|
||||
"""
|
||||
self.file_handler.close()
|
||||
|
||||
|
||||
def get_monitor_files(path: str) -> List[str]:
|
||||
"""
|
||||
get all the monitor files in the given path
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@ from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, Ve
|
|||
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||
from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
|
||||
from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs
|
||||
from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
|
||||
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
|
||||
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
|
||||
from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage
|
||||
from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
|
||||
|
|
|
|||
24
stable_baselines3/common/vec_env/vec_extract_dict_obs.py
Normal file
24
stable_baselines3/common/vec_env/vec_extract_dict_obs.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
import numpy as np
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
|
||||
|
||||
|
||||
class VecExtractDictObs(VecEnvWrapper):
|
||||
"""
|
||||
A vectorized wrapper for extracting dictionary observations.
|
||||
|
||||
:param venv: The vectorized environment
|
||||
:param key: The key of the dictionary observation
|
||||
"""
|
||||
|
||||
def __init__(self, venv: VecEnv, key: str):
|
||||
self.key = key
|
||||
super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key])
|
||||
|
||||
def reset(self) -> np.ndarray:
|
||||
obs = self.venv.reset()
|
||||
return obs[self.key]
|
||||
|
||||
def step_wait(self) -> VecEnvStepReturn:
|
||||
obs, reward, done, info = self.venv.step_wait()
|
||||
return obs[self.key], reward, done, info
|
||||
98
stable_baselines3/common/vec_env/vec_monitor.py
Normal file
98
stable_baselines3/common/vec_env/vec_monitor.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
import time
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
|
||||
|
||||
|
||||
class VecMonitor(VecEnvWrapper):
|
||||
"""
|
||||
A vectorized monitor wrapper for *vectorized* Gym environments,
|
||||
it is used to record the episode reward, length, time and other data.
|
||||
|
||||
Some environments like `openai/procgen <https://github.com/openai/procgen>`_
|
||||
or `gym3 <https://github.com/openai/gym3>`_ directly initialize the
|
||||
vectorized environments, without giving us a chance to use the ``Monitor``
|
||||
wrapper. So this class simply does the job of the ``Monitor`` wrapper on
|
||||
a vectorized level.
|
||||
|
||||
:param venv: The vectorized environment
|
||||
:param filename: the location to save a log file, can be None for no log
|
||||
:param info_keywords: extra information to log, from the information return of env.step()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
venv: VecEnv,
|
||||
filename: Optional[str] = None,
|
||||
info_keywords: Tuple[str, ...] = (),
|
||||
):
|
||||
# Avoid circular import
|
||||
from stable_baselines3.common.monitor import Monitor, ResultsWriter
|
||||
|
||||
# This check is not valid for special `VecEnv`
|
||||
# like the ones created by Procgen, that does follow completely
|
||||
# the `VecEnv` interface
|
||||
try:
|
||||
is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0]
|
||||
except AttributeError:
|
||||
is_wrapped_with_monitor = False
|
||||
|
||||
if is_wrapped_with_monitor:
|
||||
warnings.warn(
|
||||
"The environment is already wrapped with a `Monitor` wrapper"
|
||||
"but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be"
|
||||
"overwritten by the `VecMonitor` ones.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
VecEnvWrapper.__init__(self, venv)
|
||||
self.episode_returns = None
|
||||
self.episode_lengths = None
|
||||
self.episode_count = 0
|
||||
self.t_start = time.time()
|
||||
|
||||
env_id = None
|
||||
if hasattr(venv, "spec") and venv.spec is not None:
|
||||
env_id = venv.spec.id
|
||||
|
||||
if filename:
|
||||
self.results_writer = ResultsWriter(
|
||||
filename, header={"t_start": self.t_start, "env_id": env_id}, extra_keys=info_keywords
|
||||
)
|
||||
else:
|
||||
self.results_writer = None
|
||||
self.info_keywords = info_keywords
|
||||
|
||||
def reset(self) -> VecEnvObs:
|
||||
obs = self.venv.reset()
|
||||
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
|
||||
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
|
||||
return obs
|
||||
|
||||
def step_wait(self) -> VecEnvStepReturn:
|
||||
obs, rewards, dones, infos = self.venv.step_wait()
|
||||
self.episode_returns += rewards
|
||||
self.episode_lengths += 1
|
||||
new_infos = list(infos[:])
|
||||
for i in range(len(dones)):
|
||||
if dones[i]:
|
||||
info = infos[i].copy()
|
||||
episode_return = self.episode_returns[i]
|
||||
episode_length = self.episode_lengths[i]
|
||||
episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)}
|
||||
info["episode"] = episode_info
|
||||
self.episode_count += 1
|
||||
self.episode_returns[i] = 0
|
||||
self.episode_lengths[i] = 0
|
||||
if self.results_writer:
|
||||
self.results_writer.write_row(episode_info)
|
||||
new_infos[i] = info
|
||||
return obs, rewards, dones, new_infos
|
||||
|
||||
def close(self) -> None:
|
||||
if self.results_writer:
|
||||
self.results_writer.close()
|
||||
return self.venv.close()
|
||||
|
|
@ -131,7 +131,8 @@ class VecNormalize(VecEnvWrapper):
|
|||
for idx, done in enumerate(dones):
|
||||
if not done:
|
||||
continue
|
||||
infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"])
|
||||
if "terminal_observation" in infos[idx]:
|
||||
infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"])
|
||||
|
||||
self.ret[dones] = 0
|
||||
return obs, rewards, dones, infos
|
||||
|
|
|
|||
|
|
@ -51,7 +51,8 @@ class VecTransposeImage(VecEnvWrapper):
|
|||
for idx, done in enumerate(dones):
|
||||
if not done:
|
||||
continue
|
||||
infos[idx]["terminal_observation"] = self.transpose_image(infos[idx]["terminal_observation"])
|
||||
if "terminal_observation" in infos[idx]:
|
||||
infos[idx]["terminal_observation"] = self.transpose_image(infos[idx]["terminal_observation"])
|
||||
|
||||
return self.transpose_image(observations), rewards, dones, infos
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.1.0a1
|
||||
1.1.0a2
|
||||
|
|
|
|||
52
tests/test_vec_extract_dict_obs.py
Normal file
52
tests/test_vec_extract_dict_obs.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
|
||||
|
||||
|
||||
class DictObsVecEnv:
|
||||
"""Custom Environment that produces observation in a dictionary like the procgen env"""
|
||||
|
||||
metadata = {"render.modes": ["human"]}
|
||||
|
||||
def __init__(self):
|
||||
self.num_envs = 4
|
||||
self.action_space = spaces.Discrete(2)
|
||||
self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)})
|
||||
|
||||
def step_async(self, actions):
|
||||
self.actions = actions
|
||||
|
||||
def step_wait(self):
|
||||
return (
|
||||
{"rgb": np.zeros((self.num_envs, 86, 86))},
|
||||
np.zeros((self.num_envs,)),
|
||||
np.zeros((self.num_envs,), dtype=bool),
|
||||
[{} for _ in range(self.num_envs)],
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
return {"rgb": np.zeros((self.num_envs, 86, 86))}
|
||||
|
||||
def render(self, mode="human", close=False):
|
||||
pass
|
||||
|
||||
|
||||
def test_extract_dict_obs():
|
||||
"""Test VecExtractDictObs"""
|
||||
|
||||
env = DictObsVecEnv()
|
||||
env = VecExtractDictObs(env, "rgb")
|
||||
assert env.reset().shape == (4, 86, 86)
|
||||
|
||||
|
||||
def test_vec_with_ppo():
|
||||
"""
|
||||
Test the `VecExtractDictObs` with PPO
|
||||
"""
|
||||
env = DictObsVecEnv()
|
||||
env = VecExtractDictObs(env, "rgb")
|
||||
monitor_env = VecMonitor(env)
|
||||
model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu")
|
||||
model.learn(total_timesteps=250)
|
||||
120
tests/test_vec_monitor.py
Normal file
120
tests/test_vec_monitor.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
import json
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import gym
|
||||
import pandas
|
||||
import pytest
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.monitor import Monitor, get_monitor_files, load_results
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor, VecNormalize
|
||||
|
||||
|
||||
def test_vec_monitor(tmp_path):
|
||||
"""
|
||||
Test the `VecMonitor` wrapper
|
||||
"""
|
||||
env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
|
||||
env.seed(0)
|
||||
monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
|
||||
monitor_env = VecMonitor(env, monitor_file)
|
||||
monitor_env.reset()
|
||||
total_steps = 1000
|
||||
ep_len, ep_reward = 0, 0
|
||||
for _ in range(total_steps):
|
||||
_, rewards, dones, infos = monitor_env.step([monitor_env.action_space.sample()])
|
||||
ep_len += 1
|
||||
ep_reward += rewards[0]
|
||||
if dones[0]:
|
||||
assert ep_reward == infos[0]["episode"]["r"]
|
||||
assert ep_len == infos[0]["episode"]["l"]
|
||||
ep_len, ep_reward = 0, 0
|
||||
|
||||
monitor_env.close()
|
||||
|
||||
with open(monitor_file, "rt") as file_handler:
|
||||
first_line = file_handler.readline()
|
||||
assert first_line.startswith("#")
|
||||
metadata = json.loads(first_line[1:])
|
||||
assert set(metadata.keys()) == {"t_start", "env_id"}, "Incorrect keys in monitor metadata"
|
||||
|
||||
last_logline = pandas.read_csv(file_handler, index_col=None)
|
||||
assert set(last_logline.keys()) == {"l", "t", "r"}, "Incorrect keys in monitor logline"
|
||||
os.remove(monitor_file)
|
||||
|
||||
|
||||
def test_vec_monitor_load_results(tmp_path):
|
||||
"""
|
||||
test load_results on log files produced by the monitor wrapper
|
||||
"""
|
||||
tmp_path = str(tmp_path)
|
||||
env1 = DummyVecEnv([lambda: gym.make("CartPole-v1")])
|
||||
env1.seed(0)
|
||||
monitor_file1 = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
|
||||
monitor_env1 = VecMonitor(env1, monitor_file1)
|
||||
|
||||
monitor_files = get_monitor_files(tmp_path)
|
||||
assert len(monitor_files) == 1
|
||||
assert monitor_file1 in monitor_files
|
||||
|
||||
monitor_env1.reset()
|
||||
episode_count1 = 0
|
||||
for _ in range(1000):
|
||||
_, _, dones, _ = monitor_env1.step([monitor_env1.action_space.sample()])
|
||||
if dones[0]:
|
||||
episode_count1 += 1
|
||||
monitor_env1.reset()
|
||||
|
||||
results_size1 = len(load_results(os.path.join(tmp_path)).index)
|
||||
assert results_size1 == episode_count1
|
||||
|
||||
env2 = DummyVecEnv([lambda: gym.make("CartPole-v1")])
|
||||
env2.seed(0)
|
||||
monitor_file2 = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
|
||||
monitor_env2 = VecMonitor(env2, monitor_file2)
|
||||
monitor_files = get_monitor_files(tmp_path)
|
||||
assert len(monitor_files) == 2
|
||||
assert monitor_file1 in monitor_files
|
||||
assert monitor_file2 in monitor_files
|
||||
|
||||
monitor_env2.reset()
|
||||
episode_count2 = 0
|
||||
for _ in range(1000):
|
||||
_, _, dones, _ = monitor_env2.step([monitor_env2.action_space.sample()])
|
||||
if dones[0]:
|
||||
episode_count2 += 1
|
||||
monitor_env2.reset()
|
||||
|
||||
results_size2 = len(load_results(os.path.join(tmp_path)).index)
|
||||
|
||||
assert results_size2 == (results_size1 + episode_count2)
|
||||
|
||||
os.remove(monitor_file1)
|
||||
os.remove(monitor_file2)
|
||||
|
||||
|
||||
def test_vec_monitor_ppo(recwarn):
|
||||
"""
|
||||
Test the `VecMonitor` with PPO
|
||||
"""
|
||||
env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
|
||||
env.seed(0)
|
||||
monitor_env = VecMonitor(env)
|
||||
model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu")
|
||||
model.learn(total_timesteps=250)
|
||||
|
||||
# No warnings because using `VecMonitor`
|
||||
evaluate_policy(model, monitor_env)
|
||||
assert len(recwarn) == 0
|
||||
|
||||
|
||||
def test_vec_monitor_warn():
|
||||
env = DummyVecEnv([lambda: Monitor(gym.make("CartPole-v1"))])
|
||||
# We should warn the user when the env is already wrapped with a Monitor wrapper
|
||||
with pytest.warns(UserWarning):
|
||||
VecMonitor(env)
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
VecMonitor(VecNormalize(env))
|
||||
Loading…
Reference in a new issue