mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
* Add support for custom objects * Add python 3.8 to the CI * Bump version * PyType fixes * [ci skip] Fix typo * Add note about slow-down + fix typos * Minor edits to the doc * Bug fix for DQN * Update test * Add test for custom objects
114 lines
5 KiB
Python
114 lines
5 KiB
Python
import warnings
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import gym
|
|
import numpy as np
|
|
|
|
from stable_baselines3.common import base_class
|
|
from stable_baselines3.common.vec_env import VecEnv
|
|
|
|
|
|
def evaluate_policy(
|
|
model: "base_class.BaseAlgorithm",
|
|
env: Union[gym.Env, VecEnv],
|
|
n_eval_episodes: int = 10,
|
|
deterministic: bool = True,
|
|
render: bool = False,
|
|
callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
|
|
reward_threshold: Optional[float] = None,
|
|
return_episode_rewards: bool = False,
|
|
warn: bool = True,
|
|
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
|
|
"""
|
|
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
|
|
This is made to work only with one env.
|
|
|
|
.. note::
|
|
If environment has not been wrapped with ``Monitor`` wrapper, reward and
|
|
episode lengths are counted as it appears with ``env.step`` calls. If
|
|
the environment contains wrappers that modify rewards or episode lengths
|
|
(e.g. reward scaling, early episode reset), these will affect the evaluation
|
|
results as well. You can avoid this by wrapping environment with ``Monitor``
|
|
wrapper before anything else.
|
|
|
|
:param model: The RL agent you want to evaluate.
|
|
:param env: The gym environment. In the case of a ``VecEnv``
|
|
this must contain only one environment.
|
|
:param n_eval_episodes: Number of episode to evaluate the agent
|
|
:param deterministic: Whether to use deterministic or stochastic actions
|
|
:param render: Whether to render the environment or not
|
|
:param callback: callback function to do additional checks,
|
|
called after each step. Gets locals() and globals() passed as parameters.
|
|
:param reward_threshold: Minimum expected reward per episode,
|
|
this will raise an error if the performance is not met
|
|
:param return_episode_rewards: If True, a list of rewards and episode lengths
|
|
per episode will be returned instead of the mean.
|
|
:param warn: If True (default), warns user about lack of a Monitor wrapper in the
|
|
evaluation environment.
|
|
:return: Mean reward per episode, std of reward per episode.
|
|
Returns ([float], [int]) when ``return_episode_rewards`` is True, first
|
|
list containing per-episode rewards and second containing per-episode lengths
|
|
(in number of steps).
|
|
"""
|
|
is_monitor_wrapped = False
|
|
# Avoid circular import
|
|
from stable_baselines3.common.env_util import is_wrapped
|
|
from stable_baselines3.common.monitor import Monitor
|
|
|
|
if isinstance(env, VecEnv):
|
|
assert env.num_envs == 1, "You must pass only one environment when using this function"
|
|
is_monitor_wrapped = env.env_is_wrapped(Monitor)[0]
|
|
else:
|
|
is_monitor_wrapped = is_wrapped(env, Monitor)
|
|
|
|
if not is_monitor_wrapped and warn:
|
|
warnings.warn(
|
|
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
|
|
"This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. "
|
|
"Consider wrapping environment first with ``Monitor`` wrapper.",
|
|
UserWarning,
|
|
)
|
|
|
|
episode_rewards, episode_lengths = [], []
|
|
not_reseted = True
|
|
while len(episode_rewards) < n_eval_episodes:
|
|
# Number of loops here might differ from true episodes
|
|
# played, if underlying wrappers modify episode lengths.
|
|
# Avoid double reset, as VecEnv are reset automatically.
|
|
if not isinstance(env, VecEnv) or not_reseted:
|
|
obs = env.reset()
|
|
not_reseted = False
|
|
done, state = False, None
|
|
episode_reward = 0.0
|
|
episode_length = 0
|
|
while not done:
|
|
action, state = model.predict(obs, state=state, deterministic=deterministic)
|
|
obs, reward, done, info = env.step(action)
|
|
episode_reward += reward
|
|
if callback is not None:
|
|
callback(locals(), globals())
|
|
episode_length += 1
|
|
if render:
|
|
env.render()
|
|
|
|
if is_monitor_wrapped:
|
|
# Do not trust "done" with episode endings.
|
|
# Remove vecenv stacking (if any)
|
|
if isinstance(env, VecEnv):
|
|
info = info[0]
|
|
if "episode" in info.keys():
|
|
# Monitor wrapper includes "episode" key in info if environment
|
|
# has been wrapped with it. Use those rewards instead.
|
|
episode_rewards.append(info["episode"]["r"])
|
|
episode_lengths.append(info["episode"]["l"])
|
|
else:
|
|
episode_rewards.append(episode_reward)
|
|
episode_lengths.append(episode_length)
|
|
|
|
mean_reward = np.mean(episode_rewards)
|
|
std_reward = np.std(episode_rewards)
|
|
if reward_threshold is not None:
|
|
assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}"
|
|
if return_episode_rewards:
|
|
return episode_rewards, episode_lengths
|
|
return mean_reward, std_reward
|