diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3179605..3895380 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.5.1a3 (WIP) +Release 1.5.1a4 (WIP) --------------------------- Breaking Changes: @@ -23,6 +23,7 @@ Bug Fixes: ^^^^^^^^^^ - Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517) - Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec) +- Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies) Deprecations: ^^^^^^^^^^^^^ @@ -962,4 +963,4 @@ And all the contributors: @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 -@Gregwar @ycheng517 +@Gregwar @ycheng517 @quantitative-technologies diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 8504c8d..94cd658 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -154,15 +154,18 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device: return device -def get_latest_run_id(log_path: Optional[str] = None, log_name: str = "") -> int: +def get_latest_run_id(log_path: str = "", log_name: str = "") -> int: """ Returns the latest run number for the given log name and log path, by finding the greatest number in the directories. + :param log_path: Path to the log folder containing several runs. + :param log_name: Name of the experiment. Each run is stored + in a folder named ``log_name_1``, ``log_name_2``, ... :return: latest run number """ max_run_id = 0 - for path in glob.glob(f"{log_path}/{log_name}_[0-9]*"): + for path in glob.glob(os.path.join(log_path, f"{glob.escape(log_name)}_[0-9]*")): file_name = path.split(os.sep)[-1] ext = file_name.split("_")[-1] if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 8d61b2f..d6a9f8c 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a3 +1.5.1a4 diff --git a/tests/test_tensorboard.py b/tests/test_tensorboard.py index 20f58b9..6dccf41 100644 --- a/tests/test_tensorboard.py +++ b/tests/test_tensorboard.py @@ -3,6 +3,7 @@ import os import pytest from stable_baselines3 import A2C, PPO, SAC, TD3 +from stable_baselines3.common.utils import get_latest_run_id MODEL_DICT = { "a2c": (A2C, "CartPole-v1"), @@ -35,3 +36,13 @@ def test_tensorboard(tmp_path, model_name): assert os.path.isdir(tmp_path / str(logname + "_1")) # Check that the log dir name increments correctly assert os.path.isdir(tmp_path / str(logname + "_2")) + + +def test_escape_log_name(tmp_path): + # Log name that must be escaped + log_name = "filename[16, 16]" + # Create folder + os.makedirs(str(tmp_path) + f"/{log_name}_1", exist_ok=True) + os.makedirs(str(tmp_path) + f"/{log_name}_2", exist_ok=True) + last_run_id = get_latest_run_id(tmp_path, log_name) + assert last_run_id == 2