Adjust FPS calculation to accommodate for reset_num_timesteps=False (#636)

* Store number of timesteps at the beginning of each learn cycle

* Update changelog

* Set default _num_timesteps_at_start in the contructor

* Test case for FPS logger

* Adjust test to cover both on-policy and off-policy algorithms

* Fix formatting

* Update test and add comment

* Fix test

Co-authored-by: Oleksii Kachaiev <okachaiev@riotgames.com>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Oleksii Kachaiev 2021-10-31 10:19:03 -07:00 committed by GitHub
parent a2e3001598
commit 0c17fedfac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 67 additions and 3 deletions

View file

@ -16,6 +16,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- FPS calculation is now performed based on number of steps performed during last ``learn`` call, even when ``reset_num_timesteps`` is set to ``False`` (@kachayev)
Deprecations:
^^^^^^^^^^^^^

View file

@ -122,6 +122,8 @@ class BaseAlgorithm(ABC):
self.num_timesteps = 0
# Used for updating schedules
self._total_timesteps = 0
# Used for computing fps, it is updated at each call of learn()
self._num_timesteps_at_start = 0
self.eval_env = None
self.seed = seed
self.action_noise = None # type: Optional[ActionNoise]
@ -420,6 +422,7 @@ class BaseAlgorithm(ABC):
# Make sure training timesteps are ahead of the internal counter
total_timesteps += self.num_timesteps
self._total_timesteps = total_timesteps
self._num_timesteps_at_start = self.num_timesteps
# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:

View file

@ -430,7 +430,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
Write log.
"""
time_elapsed = time.time() - self.start_time
fps = int(self.num_timesteps / (time_elapsed + 1e-8))
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8))
self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))

View file

@ -244,7 +244,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
fps = int(self.num_timesteps / (time.time() - self.start_time))
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))

View file

@ -1,13 +1,15 @@
import os
import time
from typing import Sequence
import gym
import numpy as np
import pytest
import torch as th
from matplotlib import pyplot as plt
from pandas.errors import EmptyDataError
from stable_baselines3 import A2C
from stable_baselines3 import A2C, DQN
from stable_baselines3.common.logger import (
DEBUG,
INFO,
@ -16,6 +18,7 @@ from stable_baselines3.common.logger import (
FormatUnsupportedError,
HumanOutputFormat,
Image,
Logger,
TensorBoardOutputFormat,
Video,
configure,
@ -290,3 +293,60 @@ def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_
writer.write({"figure": figure}, key_excluded={"figure": ()})
assert unsupported_format in str(exec_info.value)
writer.close()
class TimeDelayEnv(gym.Env):
"""
Gym env for testing FPS logging.
"""
def __init__(self, delay: float = 0.01):
super().__init__()
self.delay = delay
self.observation_space = gym.spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
return self.observation_space.sample()
def step(self, action):
time.sleep(self.delay)
obs = self.observation_space.sample()
return obs, 0.0, True, {}
class InMemoryLogger(Logger):
"""
Logger that keeps key/value pairs in memory without any writers.
"""
def __init__(self):
super().__init__("", [])
def dump(self, step: int = 0) -> None:
pass
@pytest.mark.parametrize("algo", [A2C, DQN])
def test_fps_logger(tmp_path, algo):
logger = InMemoryLogger()
max_fps = 1000
env = TimeDelayEnv(1 / max_fps)
model = algo("MlpPolicy", env, verbose=1)
model.set_logger(logger)
# fps should be at most max_fps
model.learn(100, log_interval=1)
assert max_fps / 4 <= logger.name_to_value["time/fps"] <= max_fps
# second time, FPS should be the same
model.learn(100, log_interval=1)
assert max_fps / 4 <= logger.name_to_value["time/fps"] <= max_fps
# Artificially increase num_timesteps to check
# that fps computation is reset at each call to learn()
model.num_timesteps = 20_000
# third time, FPS should be the same
model.learn(100, log_interval=1, reset_num_timesteps=False)
assert max_fps / 4 <= logger.name_to_value["time/fps"] <= max_fps