mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Log success rate for on policy algorithms (#1870)
* Add success rate in monitor for on policy algorithms * Update changelog * make commit-checks refactoring * Assert buffers are not none in _dump_logs * Automatic refactoring of the type hinting * Add success_rate logging test for on policy algorithms * Update changelog * Reformat * Fix tests and update changelog --------- Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
8b3723c6d8
commit
071226d3e8
4 changed files with 120 additions and 14 deletions
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.3.0a4 (WIP)
|
||||
Release 2.3.0a5 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -41,9 +41,11 @@ Breaking Changes:
|
|||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Log success rate ``rollout/success_rate`` when available for on policy algorithms (@corentinlger)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed ``monitor_wrapper`` argument that was not passed to the parent class, and dones argument that wasn't passed to ``_update_into_buffer`` (@corentinlger)
|
||||
|
||||
`SB3-Contrib`_
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -92,6 +92,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq,
|
||||
support_multi_env=True,
|
||||
monitor_wrapper=monitor_wrapper,
|
||||
seed=seed,
|
||||
stats_window_size=stats_window_size,
|
||||
tensorboard_log=tensorboard_log,
|
||||
|
|
@ -200,7 +201,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
if not callback.on_step():
|
||||
return False
|
||||
|
||||
self._update_info_buffer(infos)
|
||||
self._update_info_buffer(infos, dones)
|
||||
n_steps += 1
|
||||
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
|
|
@ -250,6 +251,28 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _dump_logs(self, iteration: int) -> None:
|
||||
"""
|
||||
Write log.
|
||||
|
||||
:param iteration: Current logging iteration
|
||||
"""
|
||||
assert self.ep_info_buffer is not None
|
||||
assert self.ep_success_buffer is not None
|
||||
|
||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("time/fps", fps)
|
||||
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
if len(self.ep_success_buffer) > 0:
|
||||
self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer))
|
||||
self.logger.dump(step=self.num_timesteps)
|
||||
|
||||
def learn(
|
||||
self: SelfOnPolicyAlgorithm,
|
||||
total_timesteps: int,
|
||||
|
|
@ -285,16 +308,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
# Display training infos
|
||||
if log_interval is not None and iteration % log_interval == 0:
|
||||
assert self.ep_info_buffer is not None
|
||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("time/fps", fps)
|
||||
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
self.logger.dump(step=self.num_timesteps)
|
||||
self._dump_logs(iteration)
|
||||
|
||||
self.train()
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.3.0a4
|
||||
2.3.0a5
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from gymnasium import spaces
|
|||
from matplotlib import pyplot as plt
|
||||
from pandas.errors import EmptyDataError
|
||||
|
||||
from stable_baselines3 import A2C, DQN
|
||||
from stable_baselines3 import A2C, DQN, PPO
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
from stable_baselines3.common.logger import (
|
||||
DEBUG,
|
||||
|
|
@ -33,6 +33,7 @@ from stable_baselines3.common.logger import (
|
|||
read_csv,
|
||||
read_json,
|
||||
)
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
|
||||
KEY_VALUES = {
|
||||
"test": 1,
|
||||
|
|
@ -474,3 +475,92 @@ def test_human_output_format_custom_test_io(base_class):
|
|||
"""
|
||||
|
||||
assert printed == desired_printed
|
||||
|
||||
|
||||
class DummySuccessEnv(gym.Env):
|
||||
"""
|
||||
Create a dummy success environment that returns wether True or False for info['is_success']
|
||||
at the end of an episode according to its dummy successes list
|
||||
"""
|
||||
|
||||
def __init__(self, dummy_successes, ep_steps):
|
||||
"""Init the dummy success env
|
||||
|
||||
:param dummy_successes: list of size (n_logs_iterations, n_episodes_per_log) that specifies
|
||||
the success value of log iteration i at episode j
|
||||
:param ep_steps: number of steps per episode (to activate truncated)
|
||||
"""
|
||||
self.n_steps = 0
|
||||
self.log_id = 0
|
||||
self.ep_id = 0
|
||||
|
||||
self.ep_steps = ep_steps
|
||||
|
||||
self.dummy_success = dummy_successes
|
||||
self.num_logs = len(dummy_successes)
|
||||
self.ep_per_log = len(dummy_successes[0])
|
||||
self.steps_per_log = self.ep_per_log * self.ep_steps
|
||||
|
||||
self.action_space = spaces.Discrete(2)
|
||||
self.observation_space = spaces.Discrete(2)
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
"""
|
||||
Reset the env and advance to the next episode_id to get the next dummy success
|
||||
"""
|
||||
self.n_steps = 0
|
||||
|
||||
if self.ep_id == self.ep_per_log:
|
||||
self.ep_id = 0
|
||||
self.log_id = (self.log_id + 1) % self.num_logs
|
||||
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Step and return a dummy success when an episode is truncated
|
||||
"""
|
||||
self.n_steps += 1
|
||||
truncated = self.n_steps >= self.ep_steps
|
||||
|
||||
info = {}
|
||||
if truncated:
|
||||
maybe_success = self.dummy_success[self.log_id][self.ep_id]
|
||||
info["is_success"] = maybe_success
|
||||
self.ep_id += 1
|
||||
return self.observation_space.sample(), 0.0, False, truncated, info
|
||||
|
||||
|
||||
def test_rollout_success_rate_on_policy_algorithm(tmp_path):
|
||||
"""
|
||||
Test if the rollout/success_rate information is correctly logged with on policy algorithms
|
||||
|
||||
To do so, create a dummy environment that takes as argument dummy successes (i.e when an episode)
|
||||
is going to be successfull or not.
|
||||
"""
|
||||
|
||||
STATS_WINDOW_SIZE = 10
|
||||
# Add dummy successes with 0.3, 0.5 and 0.8 success_rate of length STATS_WINDOW_SIZE
|
||||
dummy_successes = [
|
||||
[True] * 3 + [False] * 7,
|
||||
[True] * 5 + [False] * 5,
|
||||
[True] * 8 + [False] * 2,
|
||||
]
|
||||
ep_steps = 64
|
||||
|
||||
# Monitor the env to track the success info
|
||||
monitor_file = str(tmp_path / "monitor.csv")
|
||||
env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",))
|
||||
|
||||
# Equip the model of a custom logger to check the success_rate info
|
||||
model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.steps_per_log, verbose=1)
|
||||
logger = InMemoryLogger()
|
||||
model.set_logger(logger)
|
||||
|
||||
# Make the model learn and check that the success rate corresponds to the ratio of dummy successes
|
||||
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
|
||||
assert logger.name_to_value["rollout/success_rate"] == 0.3
|
||||
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
|
||||
assert logger.name_to_value["rollout/success_rate"] == 0.5
|
||||
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
|
||||
assert logger.name_to_value["rollout/success_rate"] == 0.8
|
||||
|
|
|
|||
Loading…
Reference in a new issue