Escape tensorboard log name (#857)

* escape tensorboard log name

Otherwise utils does not recognize the log.

* Added fix to changelog

* Modifications made by: make commit-checks .

* Revert "Modifications made by: make commit-checks ."

This reverts commit 529a275d9475f85ef031038a8f3565f7301e5371.

* Update changelog and add test

Co-authored-by: James Hirschorn <James.Hirschorn@quantitative-technologies.com>
This commit is contained in:
Antonin RAFFIN 2022-04-11 21:49:18 +02:00 committed by GitHub
parent 248f082cdc
commit 39a4f9379a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 5 deletions

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.5.1a3 (WIP)
Release 1.5.1a4 (WIP)
---------------------------
Breaking Changes:
@ -23,6 +23,7 @@ Bug Fixes:
^^^^^^^^^^
- Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517)
- Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec)
- Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies)
Deprecations:
^^^^^^^^^^^^^
@ -962,4 +963,4 @@ And all the contributors:
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar @ycheng517
@Gregwar @ycheng517 @quantitative-technologies

View file

@ -154,15 +154,18 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device:
return device
def get_latest_run_id(log_path: Optional[str] = None, log_name: str = "") -> int:
def get_latest_run_id(log_path: str = "", log_name: str = "") -> int:
"""
Returns the latest run number for the given log name and log path,
by finding the greatest number in the directories.
:param log_path: Path to the log folder containing several runs.
:param log_name: Name of the experiment. Each run is stored
in a folder named ``log_name_1``, ``log_name_2``, ...
:return: latest run number
"""
max_run_id = 0
for path in glob.glob(f"{log_path}/{log_name}_[0-9]*"):
for path in glob.glob(os.path.join(log_path, f"{glob.escape(log_name)}_[0-9]*")):
file_name = path.split(os.sep)[-1]
ext = file_name.split("_")[-1]
if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id:

View file

@ -1 +1 @@
1.5.1a3
1.5.1a4

View file

@ -3,6 +3,7 @@ import os
import pytest
from stable_baselines3 import A2C, PPO, SAC, TD3
from stable_baselines3.common.utils import get_latest_run_id
MODEL_DICT = {
"a2c": (A2C, "CartPole-v1"),
@ -35,3 +36,13 @@ def test_tensorboard(tmp_path, model_name):
assert os.path.isdir(tmp_path / str(logname + "_1"))
# Check that the log dir name increments correctly
assert os.path.isdir(tmp_path / str(logname + "_2"))
def test_escape_log_name(tmp_path):
# Log name that must be escaped
log_name = "filename[16, 16]"
# Create folder
os.makedirs(str(tmp_path) + f"/{log_name}_1", exist_ok=True)
os.makedirs(str(tmp_path) + f"/{log_name}_2", exist_ok=True)
last_run_id = get_latest_run_id(tmp_path, log_name)
assert last_run_id == 2