Fix evaluation script for recurrent policies (#678)

* Fix evaluation script for RNN

* Add error message

* Revert "Add error message"

This reverts commit 8d69b6cf4de2cd13aecfb425bd3145fad6a6c49a.

* Fix for pytype

* Rename mask to `episode_start`

* Fix type hint

* Fix type hints

* Remove confusing part of sentence

Co-authored-by: Anssi <kaneran21@hotmail.com>
This commit is contained in:
Antonin RAFFIN 2021-11-30 13:49:06 +01:00 committed by GitHub
parent 8e5ede783f
commit 52c29dc497
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 34 additions and 28 deletions

View file

@ -4,11 +4,12 @@ Changelog
==========
Release 1.3.1a2 (WIP)
Release 1.3.1a3 (WIP)
---------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Renamed ``mask`` argument of the ``predict()`` method to ``episode_start`` (used with RNN policies only)
New Features:
^^^^^^^^^^^^^
@ -20,7 +21,7 @@ Bug Fixes:
^^^^^^^^^^
- Fixed a bug where ``set_env()`` with ``VecNormalize`` would result in an error with off-policy algorithms (thanks @cleversonahum)
- FPS calculation is now performed based on number of steps performed during last ``learn`` call, even when ``reset_num_timesteps`` is set to ``False`` (@kachayev)
- Fixed evaluation script for recurrent policies (experimental feature in SB3 contrib)
Deprecations:
^^^^^^^^^^^^^

View file

@ -542,21 +542,24 @@ class BaseAlgorithm(ABC):
def predict(
self,
observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the model's action(s) from an observation
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
return self.policy.predict(observation, state, mask, deterministic)
return self.policy.predict(observation, state, episode_start, deterministic)
def set_random_seed(self, seed: Optional[int] = None) -> None:
"""

View file

@ -81,8 +81,9 @@ def evaluate_policy(
current_lengths = np.zeros(n_envs, dtype="int")
observations = env.reset()
states = None
episode_starts = np.ones((env.num_envs,), dtype=bool)
while (episode_counts < episode_count_targets).any():
actions, states = model.predict(observations, state=states, deterministic=deterministic)
actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic)
observations, rewards, dones, infos = env.step(actions)
current_rewards += rewards
current_lengths += 1
@ -93,6 +94,7 @@ def evaluate_policy(
reward = rewards[i]
done = dones[i]
info = infos[i]
episode_starts[i] = done
if callback is not None:
callback(locals(), globals())
@ -116,8 +118,6 @@ def evaluate_policy(
episode_counts[i] += 1
current_rewards[i] = 0
current_lengths[i] = 0
if states is not None:
states[i] *= 0
if render:
env.render()

View file

@ -307,26 +307,28 @@ class BasePolicy(BaseModel):
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action and state from an observation (and optional state).
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
# TODO (GH/1): add support for RNN policies
# if state is None:
# state = self.initial_state
# if mask is None:
# mask = [False for _ in range(self.n_envs)]
# if episode_start is None:
# episode_start = [False for _ in range(self.n_envs)]
# Switch to eval mode (this affects batch norm / dropout)
self.set_training_mode(False)

View file

@ -198,16 +198,16 @@ class DQN(OffPolicyAlgorithm):
def predict(
self,
observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Overrides the base_class predict function to include epsilon-greedy exploration.
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state
(used in recurrent policies)
@ -222,7 +222,7 @@ class DQN(OffPolicyAlgorithm):
else:
action = np.array(self.action_space.sample())
else:
action, state = self.policy.predict(observation, state, mask, deterministic)
action, state = self.policy.predict(observation, state, episode_start, deterministic)
return action, state
def learn(

View file

@ -1 +1 @@
1.3.1a2
1.3.1a3