mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
Escape tensorboard log name (#857)
* escape tensorboard log name Otherwise utils does not recognize the log. * Added fix to changelog * Modifications made by: make commit-checks . * Revert "Modifications made by: make commit-checks ." This reverts commit 529a275d9475f85ef031038a8f3565f7301e5371. * Update changelog and add test Co-authored-by: James Hirschorn <James.Hirschorn@quantitative-technologies.com>
This commit is contained in:
parent
248f082cdc
commit
39a4f9379a
4 changed files with 20 additions and 5 deletions
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.5.1a3 (WIP)
|
||||
Release 1.5.1a4 (WIP)
|
||||
---------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -23,6 +23,7 @@ Bug Fixes:
|
|||
^^^^^^^^^^
|
||||
- Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517)
|
||||
- Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec)
|
||||
- Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -962,4 +963,4 @@ And all the contributors:
|
|||
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
|
||||
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
|
||||
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
|
||||
@Gregwar @ycheng517
|
||||
@Gregwar @ycheng517 @quantitative-technologies
|
||||
|
|
|
|||
|
|
@ -154,15 +154,18 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device:
|
|||
return device
|
||||
|
||||
|
||||
def get_latest_run_id(log_path: Optional[str] = None, log_name: str = "") -> int:
|
||||
def get_latest_run_id(log_path: str = "", log_name: str = "") -> int:
|
||||
"""
|
||||
Returns the latest run number for the given log name and log path,
|
||||
by finding the greatest number in the directories.
|
||||
|
||||
:param log_path: Path to the log folder containing several runs.
|
||||
:param log_name: Name of the experiment. Each run is stored
|
||||
in a folder named ``log_name_1``, ``log_name_2``, ...
|
||||
:return: latest run number
|
||||
"""
|
||||
max_run_id = 0
|
||||
for path in glob.glob(f"{log_path}/{log_name}_[0-9]*"):
|
||||
for path in glob.glob(os.path.join(log_path, f"{glob.escape(log_name)}_[0-9]*")):
|
||||
file_name = path.split(os.sep)[-1]
|
||||
ext = file_name.split("_")[-1]
|
||||
if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.5.1a3
|
||||
1.5.1a4
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import os
|
|||
import pytest
|
||||
|
||||
from stable_baselines3 import A2C, PPO, SAC, TD3
|
||||
from stable_baselines3.common.utils import get_latest_run_id
|
||||
|
||||
MODEL_DICT = {
|
||||
"a2c": (A2C, "CartPole-v1"),
|
||||
|
|
@ -35,3 +36,13 @@ def test_tensorboard(tmp_path, model_name):
|
|||
assert os.path.isdir(tmp_path / str(logname + "_1"))
|
||||
# Check that the log dir name increments correctly
|
||||
assert os.path.isdir(tmp_path / str(logname + "_2"))
|
||||
|
||||
|
||||
def test_escape_log_name(tmp_path):
|
||||
# Log name that must be escaped
|
||||
log_name = "filename[16, 16]"
|
||||
# Create folder
|
||||
os.makedirs(str(tmp_path) + f"/{log_name}_1", exist_ok=True)
|
||||
os.makedirs(str(tmp_path) + f"/{log_name}_2", exist_ok=True)
|
||||
last_run_id = get_latest_run_id(tmp_path, log_name)
|
||||
assert last_run_id == 2
|
||||
|
|
|
|||
Loading…
Reference in a new issue