diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 53209cb..3b09fad 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.2.0a10 (WIP) +Release 2.2.0a11 (WIP) -------------------------- **Support for options at reset, bug fixes and better error messages** @@ -62,6 +62,7 @@ Others: - Fixed ``stable_baselines3/common/off_policy_algorithm.py`` type hints - Fixed ``stable_baselines3/common/distributions.py`` type hints - Switched to PyTorch 2.1.0 in the CI (fixes type annotations) +- Fixed ``stable_baselines3/common/policies.py`` type hints Documentation: ^^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index 9c3489a..b15e515 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ exclude = [ "stable_baselines3/common/vec_env/patch_gym.py", "stable_baselines3/common/off_policy_algorithm.py", "stable_baselines3/common/distributions.py", + "stable_baselines3/common/policies.py", ] [tool.mypy] @@ -45,8 +46,7 @@ ignore_missing_imports = true follow_imports = "silent" show_error_codes = true exclude = """(?x)( - stable_baselines3/common/policies.py$ - | stable_baselines3/common/vec_env/__init__.py$ + stable_baselines3/common/vec_env/__init__.py$ | stable_baselines3/common/vec_env/vec_normalize.py$ | tests/test_logger.py$ | tests/test_train_eval_mode.py$ diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 5f57672..0d810ef 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -30,7 +30,7 @@ from stable_baselines3.common.torch_layers import ( NatureCNN, create_mlp, ) -from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.type_aliases import PyTorchObs, Schedule from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel") @@ -119,7 +119,7 @@ class BaseModel(nn.Module): """Helper method to create a features extractor.""" return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs) - def extract_features(self, obs: th.Tensor, features_extractor: BaseFeaturesExtractor) -> th.Tensor: + def extract_features(self, obs: PyTorchObs, features_extractor: BaseFeaturesExtractor) -> th.Tensor: """ Preprocess the observation if needed and extract features. @@ -219,6 +219,9 @@ class BaseModel(nn.Module): """ vectorized_env = False if isinstance(observation, dict): + assert isinstance( + self.observation_space, spaces.Dict + ), f"The observation provided is a dict but the obs space is {self.observation_space}" for key, obs in observation.items(): obs_space = self.observation_space.spaces[key] vectorized_env = vectorized_env or is_vectorized_observation(maybe_transpose(obs, obs_space), obs_space) @@ -228,7 +231,7 @@ class BaseModel(nn.Module): ) return vectorized_env - def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[th.Tensor, bool]: + def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[PyTorchObs, bool]: """ Convert an input observation to a PyTorch tensor that can be fed to a model. Includes sugar-coating to handle different observations (e.g. normalizing images). @@ -239,6 +242,9 @@ class BaseModel(nn.Module): """ vectorized_env = False if isinstance(observation, dict): + assert isinstance( + self.observation_space, spaces.Dict + ), f"The observation provided is a dict but the obs space is {self.observation_space}" # need to copy the dict as the dict in VecFrameStack will become a torch tensor observation = copy.deepcopy(observation) for key, obs in observation.items(): @@ -249,7 +255,7 @@ class BaseModel(nn.Module): obs_ = np.array(obs) vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space) # Add batch dimension if needed - observation[key] = obs_.reshape((-1, *self.observation_space[key].shape)) + observation[key] = obs_.reshape((-1, *self.observation_space[key].shape)) # type: ignore[misc] elif is_image_space(self.observation_space): # Handle the different cases for images @@ -263,10 +269,10 @@ class BaseModel(nn.Module): # Dict obs need to be handled separately vectorized_env = is_vectorized_observation(observation, self.observation_space) # Add batch dimension if needed - observation = observation.reshape((-1, *self.observation_space.shape)) + observation = observation.reshape((-1, *self.observation_space.shape)) # type: ignore[misc] - observation = obs_as_tensor(observation, self.device) - return observation, vectorized_env + obs_tensor = obs_as_tensor(observation, self.device) + return obs_tensor, vectorized_env class BasePolicy(BaseModel, ABC): @@ -308,7 +314,7 @@ class BasePolicy(BaseModel, ABC): module.bias.data.fill_(0.0) @abstractmethod - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: """ Get the action according to the policy for a given observation. @@ -354,27 +360,28 @@ class BasePolicy(BaseModel, ABC): "and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api" ) - observation, vectorized_env = self.obs_to_tensor(observation) + obs_tensor, vectorized_env = self.obs_to_tensor(observation) with th.no_grad(): - actions = self._predict(observation, deterministic=deterministic) + actions = self._predict(obs_tensor, deterministic=deterministic) # Convert to numpy, and reshape to the original action shape - actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) + actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc] if isinstance(self.action_space, spaces.Box): if self.squash_output: # Rescale to proper domain when using squashing - actions = self.unscale_action(actions) + actions = self.unscale_action(actions) # type: ignore[assignment, arg-type] else: # Actions could be on arbitrary scale, so clip the actions to avoid # out of bound error (e.g. if sampling from a Gaussian distribution) - actions = np.clip(actions, self.action_space.low, self.action_space.high) + actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type] # Remove batch dimension if needed if not vectorized_env: + assert isinstance(actions, np.ndarray) actions = actions.squeeze(axis=0) - return actions, state + return actions, state # type: ignore[return-value] def scale_action(self, action: np.ndarray) -> np.ndarray: """ @@ -384,6 +391,9 @@ class BasePolicy(BaseModel, ABC): :param action: Action to scale :return: Scaled action """ + assert isinstance( + self.action_space, spaces.Box + ), f"Trying to scale an action using an action space that is not a Box(): {self.action_space}" low, high = self.action_space.low, self.action_space.high return 2.0 * ((action - low) / (high - low)) - 1.0 @@ -394,6 +404,9 @@ class BasePolicy(BaseModel, ABC): :param scaled_action: Action to un-scale """ + assert isinstance( + self.action_space, spaces.Box + ), f"Trying to unscale an action using an action space that is not a Box(): {self.action_space}" low, high = self.action_space.low, self.action_space.high return low + (0.5 * (scaled_action + 1.0) * (high - low)) @@ -522,7 +535,7 @@ class ActorCriticPolicy(BasePolicy): def _get_constructor_parameters(self) -> Dict[str, Any]: data = super()._get_constructor_parameters() - default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) + default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) # type: ignore[arg-type, return-value] data.update( dict( @@ -616,7 +629,7 @@ class ActorCriticPolicy(BasePolicy): module.apply(partial(self.init_weights, gain=gain)) # Setup optimizer with initial learning rate - self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) # type: ignore[call-arg] def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: """ @@ -639,11 +652,11 @@ class ActorCriticPolicy(BasePolicy): distribution = self._get_action_dist_from_latent(latent_pi) actions = distribution.get_actions(deterministic=deterministic) log_prob = distribution.log_prob(actions) - actions = actions.reshape((-1, *self.action_space.shape)) + actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc] return actions, values, log_prob - def extract_features( - self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None + def extract_features( # type: ignore[override] + self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None ) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]: """ Preprocess the observation if needed and extract features. @@ -691,7 +704,7 @@ class ActorCriticPolicy(BasePolicy): else: raise ValueError("Invalid action distribution") - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: """ Get the action according to the policy for a given observation. @@ -701,7 +714,7 @@ class ActorCriticPolicy(BasePolicy): """ return self.get_distribution(observation).get_actions(deterministic=deterministic) - def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]: + def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]: """ Evaluate actions according to the current policy, given the observations. @@ -725,7 +738,7 @@ class ActorCriticPolicy(BasePolicy): entropy = distribution.entropy() return values, log_prob, entropy - def get_distribution(self, obs: th.Tensor) -> Distribution: + def get_distribution(self, obs: PyTorchObs) -> Distribution: """ Get the current policy distribution given the observations. @@ -736,7 +749,7 @@ class ActorCriticPolicy(BasePolicy): latent_pi = self.mlp_extractor.forward_actor(features) return self._get_action_dist_from_latent(latent_pi) - def predict_values(self, obs: th.Tensor) -> th.Tensor: + def predict_values(self, obs: PyTorchObs) -> th.Tensor: """ Get the estimated values according to the current policy given the observations. @@ -921,6 +934,8 @@ class ContinuousCritic(BaseModel): between the actor and the critic (this saves computation time) """ + features_extractor: BaseFeaturesExtractor + def __init__( self, observation_space: spaces.Space, @@ -944,10 +959,10 @@ class ContinuousCritic(BaseModel): self.share_features_extractor = share_features_extractor self.n_critics = n_critics - self.q_networks = [] + self.q_networks: List[nn.Module] = [] for idx in range(n_critics): - q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) - q_net = nn.Sequential(*q_net) + q_net_list = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) + q_net = nn.Sequential(*q_net_list) self.add_module(f"qf{idx}", q_net) self.q_networks.append(q_net) diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 2b5251c..d0bfbcd 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -90,7 +90,7 @@ def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> def preprocess_obs( - obs: th.Tensor, + obs: Union[th.Tensor, Dict[str, th.Tensor]], observation_space: spaces.Space, normalize_images: bool = True, ) -> Union[th.Tensor, Dict[str, th.Tensor]]: @@ -105,6 +105,16 @@ def preprocess_obs( (True by default) :return: """ + if isinstance(observation_space, spaces.Dict): + # Do not modify by reference the original observation + assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}" + preprocessed_obs = {} + for key, _obs in obs.items(): + preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images) + return preprocessed_obs # type: ignore[return-value] + + assert isinstance(obs, th.Tensor), f"Expecting a torch Tensor, but got {type(obs)}" + if isinstance(observation_space, spaces.Box): if normalize_images and is_image_space(observation_space): return obs.float() / 255.0 @@ -126,15 +136,6 @@ def preprocess_obs( elif isinstance(observation_space, spaces.MultiBinary): return obs.float() - - elif isinstance(observation_space, spaces.Dict): - # Do not modify by reference the original observation - assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}" - preprocessed_obs = {} - for key, _obs in obs.items(): - preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images) - return preprocessed_obs - else: raise NotImplementedError(f"Preprocessing not implemented for {observation_space}") diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 4a0a878..d75e115 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -20,6 +20,7 @@ AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]] TensorDict = Dict[str, th.Tensor] OptimizerStateDict = Dict[str, Any] MaybeCallback = Union[None, Callable, List["BaseCallback"], "BaseCallback"] +PyTorchObs = Union[th.Tensor, TensorDict] # A schedule takes the remaining progress as input # and ouputs a scalar (e.g. learning rate, clip range, ...) diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index fcdb958..9d2cf94 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -12,7 +12,7 @@ from stable_baselines3.common.torch_layers import ( NatureCNN, create_mlp, ) -from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.type_aliases import PyTorchObs, Schedule class QNetwork(BasePolicy): @@ -56,7 +56,7 @@ class QNetwork(BasePolicy): q_net = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn) self.q_net = nn.Sequential(*q_net) - def forward(self, obs: th.Tensor) -> th.Tensor: + def forward(self, obs: PyTorchObs) -> th.Tensor: """ Predict the q-values. @@ -65,7 +65,7 @@ class QNetwork(BasePolicy): """ return self.q_net(self.extract_features(obs, self.features_extractor)) - def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor: + def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor: q_values = self(observation) # Greedy action action = q_values.argmax(dim=1).reshape(-1) @@ -177,10 +177,10 @@ class DQNPolicy(BasePolicy): net_args = self._update_features_extractor(self.net_args, features_extractor=None) return QNetwork(**net_args).to(self.device) - def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: + def forward(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor: return self._predict(obs, deterministic=deterministic) - def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: + def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor: return self.q_net._predict(obs, deterministic=deterministic) def _get_constructor_parameters(self) -> Dict[str, Any]: diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 8902629..97d0ad9 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -15,7 +15,7 @@ from stable_baselines3.common.torch_layers import ( create_mlp, get_actor_critic_arch, ) -from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.type_aliases import PyTorchObs, Schedule # CAP the standard deviation of the actor LOG_STD_MAX = 2 @@ -144,7 +144,7 @@ class Actor(BasePolicy): assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg self.action_dist.sample_weights(self.log_std, batch_size=batch_size) - def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: + def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: """ Get the parameters for the action distribution. @@ -164,17 +164,17 @@ class Actor(BasePolicy): log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) return mean_actions, log_std, {} - def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: + def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # Note: the action is squashed return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs) - def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # return action and associated log prob return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs) - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: return self(observation, deterministic) @@ -346,10 +346,10 @@ class SACPolicy(BasePolicy): critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor) return ContinuousCritic(**critic_kwargs).to(self.device) - def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: + def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor: return self._predict(obs, deterministic=deterministic) - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: return self.actor(observation, deterministic) def set_training_mode(self, mode: bool) -> None: diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index 12117df..dda6cb3 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -14,7 +14,7 @@ from stable_baselines3.common.torch_layers import ( create_mlp, get_actor_critic_arch, ) -from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.type_aliases import PyTorchObs, Schedule class Actor(BasePolicy): @@ -77,7 +77,7 @@ class Actor(BasePolicy): features = self.extract_features(obs, self.features_extractor) return self.mu(features) - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: # Note: the deterministic deterministic parameter is ignored in the case of TD3. # Predictions are always deterministic. return self(observation) @@ -233,10 +233,10 @@ class TD3Policy(BasePolicy): critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor) return ContinuousCritic(**critic_kwargs).to(self.device) - def forward(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + def forward(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: return self._predict(observation, deterministic=deterministic) - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: # Note: the deterministic deterministic parameter is ignored in the case of TD3. # Predictions are always deterministic. return self.actor(observation) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index b208680..13ce6d7 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.2.0a10 +2.2.0a11