Log success rate for on policy algorithms (#1870)

* Add success rate in monitor for on policy algorithms

* Update changelog

* make commit-checks refactoring

* Assert buffers are not none in _dump_logs

* Automatic refactoring of the type hinting

* Add success_rate logging test for on policy algorithms

* Update changelog

* Reformat

* Fix tests and update changelog

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Corentin 2024-03-22 12:13:48 +01:00 committed by GitHub
parent 8b3723c6d8
commit 071226d3e8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 120 additions and 14 deletions

View file

@ -3,7 +3,7 @@
Changelog
==========
Release 2.3.0a4 (WIP)
Release 2.3.0a5 (WIP)
--------------------------
Breaking Changes:
@ -41,9 +41,11 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Log success rate ``rollout/success_rate`` when available for on policy algorithms (@corentinlger)
Bug Fixes:
^^^^^^^^^^
- Fixed ``monitor_wrapper`` argument that was not passed to the parent class, and dones argument that wasn't passed to ``_update_into_buffer`` (@corentinlger)
`SB3-Contrib`_
^^^^^^^^^^^^^^

View file

@ -92,6 +92,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
support_multi_env=True,
monitor_wrapper=monitor_wrapper,
seed=seed,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
@ -200,7 +201,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
if not callback.on_step():
return False
self._update_info_buffer(infos)
self._update_info_buffer(infos, dones)
n_steps += 1
if isinstance(self.action_space, spaces.Discrete):
@ -250,6 +251,28 @@ class OnPolicyAlgorithm(BaseAlgorithm):
"""
raise NotImplementedError
def _dump_logs(self, iteration: int) -> None:
"""
Write log.
:param iteration: Current logging iteration
"""
assert self.ep_info_buffer is not None
assert self.ep_success_buffer is not None
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
if len(self.ep_success_buffer) > 0:
self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer))
self.logger.dump(step=self.num_timesteps)
def learn(
self: SelfOnPolicyAlgorithm,
total_timesteps: int,
@ -285,16 +308,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
assert self.ep_info_buffer is not None
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)
self._dump_logs(iteration)
self.train()

View file

@ -1 +1 @@
2.3.0a4
2.3.0a5

View file

@ -14,7 +14,7 @@ from gymnasium import spaces
from matplotlib import pyplot as plt
from pandas.errors import EmptyDataError
from stable_baselines3 import A2C, DQN
from stable_baselines3 import A2C, DQN, PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.logger import (
DEBUG,
@ -33,6 +33,7 @@ from stable_baselines3.common.logger import (
read_csv,
read_json,
)
from stable_baselines3.common.monitor import Monitor
KEY_VALUES = {
"test": 1,
@ -474,3 +475,92 @@ def test_human_output_format_custom_test_io(base_class):
"""
assert printed == desired_printed
class DummySuccessEnv(gym.Env):
"""
Create a dummy success environment that returns wether True or False for info['is_success']
at the end of an episode according to its dummy successes list
"""
def __init__(self, dummy_successes, ep_steps):
"""Init the dummy success env
:param dummy_successes: list of size (n_logs_iterations, n_episodes_per_log) that specifies
the success value of log iteration i at episode j
:param ep_steps: number of steps per episode (to activate truncated)
"""
self.n_steps = 0
self.log_id = 0
self.ep_id = 0
self.ep_steps = ep_steps
self.dummy_success = dummy_successes
self.num_logs = len(dummy_successes)
self.ep_per_log = len(dummy_successes[0])
self.steps_per_log = self.ep_per_log * self.ep_steps
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Discrete(2)
def reset(self, seed=None, options=None):
"""
Reset the env and advance to the next episode_id to get the next dummy success
"""
self.n_steps = 0
if self.ep_id == self.ep_per_log:
self.ep_id = 0
self.log_id = (self.log_id + 1) % self.num_logs
return self.observation_space.sample(), {}
def step(self, action):
"""
Step and return a dummy success when an episode is truncated
"""
self.n_steps += 1
truncated = self.n_steps >= self.ep_steps
info = {}
if truncated:
maybe_success = self.dummy_success[self.log_id][self.ep_id]
info["is_success"] = maybe_success
self.ep_id += 1
return self.observation_space.sample(), 0.0, False, truncated, info
def test_rollout_success_rate_on_policy_algorithm(tmp_path):
"""
Test if the rollout/success_rate information is correctly logged with on policy algorithms
To do so, create a dummy environment that takes as argument dummy successes (i.e when an episode)
is going to be successfull or not.
"""
STATS_WINDOW_SIZE = 10
# Add dummy successes with 0.3, 0.5 and 0.8 success_rate of length STATS_WINDOW_SIZE
dummy_successes = [
[True] * 3 + [False] * 7,
[True] * 5 + [False] * 5,
[True] * 8 + [False] * 2,
]
ep_steps = 64
# Monitor the env to track the success info
monitor_file = str(tmp_path / "monitor.csv")
env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",))
# Equip the model of a custom logger to check the success_rate info
model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.steps_per_log, verbose=1)
logger = InMemoryLogger()
model.set_logger(logger)
# Make the model learn and check that the success rate corresponds to the ratio of dummy successes
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
assert logger.name_to_value["rollout/success_rate"] == 0.3
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
assert logger.name_to_value["rollout/success_rate"] == 0.5
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
assert logger.name_to_value["rollout/success_rate"] == 0.8