diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 047f95c..b36c358 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -16,6 +16,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- FPS calculation is now performed based on number of steps performed during last ``learn`` call, even when ``reset_num_timesteps`` is set to ``False`` (@kachayev) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 8872f41..21c2748 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -122,6 +122,8 @@ class BaseAlgorithm(ABC): self.num_timesteps = 0 # Used for updating schedules self._total_timesteps = 0 + # Used for computing fps, it is updated at each call of learn() + self._num_timesteps_at_start = 0 self.eval_env = None self.seed = seed self.action_noise = None # type: Optional[ActionNoise] @@ -420,6 +422,7 @@ class BaseAlgorithm(ABC): # Make sure training timesteps are ahead of the internal counter total_timesteps += self.num_timesteps self._total_timesteps = total_timesteps + self._num_timesteps_at_start = self.num_timesteps # Avoid resetting the environment when calling ``.learn()`` consecutive times if reset_num_timesteps or self._last_obs is None: diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 999365a..8b3c667 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -430,7 +430,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): Write log. """ time_elapsed = time.time() - self.start_time - fps = int(self.num_timesteps / (time_elapsed + 1e-8)) + fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8)) self.logger.record("time/episodes", self._episode_num, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 41e193d..8e783dd 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -244,7 +244,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): # Display training infos if log_interval is not None and iteration % log_interval == 0: - fps = int(self.num_timesteps / (time.time() - self.start_time)) + fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time)) self.logger.record("time/iterations", iteration, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) diff --git a/tests/test_logger.py b/tests/test_logger.py index e516171..57463b0 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,13 +1,15 @@ import os +import time from typing import Sequence +import gym import numpy as np import pytest import torch as th from matplotlib import pyplot as plt from pandas.errors import EmptyDataError -from stable_baselines3 import A2C +from stable_baselines3 import A2C, DQN from stable_baselines3.common.logger import ( DEBUG, INFO, @@ -16,6 +18,7 @@ from stable_baselines3.common.logger import ( FormatUnsupportedError, HumanOutputFormat, Image, + Logger, TensorBoardOutputFormat, Video, configure, @@ -290,3 +293,60 @@ def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_ writer.write({"figure": figure}, key_excluded={"figure": ()}) assert unsupported_format in str(exec_info.value) writer.close() + + +class TimeDelayEnv(gym.Env): + """ + Gym env for testing FPS logging. + """ + + def __init__(self, delay: float = 0.01): + super().__init__() + self.delay = delay + self.observation_space = gym.spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32) + self.action_space = gym.spaces.Discrete(2) + + def reset(self): + return self.observation_space.sample() + + def step(self, action): + time.sleep(self.delay) + obs = self.observation_space.sample() + return obs, 0.0, True, {} + + +class InMemoryLogger(Logger): + """ + Logger that keeps key/value pairs in memory without any writers. + """ + + def __init__(self): + super().__init__("", []) + + def dump(self, step: int = 0) -> None: + pass + + +@pytest.mark.parametrize("algo", [A2C, DQN]) +def test_fps_logger(tmp_path, algo): + logger = InMemoryLogger() + max_fps = 1000 + env = TimeDelayEnv(1 / max_fps) + model = algo("MlpPolicy", env, verbose=1) + model.set_logger(logger) + + # fps should be at most max_fps + model.learn(100, log_interval=1) + assert max_fps / 4 <= logger.name_to_value["time/fps"] <= max_fps + + # second time, FPS should be the same + model.learn(100, log_interval=1) + assert max_fps / 4 <= logger.name_to_value["time/fps"] <= max_fps + + # Artificially increase num_timesteps to check + # that fps computation is reset at each call to learn() + model.num_timesteps = 20_000 + + # third time, FPS should be the same + model.learn(100, log_interval=1, reset_num_timesteps=False) + assert max_fps / 4 <= logger.name_to_value["time/fps"] <= max_fps