From 4fb8aec215fd2dd5d668aae8285937c268baca97 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Thu, 3 Nov 2022 07:36:19 -0700 Subject: [PATCH] Update evaluate_policy type annotation to support policies as well as RL algorithms (#1146) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add PolicyPredictor protocol and use it in evaluate_policy * Update changelog * Move Protocol to type_aliases to avoid circular import * Add test for evaluate_policy on BasePolicy * Remove unused import * Use typing_extensions * Move typing_extensions to 3rd party * Add version range (typing_extensions uses SemVer) * Import Protocol from typing_extensions only on Python<3.8 Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Install typing_extensions only on Python<3.8 * Add missing sys import * Fix import ordering * Fix observation type hint in predict Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin GALLOUÉDEC --- docs/misc/changelog.rst | 1 + setup.py | 1 + stable_baselines3/common/base_class.py | 2 +- stable_baselines3/common/evaluation.py | 8 +++--- stable_baselines3/common/type_aliases.py | 31 +++++++++++++++++++++++- stable_baselines3/dqn/dqn.py | 4 +-- tests/test_utils.py | 16 ++++++------ 7 files changed, 49 insertions(+), 14 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 42ca89d..1941176 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -25,6 +25,7 @@ Bug Fixes: - Fix return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde) - Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm`` - Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the ``custom_objects`` workaround +- Fix type annotation of ``model`` in ``evaluate_policy`` Deprecations: ^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index 58c04fa..bfcb56c 100644 --- a/setup.py +++ b/setup.py @@ -76,6 +76,7 @@ setup( "gym==0.21", # Fixed version due to breaking changes in 0.22 "numpy", "torch>=1.11", + 'typing_extensions>=4.0,<5; python_version < "3.8.0"', # For saving models "cloudpickle", # For reading logs diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index d993f9b..ebfb7de 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -513,7 +513,7 @@ class BaseAlgorithm(ABC): def predict( self, - observation: np.ndarray, + observation: Union[np.ndarray, Dict[str, np.ndarray]], state: Optional[Tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index e3f14d3..ff18137 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -4,12 +4,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import gym import numpy as np -from stable_baselines3.common import base_class +from stable_baselines3.common import type_aliases from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped def evaluate_policy( - model: "base_class.BaseAlgorithm", + model: "type_aliases.PolicyPredictor", env: Union[gym.Env, VecEnv], n_eval_episodes: int = 10, deterministic: bool = True, @@ -34,7 +34,9 @@ def evaluate_policy( results as well. You can avoid this by wrapping environment with ``Monitor`` wrapper before anything else. - :param model: The RL agent you want to evaluate. + :param model: The RL agent you want to evaluate. This can be any object + that implements a `predict` method, such as an RL algorithm (``BaseAlgorithm``) + or policy (``BasePolicy``). :param env: The gym environment or ``VecEnv`` environment. :param n_eval_episodes: Number of episode to evaluate the agent :param deterministic: Whether to use deterministic or stochastic actions diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index f4c29ab..4faad7d 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -1,12 +1,18 @@ """Common aliases for type hints""" +import sys from enum import Enum -from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union import gym import numpy as np import torch as th +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + from stable_baselines3.common import callbacks, vec_env GymEnv = Union[gym.Env, vec_env.VecEnv] @@ -69,3 +75,26 @@ class TrainFrequencyUnit(Enum): class TrainFreq(NamedTuple): frequency: int unit: TrainFrequencyUnit # either "step" or "episode" + + +class PolicyPredictor(Protocol): + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + """ + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next hidden state + (used in recurrent policies) + """ diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 8c67838..9e074a9 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -224,7 +224,7 @@ class DQN(OffPolicyAlgorithm): def predict( self, - observation: np.ndarray, + observation: Union[np.ndarray, Dict[str, np.ndarray]], state: Optional[Tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, @@ -241,7 +241,7 @@ class DQN(OffPolicyAlgorithm): """ if not deterministic and np.random.rand() < self.exploration_rate: if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space): - if isinstance(self.observation_space, gym.spaces.Dict): + if isinstance(observation, dict): n_batch = observation[list(observation.keys())[0]].shape[0] else: n_batch = observation.shape[0] diff --git a/tests/test_utils.py b/tests/test_utils.py index 2a9eade..34db00e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -141,16 +141,18 @@ def test_custom_vec_env(tmp_path): make_vec_env("CartPole-v1", n_envs=1, vec_env_kwargs={"dummy": False}) -def test_evaluate_policy(): +@pytest.mark.parametrize("direct_policy", [False, True]) +def test_evaluate_policy(direct_policy: bool): model = A2C("MlpPolicy", "Pendulum-v1", seed=0) n_steps_per_episode, n_eval_episodes = 200, 2 - model.n_callback_calls = 0 def dummy_callback(locals_, _globals): locals_["model"].n_callback_calls += 1 + policy = model.policy if direct_policy else model + policy.n_callback_calls = 0 _, episode_lengths = evaluate_policy( - model, + policy, model.get_env(), n_eval_episodes, deterministic=True, @@ -162,19 +164,19 @@ def test_evaluate_policy(): n_steps = sum(episode_lengths) assert n_steps == n_steps_per_episode * n_eval_episodes - assert n_steps == model.n_callback_calls + assert n_steps == policy.n_callback_calls # Reaching a mean reward of zero is impossible with the Pendulum env with pytest.raises(AssertionError): - evaluate_policy(model, model.get_env(), n_eval_episodes, reward_threshold=0.0) + evaluate_policy(policy, model.get_env(), n_eval_episodes, reward_threshold=0.0) - episode_rewards, _ = evaluate_policy(model, model.get_env(), n_eval_episodes, return_episode_rewards=True) + episode_rewards, _ = evaluate_policy(policy, model.get_env(), n_eval_episodes, return_episode_rewards=True) assert len(episode_rewards) == n_eval_episodes # Test that warning is given about no monitor eval_env = gym.make("Pendulum-v1") with pytest.warns(UserWarning): - _ = evaluate_policy(model, eval_env, n_eval_episodes) + _ = evaluate_policy(policy, eval_env, n_eval_episodes) class ZeroRewardWrapper(gym.RewardWrapper):