mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-28 22:56:53 +00:00
Update evaluate_policy type annotation to support policies as well as RL algorithms (#1146)
* Add PolicyPredictor protocol and use it in evaluate_policy * Update changelog * Move Protocol to type_aliases to avoid circular import * Add test for evaluate_policy on BasePolicy * Remove unused import * Use typing_extensions * Move typing_extensions to 3rd party * Add version range (typing_extensions uses SemVer) * Import Protocol from typing_extensions only on Python<3.8 Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Install typing_extensions only on Python<3.8 * Add missing sys import * Fix import ordering * Fix observation type hint in predict Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
This commit is contained in:
parent
0532a5719c
commit
4fb8aec215
7 changed files with 49 additions and 14 deletions
|
|
@ -25,6 +25,7 @@ Bug Fixes:
|
|||
- Fix return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde)
|
||||
- Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm``
|
||||
- Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the ``custom_objects`` workaround
|
||||
- Fix type annotation of ``model`` in ``evaluate_policy``
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
1
setup.py
1
setup.py
|
|
@ -76,6 +76,7 @@ setup(
|
|||
"gym==0.21", # Fixed version due to breaking changes in 0.22
|
||||
"numpy",
|
||||
"torch>=1.11",
|
||||
'typing_extensions>=4.0,<5; python_version < "3.8.0"',
|
||||
# For saving models
|
||||
"cloudpickle",
|
||||
# For reading logs
|
||||
|
|
|
|||
|
|
@ -513,7 +513,7 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
def predict(
|
||||
self,
|
||||
observation: np.ndarray,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|||
import gym
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common import base_class
|
||||
from stable_baselines3.common import type_aliases
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
|
||||
|
||||
|
||||
def evaluate_policy(
|
||||
model: "base_class.BaseAlgorithm",
|
||||
model: "type_aliases.PolicyPredictor",
|
||||
env: Union[gym.Env, VecEnv],
|
||||
n_eval_episodes: int = 10,
|
||||
deterministic: bool = True,
|
||||
|
|
@ -34,7 +34,9 @@ def evaluate_policy(
|
|||
results as well. You can avoid this by wrapping environment with ``Monitor``
|
||||
wrapper before anything else.
|
||||
|
||||
:param model: The RL agent you want to evaluate.
|
||||
:param model: The RL agent you want to evaluate. This can be any object
|
||||
that implements a `predict` method, such as an RL algorithm (``BaseAlgorithm``)
|
||||
or policy (``BasePolicy``).
|
||||
:param env: The gym environment or ``VecEnv`` environment.
|
||||
:param n_eval_episodes: Number of episode to evaluate the agent
|
||||
:param deterministic: Whether to use deterministic or stochastic actions
|
||||
|
|
|
|||
|
|
@ -1,12 +1,18 @@
|
|||
"""Common aliases for type hints"""
|
||||
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Protocol
|
||||
else:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from stable_baselines3.common import callbacks, vec_env
|
||||
|
||||
GymEnv = Union[gym.Env, vec_env.VecEnv]
|
||||
|
|
@ -69,3 +75,26 @@ class TrainFrequencyUnit(Enum):
|
|||
class TrainFreq(NamedTuple):
|
||||
frequency: int
|
||||
unit: TrainFrequencyUnit # either "step" or "episode"
|
||||
|
||||
|
||||
class PolicyPredictor(Protocol):
|
||||
def predict(
|
||||
self,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Get the policy action from an observation (and optional hidden state).
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
||||
:param observation: the input observation
|
||||
:param state: The last hidden states (can be None, used in recurrent policies)
|
||||
:param episode_start: The last masks (can be None, used in recurrent policies)
|
||||
this correspond to beginning of episodes,
|
||||
where the hidden states of the RNN must be reset.
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:return: the model's action and the next hidden state
|
||||
(used in recurrent policies)
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -224,7 +224,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
|
||||
def predict(
|
||||
self,
|
||||
observation: np.ndarray,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
|
|
@ -241,7 +241,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
"""
|
||||
if not deterministic and np.random.rand() < self.exploration_rate:
|
||||
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
|
||||
if isinstance(self.observation_space, gym.spaces.Dict):
|
||||
if isinstance(observation, dict):
|
||||
n_batch = observation[list(observation.keys())[0]].shape[0]
|
||||
else:
|
||||
n_batch = observation.shape[0]
|
||||
|
|
|
|||
|
|
@ -141,16 +141,18 @@ def test_custom_vec_env(tmp_path):
|
|||
make_vec_env("CartPole-v1", n_envs=1, vec_env_kwargs={"dummy": False})
|
||||
|
||||
|
||||
def test_evaluate_policy():
|
||||
@pytest.mark.parametrize("direct_policy", [False, True])
|
||||
def test_evaluate_policy(direct_policy: bool):
|
||||
model = A2C("MlpPolicy", "Pendulum-v1", seed=0)
|
||||
n_steps_per_episode, n_eval_episodes = 200, 2
|
||||
model.n_callback_calls = 0
|
||||
|
||||
def dummy_callback(locals_, _globals):
|
||||
locals_["model"].n_callback_calls += 1
|
||||
|
||||
policy = model.policy if direct_policy else model
|
||||
policy.n_callback_calls = 0
|
||||
_, episode_lengths = evaluate_policy(
|
||||
model,
|
||||
policy,
|
||||
model.get_env(),
|
||||
n_eval_episodes,
|
||||
deterministic=True,
|
||||
|
|
@ -162,19 +164,19 @@ def test_evaluate_policy():
|
|||
|
||||
n_steps = sum(episode_lengths)
|
||||
assert n_steps == n_steps_per_episode * n_eval_episodes
|
||||
assert n_steps == model.n_callback_calls
|
||||
assert n_steps == policy.n_callback_calls
|
||||
|
||||
# Reaching a mean reward of zero is impossible with the Pendulum env
|
||||
with pytest.raises(AssertionError):
|
||||
evaluate_policy(model, model.get_env(), n_eval_episodes, reward_threshold=0.0)
|
||||
evaluate_policy(policy, model.get_env(), n_eval_episodes, reward_threshold=0.0)
|
||||
|
||||
episode_rewards, _ = evaluate_policy(model, model.get_env(), n_eval_episodes, return_episode_rewards=True)
|
||||
episode_rewards, _ = evaluate_policy(policy, model.get_env(), n_eval_episodes, return_episode_rewards=True)
|
||||
assert len(episode_rewards) == n_eval_episodes
|
||||
|
||||
# Test that warning is given about no monitor
|
||||
eval_env = gym.make("Pendulum-v1")
|
||||
with pytest.warns(UserWarning):
|
||||
_ = evaluate_policy(model, eval_env, n_eval_episodes)
|
||||
_ = evaluate_policy(policy, eval_env, n_eval_episodes)
|
||||
|
||||
|
||||
class ZeroRewardWrapper(gym.RewardWrapper):
|
||||
|
|
|
|||
Loading…
Reference in a new issue