diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index adb59d3..2c2e5ff 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -38,8 +38,8 @@ Deprecations: Others: ^^^^^^^ - Fixed ``DictReplayBuffer.next_observations`` typing (@qgallouedec) - - Added support for ``device="auto"`` in buffers and made it default (@qgallouedec) +- Updated ``ResultsWriter` (used internally by ``Monitor`` wrapper) to automatically create missing directories when ``filename`` is a path (@dominicgkerr) Documentation: ^^^^^^^^^^^^^^ @@ -1038,4 +1038,4 @@ And all the contributors: @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 -@anand-bala @hughperkins @sidney-tio @AlexPasqua +@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 5972531..1896924 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -574,7 +574,7 @@ class DictReplayBuffer(ReplayBuffer): reward: np.ndarray, done: np.ndarray, infos: List[Dict[str, Any]], - ) -> None: + ) -> None: # pytype: disable=signature-mismatch # Copy to avoid modification by reference for key in self.observations.keys(): # Reshape needed when using multiple envs with discrete observations @@ -711,7 +711,7 @@ class DictRolloutBuffer(RolloutBuffer): episode_start: np.ndarray, value: th.Tensor, log_prob: th.Tensor, - ) -> None: + ) -> None: # pytype: disable=signature-mismatch """ :param obs: Observation :param action: Action diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index 9a07b03..1e56fdb 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -163,9 +163,10 @@ class ResultsWriter: """ A result writer that saves the data from the `Monitor` class - :param filename: the location to save a log file, can be None for no log + :param filename: the location to save a log file. When it does not end in + the string ``"monitor.csv"``, this suffix will be appended to it :param header: the header dictionary object of the saved csv - :param reset_keywords: the extra information to log, typically is composed of + :param extra_keys: the extra information to log, typically is composed of ``reset_keywords`` and ``info_keywords`` :param override_existing: appends to file if ``filename`` exists, otherwise override existing files (default) @@ -185,6 +186,9 @@ class ResultsWriter: filename = os.path.join(filename, Monitor.EXT) else: filename = filename + "." + Monitor.EXT + filename = os.path.realpath(filename) + # Create (if any) missing filename directories + os.makedirs(os.path.dirname(filename), exist_ok=True) # Append mode when not overridding existing file mode = "w" if override_existing else "a" # Prevent newline issue on Windows, see GH issue #692 diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index 733b728..88d725e 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -199,7 +199,7 @@ class StackedDictObservations(StackedObservations): spaces_dict[key] = spaces.Box(low=low, high=high, dtype=subspace.dtype) return spaces.Dict(spaces=spaces_dict) - def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: # pytype: disable=signature-mismatch """ Resets the stacked observations, adds the reset observation to the stack, and returns the stack @@ -219,7 +219,7 @@ class StackedDictObservations(StackedObservations): observations: Dict[str, np.ndarray], dones: np.ndarray, infos: List[Dict[str, Any]], - ) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]: + ) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]: # pytype: disable=signature-mismatch """ Adds the observations to the stack and uses the dones to update the infos. diff --git a/tests/test_monitor.py b/tests/test_monitor.py index e5cb7f9..17002f3 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -48,6 +48,15 @@ def test_monitor(tmp_path): assert set(last_logline.keys()) == {"l", "t", "r"}, "Incorrect keys in monitor logline" os.remove(monitor_file) + # Check missing filename directories are created + monitor_dir = os.path.join(str(tmp_path), "missing-dir") + monitor_file = os.path.join(monitor_dir, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") + assert os.path.exists(monitor_dir) is False + _ = Monitor(env, monitor_file) + assert os.path.exists(monitor_dir) is True + os.remove(monitor_file) + os.rmdir(monitor_dir) + def test_monitor_load_results(tmp_path): """