Update evaluate_policy type annotation to support policies as well as RL algorithms (#1146)

* Add PolicyPredictor protocol and use it in evaluate_policy

* Update changelog

* Move Protocol to type_aliases to avoid circular import

* Add test for evaluate_policy on BasePolicy

* Remove unused import

* Use typing_extensions

* Move typing_extensions to 3rd party

* Add version range (typing_extensions uses SemVer)

* Import Protocol from typing_extensions only on Python<3.8

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Install typing_extensions only on Python<3.8

* Add missing sys import

* Fix import ordering

* Fix observation type hint in predict

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
This commit is contained in:
Adam Gleave 2022-11-03 07:36:19 -07:00 committed by GitHub
parent 0532a5719c
commit 4fb8aec215
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 49 additions and 14 deletions

View file

@ -25,6 +25,7 @@ Bug Fixes:
- Fix return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde)
- Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm``
- Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the ``custom_objects`` workaround
- Fix type annotation of ``model`` in ``evaluate_policy``
Deprecations:
^^^^^^^^^^^^^

View file

@ -76,6 +76,7 @@ setup(
"gym==0.21", # Fixed version due to breaking changes in 0.22
"numpy",
"torch>=1.11",
'typing_extensions>=4.0,<5; python_version < "3.8.0"',
# For saving models
"cloudpickle",
# For reading logs

View file

@ -513,7 +513,7 @@ class BaseAlgorithm(ABC):
def predict(
self,
observation: np.ndarray,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,

View file

@ -4,12 +4,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import gym
import numpy as np
from stable_baselines3.common import base_class
from stable_baselines3.common import type_aliases
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
def evaluate_policy(
model: "base_class.BaseAlgorithm",
model: "type_aliases.PolicyPredictor",
env: Union[gym.Env, VecEnv],
n_eval_episodes: int = 10,
deterministic: bool = True,
@ -34,7 +34,9 @@ def evaluate_policy(
results as well. You can avoid this by wrapping environment with ``Monitor``
wrapper before anything else.
:param model: The RL agent you want to evaluate.
:param model: The RL agent you want to evaluate. This can be any object
that implements a `predict` method, such as an RL algorithm (``BaseAlgorithm``)
or policy (``BasePolicy``).
:param env: The gym environment or ``VecEnv`` environment.
:param n_eval_episodes: Number of episode to evaluate the agent
:param deterministic: Whether to use deterministic or stochastic actions

View file

@ -1,12 +1,18 @@
"""Common aliases for type hints"""
import sys
from enum import Enum
from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union
import gym
import numpy as np
import torch as th
if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol
from stable_baselines3.common import callbacks, vec_env
GymEnv = Union[gym.Env, vec_env.VecEnv]
@ -69,3 +75,26 @@ class TrainFrequencyUnit(Enum):
class TrainFreq(NamedTuple):
frequency: int
unit: TrainFrequencyUnit # either "step" or "episode"
class PolicyPredictor(Protocol):
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""

View file

@ -224,7 +224,7 @@ class DQN(OffPolicyAlgorithm):
def predict(
self,
observation: np.ndarray,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
@ -241,7 +241,7 @@ class DQN(OffPolicyAlgorithm):
"""
if not deterministic and np.random.rand() < self.exploration_rate:
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
if isinstance(self.observation_space, gym.spaces.Dict):
if isinstance(observation, dict):
n_batch = observation[list(observation.keys())[0]].shape[0]
else:
n_batch = observation.shape[0]

View file

@ -141,16 +141,18 @@ def test_custom_vec_env(tmp_path):
make_vec_env("CartPole-v1", n_envs=1, vec_env_kwargs={"dummy": False})
def test_evaluate_policy():
@pytest.mark.parametrize("direct_policy", [False, True])
def test_evaluate_policy(direct_policy: bool):
model = A2C("MlpPolicy", "Pendulum-v1", seed=0)
n_steps_per_episode, n_eval_episodes = 200, 2
model.n_callback_calls = 0
def dummy_callback(locals_, _globals):
locals_["model"].n_callback_calls += 1
policy = model.policy if direct_policy else model
policy.n_callback_calls = 0
_, episode_lengths = evaluate_policy(
model,
policy,
model.get_env(),
n_eval_episodes,
deterministic=True,
@ -162,19 +164,19 @@ def test_evaluate_policy():
n_steps = sum(episode_lengths)
assert n_steps == n_steps_per_episode * n_eval_episodes
assert n_steps == model.n_callback_calls
assert n_steps == policy.n_callback_calls
# Reaching a mean reward of zero is impossible with the Pendulum env
with pytest.raises(AssertionError):
evaluate_policy(model, model.get_env(), n_eval_episodes, reward_threshold=0.0)
evaluate_policy(policy, model.get_env(), n_eval_episodes, reward_threshold=0.0)
episode_rewards, _ = evaluate_policy(model, model.get_env(), n_eval_episodes, return_episode_rewards=True)
episode_rewards, _ = evaluate_policy(policy, model.get_env(), n_eval_episodes, return_episode_rewards=True)
assert len(episode_rewards) == n_eval_episodes
# Test that warning is given about no monitor
eval_env = gym.make("Pendulum-v1")
with pytest.warns(UserWarning):
_ = evaluate_policy(model, eval_env, n_eval_episodes)
_ = evaluate_policy(policy, eval_env, n_eval_episodes)
class ZeroRewardWrapper(gym.RewardWrapper):