mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-02 03:55:39 +00:00
Adjust FPS calculation to accommodate for reset_num_timesteps=False (#636)
* Store number of timesteps at the beginning of each learn cycle * Update changelog * Set default _num_timesteps_at_start in the contructor * Test case for FPS logger * Adjust test to cover both on-policy and off-policy algorithms * Fix formatting * Update test and add comment * Fix test Co-authored-by: Oleksii Kachaiev <okachaiev@riotgames.com> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
a2e3001598
commit
0c17fedfac
5 changed files with 67 additions and 3 deletions
|
|
@ -16,6 +16,7 @@ New Features:
|
|||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- FPS calculation is now performed based on number of steps performed during last ``learn`` call, even when ``reset_num_timesteps`` is set to ``False`` (@kachayev)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -122,6 +122,8 @@ class BaseAlgorithm(ABC):
|
|||
self.num_timesteps = 0
|
||||
# Used for updating schedules
|
||||
self._total_timesteps = 0
|
||||
# Used for computing fps, it is updated at each call of learn()
|
||||
self._num_timesteps_at_start = 0
|
||||
self.eval_env = None
|
||||
self.seed = seed
|
||||
self.action_noise = None # type: Optional[ActionNoise]
|
||||
|
|
@ -420,6 +422,7 @@ class BaseAlgorithm(ABC):
|
|||
# Make sure training timesteps are ahead of the internal counter
|
||||
total_timesteps += self.num_timesteps
|
||||
self._total_timesteps = total_timesteps
|
||||
self._num_timesteps_at_start = self.num_timesteps
|
||||
|
||||
# Avoid resetting the environment when calling ``.learn()`` consecutive times
|
||||
if reset_num_timesteps or self._last_obs is None:
|
||||
|
|
|
|||
|
|
@ -430,7 +430,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
Write log.
|
||||
"""
|
||||
time_elapsed = time.time() - self.start_time
|
||||
fps = int(self.num_timesteps / (time_elapsed + 1e-8))
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8))
|
||||
self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
|
|
|
|||
|
|
@ -244,7 +244,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
|
||||
# Display training infos
|
||||
if log_interval is not None and iteration % log_interval == 0:
|
||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
|
||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
|
|
|
|||
|
|
@ -1,13 +1,15 @@
|
|||
import os
|
||||
import time
|
||||
from typing import Sequence
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
from matplotlib import pyplot as plt
|
||||
from pandas.errors import EmptyDataError
|
||||
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3 import A2C, DQN
|
||||
from stable_baselines3.common.logger import (
|
||||
DEBUG,
|
||||
INFO,
|
||||
|
|
@ -16,6 +18,7 @@ from stable_baselines3.common.logger import (
|
|||
FormatUnsupportedError,
|
||||
HumanOutputFormat,
|
||||
Image,
|
||||
Logger,
|
||||
TensorBoardOutputFormat,
|
||||
Video,
|
||||
configure,
|
||||
|
|
@ -290,3 +293,60 @@ def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_
|
|||
writer.write({"figure": figure}, key_excluded={"figure": ()})
|
||||
assert unsupported_format in str(exec_info.value)
|
||||
writer.close()
|
||||
|
||||
|
||||
class TimeDelayEnv(gym.Env):
|
||||
"""
|
||||
Gym env for testing FPS logging.
|
||||
"""
|
||||
|
||||
def __init__(self, delay: float = 0.01):
|
||||
super().__init__()
|
||||
self.delay = delay
|
||||
self.observation_space = gym.spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
|
||||
def step(self, action):
|
||||
time.sleep(self.delay)
|
||||
obs = self.observation_space.sample()
|
||||
return obs, 0.0, True, {}
|
||||
|
||||
|
||||
class InMemoryLogger(Logger):
|
||||
"""
|
||||
Logger that keeps key/value pairs in memory without any writers.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("", [])
|
||||
|
||||
def dump(self, step: int = 0) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("algo", [A2C, DQN])
|
||||
def test_fps_logger(tmp_path, algo):
|
||||
logger = InMemoryLogger()
|
||||
max_fps = 1000
|
||||
env = TimeDelayEnv(1 / max_fps)
|
||||
model = algo("MlpPolicy", env, verbose=1)
|
||||
model.set_logger(logger)
|
||||
|
||||
# fps should be at most max_fps
|
||||
model.learn(100, log_interval=1)
|
||||
assert max_fps / 4 <= logger.name_to_value["time/fps"] <= max_fps
|
||||
|
||||
# second time, FPS should be the same
|
||||
model.learn(100, log_interval=1)
|
||||
assert max_fps / 4 <= logger.name_to_value["time/fps"] <= max_fps
|
||||
|
||||
# Artificially increase num_timesteps to check
|
||||
# that fps computation is reset at each call to learn()
|
||||
model.num_timesteps = 20_000
|
||||
|
||||
# third time, FPS should be the same
|
||||
model.learn(100, log_interval=1, reset_num_timesteps=False)
|
||||
assert max_fps / 4 <= logger.name_to_value["time/fps"] <= max_fps
|
||||
|
|
|
|||
Loading…
Reference in a new issue