mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-31 23:28:05 +00:00
Fix evaluation script for recurrent policies (#678)
* Fix evaluation script for RNN * Add error message * Revert "Add error message" This reverts commit 8d69b6cf4de2cd13aecfb425bd3145fad6a6c49a. * Fix for pytype * Rename mask to `episode_start` * Fix type hint * Fix type hints * Remove confusing part of sentence Co-authored-by: Anssi <kaneran21@hotmail.com>
This commit is contained in:
parent
8e5ede783f
commit
52c29dc497
6 changed files with 34 additions and 28 deletions
|
|
@ -4,11 +4,12 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.3.1a2 (WIP)
|
||||
Release 1.3.1a3 (WIP)
|
||||
---------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Renamed ``mask`` argument of the ``predict()`` method to ``episode_start`` (used with RNN policies only)
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -20,7 +21,7 @@ Bug Fixes:
|
|||
^^^^^^^^^^
|
||||
- Fixed a bug where ``set_env()`` with ``VecNormalize`` would result in an error with off-policy algorithms (thanks @cleversonahum)
|
||||
- FPS calculation is now performed based on number of steps performed during last ``learn`` call, even when ``reset_num_timesteps`` is set to ``False`` (@kachayev)
|
||||
|
||||
- Fixed evaluation script for recurrent policies (experimental feature in SB3 contrib)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -542,21 +542,24 @@ class BaseAlgorithm(ABC):
|
|||
def predict(
|
||||
self,
|
||||
observation: np.ndarray,
|
||||
state: Optional[np.ndarray] = None,
|
||||
mask: Optional[np.ndarray] = None,
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Get the model's action(s) from an observation
|
||||
Get the policy action from an observation (and optional hidden state).
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
||||
:param observation: the input observation
|
||||
:param state: The last states (can be None, used in recurrent policies)
|
||||
:param mask: The last masks (can be None, used in recurrent policies)
|
||||
:param state: The last hidden states (can be None, used in recurrent policies)
|
||||
:param episode_start: The last masks (can be None, used in recurrent policies)
|
||||
this correspond to beginning of episodes,
|
||||
where the hidden states of the RNN must be reset.
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:return: the model's action and the next state
|
||||
:return: the model's action and the next hidden state
|
||||
(used in recurrent policies)
|
||||
"""
|
||||
return self.policy.predict(observation, state, mask, deterministic)
|
||||
return self.policy.predict(observation, state, episode_start, deterministic)
|
||||
|
||||
def set_random_seed(self, seed: Optional[int] = None) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -81,8 +81,9 @@ def evaluate_policy(
|
|||
current_lengths = np.zeros(n_envs, dtype="int")
|
||||
observations = env.reset()
|
||||
states = None
|
||||
episode_starts = np.ones((env.num_envs,), dtype=bool)
|
||||
while (episode_counts < episode_count_targets).any():
|
||||
actions, states = model.predict(observations, state=states, deterministic=deterministic)
|
||||
actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic)
|
||||
observations, rewards, dones, infos = env.step(actions)
|
||||
current_rewards += rewards
|
||||
current_lengths += 1
|
||||
|
|
@ -93,6 +94,7 @@ def evaluate_policy(
|
|||
reward = rewards[i]
|
||||
done = dones[i]
|
||||
info = infos[i]
|
||||
episode_starts[i] = done
|
||||
|
||||
if callback is not None:
|
||||
callback(locals(), globals())
|
||||
|
|
@ -116,8 +118,6 @@ def evaluate_policy(
|
|||
episode_counts[i] += 1
|
||||
current_rewards[i] = 0
|
||||
current_lengths[i] = 0
|
||||
if states is not None:
|
||||
states[i] *= 0
|
||||
|
||||
if render:
|
||||
env.render()
|
||||
|
|
|
|||
|
|
@ -307,26 +307,28 @@ class BasePolicy(BaseModel):
|
|||
def predict(
|
||||
self,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[np.ndarray] = None,
|
||||
mask: Optional[np.ndarray] = None,
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Get the policy action and state from an observation (and optional state).
|
||||
Get the policy action from an observation (and optional hidden state).
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
||||
:param observation: the input observation
|
||||
:param state: The last states (can be None, used in recurrent policies)
|
||||
:param mask: The last masks (can be None, used in recurrent policies)
|
||||
:param state: The last hidden states (can be None, used in recurrent policies)
|
||||
:param episode_start: The last masks (can be None, used in recurrent policies)
|
||||
this correspond to beginning of episodes,
|
||||
where the hidden states of the RNN must be reset.
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:return: the model's action and the next state
|
||||
:return: the model's action and the next hidden state
|
||||
(used in recurrent policies)
|
||||
"""
|
||||
# TODO (GH/1): add support for RNN policies
|
||||
# if state is None:
|
||||
# state = self.initial_state
|
||||
# if mask is None:
|
||||
# mask = [False for _ in range(self.n_envs)]
|
||||
# if episode_start is None:
|
||||
# episode_start = [False for _ in range(self.n_envs)]
|
||||
# Switch to eval mode (this affects batch norm / dropout)
|
||||
self.set_training_mode(False)
|
||||
|
||||
|
|
|
|||
|
|
@ -198,16 +198,16 @@ class DQN(OffPolicyAlgorithm):
|
|||
def predict(
|
||||
self,
|
||||
observation: np.ndarray,
|
||||
state: Optional[np.ndarray] = None,
|
||||
mask: Optional[np.ndarray] = None,
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Overrides the base_class predict function to include epsilon-greedy exploration.
|
||||
|
||||
:param observation: the input observation
|
||||
:param state: The last states (can be None, used in recurrent policies)
|
||||
:param mask: The last masks (can be None, used in recurrent policies)
|
||||
:param episode_start: The last masks (can be None, used in recurrent policies)
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:return: the model's action and the next state
|
||||
(used in recurrent policies)
|
||||
|
|
@ -222,7 +222,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
else:
|
||||
action = np.array(self.action_space.sample())
|
||||
else:
|
||||
action, state = self.policy.predict(observation, state, mask, deterministic)
|
||||
action, state = self.policy.predict(observation, state, episode_start, deterministic)
|
||||
return action, state
|
||||
|
||||
def learn(
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.3.1a2
|
||||
1.3.1a3
|
||||
|
|
|
|||
Loading…
Reference in a new issue