Fix policies type annotations (#1735)

This commit is contained in:
Antonin RAFFIN 2023-11-06 18:35:28 +01:00 committed by GitHub
parent a35c08c0d6
commit d671402c93
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 74 additions and 56 deletions

View file

@ -3,7 +3,7 @@
Changelog
==========
Release 2.2.0a10 (WIP)
Release 2.2.0a11 (WIP)
--------------------------
**Support for options at reset, bug fixes and better error messages**
@ -62,6 +62,7 @@ Others:
- Fixed ``stable_baselines3/common/off_policy_algorithm.py`` type hints
- Fixed ``stable_baselines3/common/distributions.py`` type hints
- Switched to PyTorch 2.1.0 in the CI (fixes type annotations)
- Fixed ``stable_baselines3/common/policies.py`` type hints
Documentation:
^^^^^^^^^^^^^^

View file

@ -38,6 +38,7 @@ exclude = [
"stable_baselines3/common/vec_env/patch_gym.py",
"stable_baselines3/common/off_policy_algorithm.py",
"stable_baselines3/common/distributions.py",
"stable_baselines3/common/policies.py",
]
[tool.mypy]
@ -45,8 +46,7 @@ ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
stable_baselines3/common/policies.py$
| stable_baselines3/common/vec_env/__init__.py$
stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/vec_normalize.py$
| tests/test_logger.py$
| tests/test_train_eval_mode.py$

View file

@ -30,7 +30,7 @@ from stable_baselines3.common.torch_layers import (
NatureCNN,
create_mlp,
)
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
@ -119,7 +119,7 @@ class BaseModel(nn.Module):
"""Helper method to create a features extractor."""
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
def extract_features(self, obs: th.Tensor, features_extractor: BaseFeaturesExtractor) -> th.Tensor:
def extract_features(self, obs: PyTorchObs, features_extractor: BaseFeaturesExtractor) -> th.Tensor:
"""
Preprocess the observation if needed and extract features.
@ -219,6 +219,9 @@ class BaseModel(nn.Module):
"""
vectorized_env = False
if isinstance(observation, dict):
assert isinstance(
self.observation_space, spaces.Dict
), f"The observation provided is a dict but the obs space is {self.observation_space}"
for key, obs in observation.items():
obs_space = self.observation_space.spaces[key]
vectorized_env = vectorized_env or is_vectorized_observation(maybe_transpose(obs, obs_space), obs_space)
@ -228,7 +231,7 @@ class BaseModel(nn.Module):
)
return vectorized_env
def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[th.Tensor, bool]:
def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[PyTorchObs, bool]:
"""
Convert an input observation to a PyTorch tensor that can be fed to a model.
Includes sugar-coating to handle different observations (e.g. normalizing images).
@ -239,6 +242,9 @@ class BaseModel(nn.Module):
"""
vectorized_env = False
if isinstance(observation, dict):
assert isinstance(
self.observation_space, spaces.Dict
), f"The observation provided is a dict but the obs space is {self.observation_space}"
# need to copy the dict as the dict in VecFrameStack will become a torch tensor
observation = copy.deepcopy(observation)
for key, obs in observation.items():
@ -249,7 +255,7 @@ class BaseModel(nn.Module):
obs_ = np.array(obs)
vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space)
# Add batch dimension if needed
observation[key] = obs_.reshape((-1, *self.observation_space[key].shape))
observation[key] = obs_.reshape((-1, *self.observation_space[key].shape)) # type: ignore[misc]
elif is_image_space(self.observation_space):
# Handle the different cases for images
@ -263,10 +269,10 @@ class BaseModel(nn.Module):
# Dict obs need to be handled separately
vectorized_env = is_vectorized_observation(observation, self.observation_space)
# Add batch dimension if needed
observation = observation.reshape((-1, *self.observation_space.shape))
observation = observation.reshape((-1, *self.observation_space.shape)) # type: ignore[misc]
observation = obs_as_tensor(observation, self.device)
return observation, vectorized_env
obs_tensor = obs_as_tensor(observation, self.device)
return obs_tensor, vectorized_env
class BasePolicy(BaseModel, ABC):
@ -308,7 +314,7 @@ class BasePolicy(BaseModel, ABC):
module.bias.data.fill_(0.0)
@abstractmethod
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
@ -354,27 +360,28 @@ class BasePolicy(BaseModel, ABC):
"and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
)
observation, vectorized_env = self.obs_to_tensor(observation)
obs_tensor, vectorized_env = self.obs_to_tensor(observation)
with th.no_grad():
actions = self._predict(observation, deterministic=deterministic)
actions = self._predict(obs_tensor, deterministic=deterministic)
# Convert to numpy, and reshape to the original action shape
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape))
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc]
if isinstance(self.action_space, spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions)
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]
else:
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high)
actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type]
# Remove batch dimension if needed
if not vectorized_env:
assert isinstance(actions, np.ndarray)
actions = actions.squeeze(axis=0)
return actions, state
return actions, state # type: ignore[return-value]
def scale_action(self, action: np.ndarray) -> np.ndarray:
"""
@ -384,6 +391,9 @@ class BasePolicy(BaseModel, ABC):
:param action: Action to scale
:return: Scaled action
"""
assert isinstance(
self.action_space, spaces.Box
), f"Trying to scale an action using an action space that is not a Box(): {self.action_space}"
low, high = self.action_space.low, self.action_space.high
return 2.0 * ((action - low) / (high - low)) - 1.0
@ -394,6 +404,9 @@ class BasePolicy(BaseModel, ABC):
:param scaled_action: Action to un-scale
"""
assert isinstance(
self.action_space, spaces.Box
), f"Trying to unscale an action using an action space that is not a Box(): {self.action_space}"
low, high = self.action_space.low, self.action_space.high
return low + (0.5 * (scaled_action + 1.0) * (high - low))
@ -522,7 +535,7 @@ class ActorCriticPolicy(BasePolicy):
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None)
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) # type: ignore[arg-type, return-value]
data.update(
dict(
@ -616,7 +629,7 @@ class ActorCriticPolicy(BasePolicy):
module.apply(partial(self.init_weights, gain=gain))
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) # type: ignore[call-arg]
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
@ -639,11 +652,11 @@ class ActorCriticPolicy(BasePolicy):
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1, *self.action_space.shape))
actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc]
return actions, values, log_prob
def extract_features(
self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None
def extract_features( # type: ignore[override]
self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
"""
Preprocess the observation if needed and extract features.
@ -691,7 +704,7 @@ class ActorCriticPolicy(BasePolicy):
else:
raise ValueError("Invalid action distribution")
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
@ -701,7 +714,7 @@ class ActorCriticPolicy(BasePolicy):
"""
return self.get_distribution(observation).get_actions(deterministic=deterministic)
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
"""
Evaluate actions according to the current policy,
given the observations.
@ -725,7 +738,7 @@ class ActorCriticPolicy(BasePolicy):
entropy = distribution.entropy()
return values, log_prob, entropy
def get_distribution(self, obs: th.Tensor) -> Distribution:
def get_distribution(self, obs: PyTorchObs) -> Distribution:
"""
Get the current policy distribution given the observations.
@ -736,7 +749,7 @@ class ActorCriticPolicy(BasePolicy):
latent_pi = self.mlp_extractor.forward_actor(features)
return self._get_action_dist_from_latent(latent_pi)
def predict_values(self, obs: th.Tensor) -> th.Tensor:
def predict_values(self, obs: PyTorchObs) -> th.Tensor:
"""
Get the estimated values according to the current policy given the observations.
@ -921,6 +934,8 @@ class ContinuousCritic(BaseModel):
between the actor and the critic (this saves computation time)
"""
features_extractor: BaseFeaturesExtractor
def __init__(
self,
observation_space: spaces.Space,
@ -944,10 +959,10 @@ class ContinuousCritic(BaseModel):
self.share_features_extractor = share_features_extractor
self.n_critics = n_critics
self.q_networks = []
self.q_networks: List[nn.Module] = []
for idx in range(n_critics):
q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
q_net = nn.Sequential(*q_net)
q_net_list = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
q_net = nn.Sequential(*q_net_list)
self.add_module(f"qf{idx}", q_net)
self.q_networks.append(q_net)

View file

@ -90,7 +90,7 @@ def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) ->
def preprocess_obs(
obs: th.Tensor,
obs: Union[th.Tensor, Dict[str, th.Tensor]],
observation_space: spaces.Space,
normalize_images: bool = True,
) -> Union[th.Tensor, Dict[str, th.Tensor]]:
@ -105,6 +105,16 @@ def preprocess_obs(
(True by default)
:return:
"""
if isinstance(observation_space, spaces.Dict):
# Do not modify by reference the original observation
assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}"
preprocessed_obs = {}
for key, _obs in obs.items():
preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
return preprocessed_obs # type: ignore[return-value]
assert isinstance(obs, th.Tensor), f"Expecting a torch Tensor, but got {type(obs)}"
if isinstance(observation_space, spaces.Box):
if normalize_images and is_image_space(observation_space):
return obs.float() / 255.0
@ -126,15 +136,6 @@ def preprocess_obs(
elif isinstance(observation_space, spaces.MultiBinary):
return obs.float()
elif isinstance(observation_space, spaces.Dict):
# Do not modify by reference the original observation
assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}"
preprocessed_obs = {}
for key, _obs in obs.items():
preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
return preprocessed_obs
else:
raise NotImplementedError(f"Preprocessing not implemented for {observation_space}")

View file

@ -20,6 +20,7 @@ AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]]
TensorDict = Dict[str, th.Tensor]
OptimizerStateDict = Dict[str, Any]
MaybeCallback = Union[None, Callable, List["BaseCallback"], "BaseCallback"]
PyTorchObs = Union[th.Tensor, TensorDict]
# A schedule takes the remaining progress as input
# and ouputs a scalar (e.g. learning rate, clip range, ...)

View file

@ -12,7 +12,7 @@ from stable_baselines3.common.torch_layers import (
NatureCNN,
create_mlp,
)
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
class QNetwork(BasePolicy):
@ -56,7 +56,7 @@ class QNetwork(BasePolicy):
q_net = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn)
self.q_net = nn.Sequential(*q_net)
def forward(self, obs: th.Tensor) -> th.Tensor:
def forward(self, obs: PyTorchObs) -> th.Tensor:
"""
Predict the q-values.
@ -65,7 +65,7 @@ class QNetwork(BasePolicy):
"""
return self.q_net(self.extract_features(obs, self.features_extractor))
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor:
q_values = self(observation)
# Greedy action
action = q_values.argmax(dim=1).reshape(-1)
@ -177,10 +177,10 @@ class DQNPolicy(BasePolicy):
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
return QNetwork(**net_args).to(self.device)
def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
def forward(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
return self.q_net._predict(obs, deterministic=deterministic)
def _get_constructor_parameters(self) -> Dict[str, Any]:

View file

@ -15,7 +15,7 @@ from stable_baselines3.common.torch_layers import (
create_mlp,
get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
# CAP the standard deviation of the actor
LOG_STD_MAX = 2
@ -144,7 +144,7 @@ class Actor(BasePolicy):
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
"""
Get the parameters for the action distribution.
@ -164,17 +164,17 @@ class Actor(BasePolicy):
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
return mean_actions, log_std, {}
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# Note: the action is squashed
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# return action and associated log prob
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self(observation, deterministic)
@ -346,10 +346,10 @@ class SACPolicy(BasePolicy):
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
return ContinuousCritic(**critic_kwargs).to(self.device)
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self.actor(observation, deterministic)
def set_training_mode(self, mode: bool) -> None:

View file

@ -14,7 +14,7 @@ from stable_baselines3.common.torch_layers import (
create_mlp,
get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
class Actor(BasePolicy):
@ -77,7 +77,7 @@ class Actor(BasePolicy):
features = self.extract_features(obs, self.features_extractor)
return self.mu(features)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
# Note: the deterministic deterministic parameter is ignored in the case of TD3.
# Predictions are always deterministic.
return self(observation)
@ -233,10 +233,10 @@ class TD3Policy(BasePolicy):
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
return ContinuousCritic(**critic_kwargs).to(self.device)
def forward(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def forward(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self._predict(observation, deterministic=deterministic)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
# Note: the deterministic deterministic parameter is ignored in the case of TD3.
# Predictions are always deterministic.
return self.actor(observation)

View file

@ -1 +1 @@
2.2.0a10
2.2.0a11