Support for VecMonitor for gym3-style environments (#311)

* add vectorized monitor

* auto format of the code

* add documentation and VecExtractDictObs

* refactor and add test cases

* add test cases and format

* avoid circular import and fix doc

* fix type

* fix type

* oops

* Update stable_baselines3/common/monitor.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update stable_baselines3/common/monitor.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* add test cases

* update changelog

* fix mutable argument

* quick fix

* Apply suggestions from code review

* fix terminal observation for gym3 envs

* delete comment

* Update doc and bump version

* Add warning when already using `Monitor` wrapper

* Update vecmonitor tests

* Fixes

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Costa Huang 2021-04-13 12:09:31 -04:00 committed by GitHub
parent 1ed15bf6ee
commit ddbe0e93f9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 424 additions and 36 deletions

View file

@ -642,6 +642,30 @@ A2C policy gradient updates on the model.
print(f"Best fitness: {top_candidates[0][1]:.2f}")
SB3 and ProcgenEnv
------------------
Some environments like `Procgen <https://github.com/openai/procgen>`_ already produce a vectorized
environment (see discussion in `issue #314 <https://github.com/DLR-RM/stable-baselines3/issues/314>`_). In order to use it with SB3, you must wrap it in a ``VecMonitor`` wrapper which will also allow
to keep track of the agent progress.
.. code-block:: python
from procgen import ProcgenEnv
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
# ProcgenEnv is already vectorized
venv = ProcgenEnv(num_envs=2, env_name='starpilot')
# PPO does not currently support Dict observations
# this will be solved in https://github.com/DLR-RM/stable-baselines3/pull/243
venv = VecExtractDictObs(venv, "rgb")
venv = VecMonitor(venv=venv)
model = PPO("MlpPolicy", venv, verbose=1)
model.learn(10000)
Record a Video
--------------

View file

@ -27,14 +27,22 @@ SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️
When using vectorized environments, the environments are automatically reset at the end of each episode.
Thus, the observation returned for the i-th environment when ``done[i]`` is true will in fact be the first observation of the next episode, not the last observation of the episode that has just terminated.
You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the vecenv.
You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the ``VecEnv``.
.. warning::
When using ``SubprocVecEnv``, users must wrap the code in an ``if __name__ == "__main__":`` if using the ``forkserver`` or ``spawn`` start method (default on Windows).
On Linux, the default start method is ``fork`` which is not thread safe and can create deadlocks.
When defining a custom ``VecEnv`` (for instance, using gym3 ``ProcgenEnv``), you should provide ``terminal_observation`` keys in the info dicts returned by the ``VecEnv``
(cf. note above).
.. warning::
When using ``SubprocVecEnv``, users must wrap the code in an ``if __name__ == "__main__":`` if using the ``forkserver`` or ``spawn`` start method (default on Windows).
On Linux, the default start method is ``fork`` which is not thread safe and can create deadlocks.
For more information, see Python's `multiprocessing guidelines <https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods>`_.
For more information, see Python's `multiprocessing guidelines <https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods>`_.
VecEnv
------
@ -90,3 +98,15 @@ VecTransposeImage
.. autoclass:: VecTransposeImage
:members:
VecMonitor
~~~~~~~~~~~~~~~~~
.. autoclass:: VecMonitor
:members:
VecExtractDictObs
~~~~~~~~~~~~~~~~~
.. autoclass:: VecExtractDictObs
:members:

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.1.0a1 (WIP)
Release 1.1.0a2 (WIP)
---------------------------
Breaking Changes:
@ -12,6 +12,11 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Added `VecMonitor <https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/vec_monitor.py>`_ and
`VecExtractDictObs <https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/vec_extract_dict_obs.py>`_ wrappers
to handle gym3-style vectorized environments (@vwxyzjn)
- Ignored the terminal observation if the it is not provided by the environment
such as the gym3-style vectorized environments. (@vwxyzjn)
Bug Fixes:
^^^^^^^^^^
@ -33,6 +38,8 @@ Documentation:
- Clarify channel-first/channel-last recommendation
- Update sphinx environment installation instructions (@tom-doerr)
- Clarify pip installation in Zsh (@tom-doerr)
- Added example for using ``ProcgenEnv``
Release 1.0 (2021-03-15)
------------------------
@ -54,6 +61,7 @@ New Features:
^^^^^^^^^^^^^
- Added support for ``custom_objects`` when loading models
Bug Fixes:
^^^^^^^^^^
- Fixed a bug with ``DQN`` predict method when using ``deterministic=False`` with image space
@ -640,5 +648,5 @@ And all the contributors:
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr

View file

@ -5,7 +5,7 @@ import gym
import numpy as np
from stable_baselines3.common import base_class
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.vec_env import VecEnv, VecMonitor, is_vecenv_wrapped
def evaluate_policy(
@ -57,7 +57,7 @@ def evaluate_policy(
if isinstance(env, VecEnv):
assert env.num_envs == 1, "You must pass only one environment when using this function"
is_monitor_wrapped = env.env_is_wrapped(Monitor)[0]
is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
else:
is_monitor_wrapped = is_wrapped(env, Monitor)

View file

@ -1,11 +1,11 @@
__all__ = ["Monitor", "get_monitor_files", "load_results"]
__all__ = ["Monitor", "ResultsWriter", "get_monitor_files", "load_results"]
import csv
import json
import os
import time
from glob import glob
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import gym
import numpy as np
@ -38,27 +38,20 @@ class Monitor(gym.Wrapper):
):
super(Monitor, self).__init__(env=env)
self.t_start = time.time()
if filename is None:
self.file_handler = None
self.logger = None
if filename is not None:
self.results_writer = ResultsWriter(
filename,
header={"t_start": self.t_start, "env_id": env.spec and env.spec.id},
extra_keys=reset_keywords + info_keywords,
)
else:
if not filename.endswith(Monitor.EXT):
if os.path.isdir(filename):
filename = os.path.join(filename, Monitor.EXT)
else:
filename = filename + "." + Monitor.EXT
self.file_handler = open(filename, "wt")
self.file_handler.write("#%s\n" % json.dumps({"t_start": self.t_start, "env_id": env.spec and env.spec.id}))
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + reset_keywords + info_keywords)
self.logger.writeheader()
self.file_handler.flush()
self.results_writer = None
self.reset_keywords = reset_keywords
self.info_keywords = info_keywords
self.allow_early_resets = allow_early_resets
self.rewards = None
self.needs_reset = True
self.episode_rewards = []
self.episode_returns = []
self.episode_lengths = []
self.episode_times = []
self.total_steps = 0
@ -81,7 +74,7 @@ class Monitor(gym.Wrapper):
for key in self.reset_keywords:
value = kwargs.get(key)
if value is None:
raise ValueError("Expected you to pass kwarg {} into reset".format(key))
raise ValueError(f"Expected you to pass keyword argument {key} into reset")
self.current_reset_info[key] = value
return self.env.reset(**kwargs)
@ -103,13 +96,12 @@ class Monitor(gym.Wrapper):
ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.t_start, 6)}
for key in self.info_keywords:
ep_info[key] = info[key]
self.episode_rewards.append(ep_rew)
self.episode_returns.append(ep_rew)
self.episode_lengths.append(ep_len)
self.episode_times.append(time.time() - self.t_start)
ep_info.update(self.current_reset_info)
if self.logger:
self.logger.writerow(ep_info)
self.file_handler.flush()
if self.results_writer:
self.results_writer.write_row(ep_info)
info["episode"] = ep_info
self.total_steps += 1
return observation, reward, done, info
@ -119,8 +111,8 @@ class Monitor(gym.Wrapper):
Closes the environment
"""
super(Monitor, self).close()
if self.file_handler is not None:
self.file_handler.close()
if self.results_writer is not None:
self.results_writer.close()
def get_total_steps(self) -> int:
"""
@ -136,7 +128,7 @@ class Monitor(gym.Wrapper):
:return:
"""
return self.episode_rewards
return self.episode_returns
def get_episode_lengths(self) -> List[int]:
"""
@ -163,6 +155,52 @@ class LoadMonitorResultsError(Exception):
pass
class ResultsWriter:
"""
A result writer that saves the data from the `Monitor` class
:param filename: the location to save a log file, can be None for no log
:param header: the header dictionary object of the saved csv
:param reset_keywords: the extra information to log, typically is composed of
``reset_keywords`` and ``info_keywords``
"""
def __init__(
self,
filename: str = "",
header: Dict[str, Union[float, str]] = None,
extra_keys: Tuple[str, ...] = (),
):
if header is None:
header = {}
if not filename.endswith(Monitor.EXT):
if os.path.isdir(filename):
filename = os.path.join(filename, Monitor.EXT)
else:
filename = filename + "." + Monitor.EXT
self.file_handler = open(filename, "wt")
self.file_handler.write("#%s\n" % json.dumps(header))
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + extra_keys)
self.logger.writeheader()
self.file_handler.flush()
def write_row(self, epinfo: Dict[str, Union[float, int]]) -> None:
"""
Close the file handler
:param epinfo: the information on episodic return, length, and time
"""
if self.logger:
self.logger.writerow(epinfo)
self.file_handler.flush()
def close(self) -> None:
"""
Close the file handler
"""
self.file_handler.close()
def get_monitor_files(path: str) -> List[str]:
"""
get all the monitor files in the given path

View file

@ -7,7 +7,9 @@ from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, Ve
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs
from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage
from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder

View file

@ -0,0 +1,24 @@
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
class VecExtractDictObs(VecEnvWrapper):
"""
A vectorized wrapper for extracting dictionary observations.
:param venv: The vectorized environment
:param key: The key of the dictionary observation
"""
def __init__(self, venv: VecEnv, key: str):
self.key = key
super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key])
def reset(self) -> np.ndarray:
obs = self.venv.reset()
return obs[self.key]
def step_wait(self) -> VecEnvStepReturn:
obs, reward, done, info = self.venv.step_wait()
return obs[self.key], reward, done, info

View file

@ -0,0 +1,98 @@
import time
import warnings
from typing import Optional, Tuple
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
class VecMonitor(VecEnvWrapper):
"""
A vectorized monitor wrapper for *vectorized* Gym environments,
it is used to record the episode reward, length, time and other data.
Some environments like `openai/procgen <https://github.com/openai/procgen>`_
or `gym3 <https://github.com/openai/gym3>`_ directly initialize the
vectorized environments, without giving us a chance to use the ``Monitor``
wrapper. So this class simply does the job of the ``Monitor`` wrapper on
a vectorized level.
:param venv: The vectorized environment
:param filename: the location to save a log file, can be None for no log
:param info_keywords: extra information to log, from the information return of env.step()
"""
def __init__(
self,
venv: VecEnv,
filename: Optional[str] = None,
info_keywords: Tuple[str, ...] = (),
):
# Avoid circular import
from stable_baselines3.common.monitor import Monitor, ResultsWriter
# This check is not valid for special `VecEnv`
# like the ones created by Procgen, that does follow completely
# the `VecEnv` interface
try:
is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0]
except AttributeError:
is_wrapped_with_monitor = False
if is_wrapped_with_monitor:
warnings.warn(
"The environment is already wrapped with a `Monitor` wrapper"
"but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be"
"overwritten by the `VecMonitor` ones.",
UserWarning,
)
VecEnvWrapper.__init__(self, venv)
self.episode_returns = None
self.episode_lengths = None
self.episode_count = 0
self.t_start = time.time()
env_id = None
if hasattr(venv, "spec") and venv.spec is not None:
env_id = venv.spec.id
if filename:
self.results_writer = ResultsWriter(
filename, header={"t_start": self.t_start, "env_id": env_id}, extra_keys=info_keywords
)
else:
self.results_writer = None
self.info_keywords = info_keywords
def reset(self) -> VecEnvObs:
obs = self.venv.reset()
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return obs
def step_wait(self) -> VecEnvStepReturn:
obs, rewards, dones, infos = self.venv.step_wait()
self.episode_returns += rewards
self.episode_lengths += 1
new_infos = list(infos[:])
for i in range(len(dones)):
if dones[i]:
info = infos[i].copy()
episode_return = self.episode_returns[i]
episode_length = self.episode_lengths[i]
episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)}
info["episode"] = episode_info
self.episode_count += 1
self.episode_returns[i] = 0
self.episode_lengths[i] = 0
if self.results_writer:
self.results_writer.write_row(episode_info)
new_infos[i] = info
return obs, rewards, dones, new_infos
def close(self) -> None:
if self.results_writer:
self.results_writer.close()
return self.venv.close()

View file

@ -131,7 +131,8 @@ class VecNormalize(VecEnvWrapper):
for idx, done in enumerate(dones):
if not done:
continue
infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"])
if "terminal_observation" in infos[idx]:
infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"])
self.ret[dones] = 0
return obs, rewards, dones, infos

View file

@ -51,7 +51,8 @@ class VecTransposeImage(VecEnvWrapper):
for idx, done in enumerate(dones):
if not done:
continue
infos[idx]["terminal_observation"] = self.transpose_image(infos[idx]["terminal_observation"])
if "terminal_observation" in infos[idx]:
infos[idx]["terminal_observation"] = self.transpose_image(infos[idx]["terminal_observation"])
return self.transpose_image(observations), rewards, dones, infos

View file

@ -1 +1 @@
1.1.0a1
1.1.0a2

View file

@ -0,0 +1,52 @@
import numpy as np
from gym import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
class DictObsVecEnv:
"""Custom Environment that produces observation in a dictionary like the procgen env"""
metadata = {"render.modes": ["human"]}
def __init__(self):
self.num_envs = 4
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)})
def step_async(self, actions):
self.actions = actions
def step_wait(self):
return (
{"rgb": np.zeros((self.num_envs, 86, 86))},
np.zeros((self.num_envs,)),
np.zeros((self.num_envs,), dtype=bool),
[{} for _ in range(self.num_envs)],
)
def reset(self):
return {"rgb": np.zeros((self.num_envs, 86, 86))}
def render(self, mode="human", close=False):
pass
def test_extract_dict_obs():
"""Test VecExtractDictObs"""
env = DictObsVecEnv()
env = VecExtractDictObs(env, "rgb")
assert env.reset().shape == (4, 86, 86)
def test_vec_with_ppo():
"""
Test the `VecExtractDictObs` with PPO
"""
env = DictObsVecEnv()
env = VecExtractDictObs(env, "rgb")
monitor_env = VecMonitor(env)
model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu")
model.learn(total_timesteps=250)

120
tests/test_vec_monitor.py Normal file
View file

@ -0,0 +1,120 @@
import json
import os
import uuid
import gym
import pandas
import pytest
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor, get_monitor_files, load_results
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor, VecNormalize
def test_vec_monitor(tmp_path):
"""
Test the `VecMonitor` wrapper
"""
env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
env.seed(0)
monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
monitor_env = VecMonitor(env, monitor_file)
monitor_env.reset()
total_steps = 1000
ep_len, ep_reward = 0, 0
for _ in range(total_steps):
_, rewards, dones, infos = monitor_env.step([monitor_env.action_space.sample()])
ep_len += 1
ep_reward += rewards[0]
if dones[0]:
assert ep_reward == infos[0]["episode"]["r"]
assert ep_len == infos[0]["episode"]["l"]
ep_len, ep_reward = 0, 0
monitor_env.close()
with open(monitor_file, "rt") as file_handler:
first_line = file_handler.readline()
assert first_line.startswith("#")
metadata = json.loads(first_line[1:])
assert set(metadata.keys()) == {"t_start", "env_id"}, "Incorrect keys in monitor metadata"
last_logline = pandas.read_csv(file_handler, index_col=None)
assert set(last_logline.keys()) == {"l", "t", "r"}, "Incorrect keys in monitor logline"
os.remove(monitor_file)
def test_vec_monitor_load_results(tmp_path):
"""
test load_results on log files produced by the monitor wrapper
"""
tmp_path = str(tmp_path)
env1 = DummyVecEnv([lambda: gym.make("CartPole-v1")])
env1.seed(0)
monitor_file1 = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
monitor_env1 = VecMonitor(env1, monitor_file1)
monitor_files = get_monitor_files(tmp_path)
assert len(monitor_files) == 1
assert monitor_file1 in monitor_files
monitor_env1.reset()
episode_count1 = 0
for _ in range(1000):
_, _, dones, _ = monitor_env1.step([monitor_env1.action_space.sample()])
if dones[0]:
episode_count1 += 1
monitor_env1.reset()
results_size1 = len(load_results(os.path.join(tmp_path)).index)
assert results_size1 == episode_count1
env2 = DummyVecEnv([lambda: gym.make("CartPole-v1")])
env2.seed(0)
monitor_file2 = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
monitor_env2 = VecMonitor(env2, monitor_file2)
monitor_files = get_monitor_files(tmp_path)
assert len(monitor_files) == 2
assert monitor_file1 in monitor_files
assert monitor_file2 in monitor_files
monitor_env2.reset()
episode_count2 = 0
for _ in range(1000):
_, _, dones, _ = monitor_env2.step([monitor_env2.action_space.sample()])
if dones[0]:
episode_count2 += 1
monitor_env2.reset()
results_size2 = len(load_results(os.path.join(tmp_path)).index)
assert results_size2 == (results_size1 + episode_count2)
os.remove(monitor_file1)
os.remove(monitor_file2)
def test_vec_monitor_ppo(recwarn):
"""
Test the `VecMonitor` with PPO
"""
env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
env.seed(0)
monitor_env = VecMonitor(env)
model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu")
model.learn(total_timesteps=250)
# No warnings because using `VecMonitor`
evaluate_policy(model, monitor_env)
assert len(recwarn) == 0
def test_vec_monitor_warn():
env = DummyVecEnv([lambda: Monitor(gym.make("CartPole-v1"))])
# We should warn the user when the env is already wrapped with a Monitor wrapper
with pytest.warns(UserWarning):
VecMonitor(env)
with pytest.warns(UserWarning):
VecMonitor(VecNormalize(env))