Automatically create missing directories of `filenames passed to ResultsWriter` (#1072)

* Create (if any) missing filename directories, passed into ResultsWriter

* Fixed incorrect ``filename`` docstring (if ``filename`` where ``None``, the string method ``filename.endswith(Monitor.EXT)`` would raise an ``AttributeError``), and renamed ``reset_keywords`` docstring.

* Added description of #1068

* Ignore pytype errors

* Update changelog.rst

Co-authored-by: dominicgkerr <dominicgkerr1@gmail.co>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Dominic Kerr 2022-09-21 12:14:38 +01:00 committed by GitHub
parent b7456392ac
commit 899eee6bd4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 21 additions and 8 deletions

View file

@ -38,8 +38,8 @@ Deprecations:
Others:
^^^^^^^
- Fixed ``DictReplayBuffer.next_observations`` typing (@qgallouedec)
- Added support for ``device="auto"`` in buffers and made it default (@qgallouedec)
- Updated ``ResultsWriter` (used internally by ``Monitor`` wrapper) to automatically create missing directories when ``filename`` is a path (@dominicgkerr)
Documentation:
^^^^^^^^^^^^^^
@ -1038,4 +1038,4 @@ And all the contributors:
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr

View file

@ -574,7 +574,7 @@ class DictReplayBuffer(ReplayBuffer):
reward: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:
) -> None: # pytype: disable=signature-mismatch
# Copy to avoid modification by reference
for key in self.observations.keys():
# Reshape needed when using multiple envs with discrete observations
@ -711,7 +711,7 @@ class DictRolloutBuffer(RolloutBuffer):
episode_start: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
) -> None:
) -> None: # pytype: disable=signature-mismatch
"""
:param obs: Observation
:param action: Action

View file

@ -163,9 +163,10 @@ class ResultsWriter:
"""
A result writer that saves the data from the `Monitor` class
:param filename: the location to save a log file, can be None for no log
:param filename: the location to save a log file. When it does not end in
the string ``"monitor.csv"``, this suffix will be appended to it
:param header: the header dictionary object of the saved csv
:param reset_keywords: the extra information to log, typically is composed of
:param extra_keys: the extra information to log, typically is composed of
``reset_keywords`` and ``info_keywords``
:param override_existing: appends to file if ``filename`` exists, otherwise
override existing files (default)
@ -185,6 +186,9 @@ class ResultsWriter:
filename = os.path.join(filename, Monitor.EXT)
else:
filename = filename + "." + Monitor.EXT
filename = os.path.realpath(filename)
# Create (if any) missing filename directories
os.makedirs(os.path.dirname(filename), exist_ok=True)
# Append mode when not overridding existing file
mode = "w" if override_existing else "a"
# Prevent newline issue on Windows, see GH issue #692

View file

@ -199,7 +199,7 @@ class StackedDictObservations(StackedObservations):
spaces_dict[key] = spaces.Box(low=low, high=high, dtype=subspace.dtype)
return spaces.Dict(spaces=spaces_dict)
def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: # pytype: disable=signature-mismatch
"""
Resets the stacked observations, adds the reset observation to the stack, and returns the stack
@ -219,7 +219,7 @@ class StackedDictObservations(StackedObservations):
observations: Dict[str, np.ndarray],
dones: np.ndarray,
infos: List[Dict[str, Any]],
) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]:
) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]: # pytype: disable=signature-mismatch
"""
Adds the observations to the stack and uses the dones to update the infos.

View file

@ -48,6 +48,15 @@ def test_monitor(tmp_path):
assert set(last_logline.keys()) == {"l", "t", "r"}, "Incorrect keys in monitor logline"
os.remove(monitor_file)
# Check missing filename directories are created
monitor_dir = os.path.join(str(tmp_path), "missing-dir")
monitor_file = os.path.join(monitor_dir, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
assert os.path.exists(monitor_dir) is False
_ = Monitor(env, monitor_file)
assert os.path.exists(monitor_dir) is True
os.remove(monitor_file)
os.rmdir(monitor_dir)
def test_monitor_load_results(tmp_path):
"""