mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-22 22:10:16 +00:00
Automatically create missing directories of `filenames passed to ResultsWriter` (#1072)
* Create (if any) missing filename directories, passed into ResultsWriter * Fixed incorrect ``filename`` docstring (if ``filename`` where ``None``, the string method ``filename.endswith(Monitor.EXT)`` would raise an ``AttributeError``), and renamed ``reset_keywords`` docstring. * Added description of #1068 * Ignore pytype errors * Update changelog.rst Co-authored-by: dominicgkerr <dominicgkerr1@gmail.co> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
b7456392ac
commit
899eee6bd4
5 changed files with 21 additions and 8 deletions
|
|
@ -38,8 +38,8 @@ Deprecations:
|
|||
Others:
|
||||
^^^^^^^
|
||||
- Fixed ``DictReplayBuffer.next_observations`` typing (@qgallouedec)
|
||||
|
||||
- Added support for ``device="auto"`` in buffers and made it default (@qgallouedec)
|
||||
- Updated ``ResultsWriter` (used internally by ``Monitor`` wrapper) to automatically create missing directories when ``filename`` is a path (@dominicgkerr)
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
@ -1038,4 +1038,4 @@ And all the contributors:
|
|||
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
|
||||
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
|
||||
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
|
||||
@anand-bala @hughperkins @sidney-tio @AlexPasqua
|
||||
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr
|
||||
|
|
|
|||
|
|
@ -574,7 +574,7 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
reward: np.ndarray,
|
||||
done: np.ndarray,
|
||||
infos: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
) -> None: # pytype: disable=signature-mismatch
|
||||
# Copy to avoid modification by reference
|
||||
for key in self.observations.keys():
|
||||
# Reshape needed when using multiple envs with discrete observations
|
||||
|
|
@ -711,7 +711,7 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
episode_start: np.ndarray,
|
||||
value: th.Tensor,
|
||||
log_prob: th.Tensor,
|
||||
) -> None:
|
||||
) -> None: # pytype: disable=signature-mismatch
|
||||
"""
|
||||
:param obs: Observation
|
||||
:param action: Action
|
||||
|
|
|
|||
|
|
@ -163,9 +163,10 @@ class ResultsWriter:
|
|||
"""
|
||||
A result writer that saves the data from the `Monitor` class
|
||||
|
||||
:param filename: the location to save a log file, can be None for no log
|
||||
:param filename: the location to save a log file. When it does not end in
|
||||
the string ``"monitor.csv"``, this suffix will be appended to it
|
||||
:param header: the header dictionary object of the saved csv
|
||||
:param reset_keywords: the extra information to log, typically is composed of
|
||||
:param extra_keys: the extra information to log, typically is composed of
|
||||
``reset_keywords`` and ``info_keywords``
|
||||
:param override_existing: appends to file if ``filename`` exists, otherwise
|
||||
override existing files (default)
|
||||
|
|
@ -185,6 +186,9 @@ class ResultsWriter:
|
|||
filename = os.path.join(filename, Monitor.EXT)
|
||||
else:
|
||||
filename = filename + "." + Monitor.EXT
|
||||
filename = os.path.realpath(filename)
|
||||
# Create (if any) missing filename directories
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
# Append mode when not overridding existing file
|
||||
mode = "w" if override_existing else "a"
|
||||
# Prevent newline issue on Windows, see GH issue #692
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ class StackedDictObservations(StackedObservations):
|
|||
spaces_dict[key] = spaces.Box(low=low, high=high, dtype=subspace.dtype)
|
||||
return spaces.Dict(spaces=spaces_dict)
|
||||
|
||||
def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
|
||||
def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: # pytype: disable=signature-mismatch
|
||||
"""
|
||||
Resets the stacked observations, adds the reset observation to the stack, and returns the stack
|
||||
|
||||
|
|
@ -219,7 +219,7 @@ class StackedDictObservations(StackedObservations):
|
|||
observations: Dict[str, np.ndarray],
|
||||
dones: np.ndarray,
|
||||
infos: List[Dict[str, Any]],
|
||||
) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]:
|
||||
) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]: # pytype: disable=signature-mismatch
|
||||
"""
|
||||
Adds the observations to the stack and uses the dones to update the infos.
|
||||
|
||||
|
|
|
|||
|
|
@ -48,6 +48,15 @@ def test_monitor(tmp_path):
|
|||
assert set(last_logline.keys()) == {"l", "t", "r"}, "Incorrect keys in monitor logline"
|
||||
os.remove(monitor_file)
|
||||
|
||||
# Check missing filename directories are created
|
||||
monitor_dir = os.path.join(str(tmp_path), "missing-dir")
|
||||
monitor_file = os.path.join(monitor_dir, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
|
||||
assert os.path.exists(monitor_dir) is False
|
||||
_ = Monitor(env, monitor_file)
|
||||
assert os.path.exists(monitor_dir) is True
|
||||
os.remove(monitor_file)
|
||||
os.rmdir(monitor_dir)
|
||||
|
||||
|
||||
def test_monitor_load_results(tmp_path):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue