diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3ef78d5..99fa5b4 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -27,6 +27,7 @@ Bug Fixes: - Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm`` - Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the ``custom_objects`` workaround - Fix type annotation of ``model`` in ``evaluate_policy`` +- Fixed ``Self`` return type using ``TypeVar`` Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 24d69a6..972e700 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -9,7 +9,7 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance -A2CSelf = TypeVar("A2CSelf", bound="A2C") +SelfA2C = TypeVar("SelfA2C", bound="A2C") class A2C(OnPolicyAlgorithm): @@ -181,14 +181,14 @@ class A2C(OnPolicyAlgorithm): self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) def learn( - self: A2CSelf, + self: SelfA2C, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 100, tb_log_name: str = "A2C", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> A2CSelf: + ) -> SelfA2C: return super().learn( total_timesteps=total_timesteps, diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index ebfb7de..bb14f6a 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -38,6 +38,8 @@ from stable_baselines3.common.vec_env import ( unwrap_vec_normalize, ) +SelfBaseAlgorithm = TypeVar("SelfBaseAlgorithm", bound="BaseAlgorithm") + def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymEnv]: """If env is a string, make the environment; otherwise, return env. @@ -53,9 +55,6 @@ def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymE return env -BaseAlgorithmSelf = TypeVar("BaseAlgorithmSelf", bound="BaseAlgorithm") - - class BaseAlgorithm(ABC): """ The base of RL algorithms @@ -491,14 +490,14 @@ class BaseAlgorithm(ABC): @abstractmethod def learn( - self: BaseAlgorithmSelf, + self: SelfBaseAlgorithm, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 100, tb_log_name: str = "run", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> BaseAlgorithmSelf: + ) -> SelfBaseAlgorithm: """ Return a trained model. @@ -617,7 +616,7 @@ class BaseAlgorithm(ABC): @classmethod def load( - cls: Type[BaseAlgorithmSelf], + cls: Type[SelfBaseAlgorithm], path: Union[str, pathlib.Path, io.BufferedIOBase], env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", @@ -625,7 +624,7 @@ class BaseAlgorithm(ABC): print_system_info: bool = False, force_reset: bool = True, **kwargs, - ) -> BaseAlgorithmSelf: + ) -> SelfBaseAlgorithm: """ Load the model from a zip-file. Warning: ``load`` re-creates the model from scratch, it does not update it in-place! diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index fc48625..63eb475 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -1,7 +1,7 @@ """Probability distributions.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union import gym import numpy as np @@ -12,6 +12,16 @@ from torch.distributions import Bernoulli, Categorical, Normal from stable_baselines3.common.preprocessing import get_action_dim +SelfDistribution = TypeVar("SelfDistribution", bound="Distribution") +SelfDiagGaussianDistribution = TypeVar("SelfDiagGaussianDistribution", bound="DiagGaussianDistribution") +SelfSquashedDiagGaussianDistribution = TypeVar( + "SelfSquashedDiagGaussianDistribution", bound="SquashedDiagGaussianDistribution" +) +SelfCategoricalDistribution = TypeVar("SelfCategoricalDistribution", bound="CategoricalDistribution") +SelfMultiCategoricalDistribution = TypeVar("SelfMultiCategoricalDistribution", bound="MultiCategoricalDistribution") +SelfBernoulliDistribution = TypeVar("SelfBernoulliDistribution", bound="BernoulliDistribution") +SelfStateDependentNoiseDistribution = TypeVar("SelfStateDependentNoiseDistribution", bound="StateDependentNoiseDistribution") + class Distribution(ABC): """Abstract base class for distributions.""" @@ -28,7 +38,7 @@ class Distribution(ABC): concrete classes.""" @abstractmethod - def proba_distribution(self, *args, **kwargs) -> "Distribution": + def proba_distribution(self: SelfDistribution, *args, **kwargs) -> SelfDistribution: """Set parameters of the distribution. :return: self @@ -141,7 +151,9 @@ class DiagGaussianDistribution(Distribution): log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True) return mean_actions, log_std - def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution": + def proba_distribution( + self: SelfDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor + ) -> SelfDiagGaussianDistribution: """ Create the distribution given its parameters (mean, std) @@ -207,7 +219,9 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): self.epsilon = epsilon self.gaussian_actions = None - def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "SquashedDiagGaussianDistribution": + def proba_distribution( + self: SelfSquashedDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor + ) -> SelfSquashedDiagGaussianDistribution: super().proba_distribution(mean_actions, log_std) return self @@ -271,7 +285,7 @@ class CategoricalDistribution(Distribution): action_logits = nn.Linear(latent_dim, self.action_dim) return action_logits - def proba_distribution(self, action_logits: th.Tensor) -> "CategoricalDistribution": + def proba_distribution(self: SelfCategoricalDistribution, action_logits: th.Tensor) -> SelfCategoricalDistribution: self.distribution = Categorical(logits=action_logits) return self @@ -323,7 +337,9 @@ class MultiCategoricalDistribution(Distribution): action_logits = nn.Linear(latent_dim, sum(self.action_dims)) return action_logits - def proba_distribution(self, action_logits: th.Tensor) -> "MultiCategoricalDistribution": + def proba_distribution( + self: SelfMultiCategoricalDistribution, action_logits: th.Tensor + ) -> SelfMultiCategoricalDistribution: self.distribution = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)] return self @@ -376,7 +392,7 @@ class BernoulliDistribution(Distribution): action_logits = nn.Linear(latent_dim, self.action_dims) return action_logits - def proba_distribution(self, action_logits: th.Tensor) -> "BernoulliDistribution": + def proba_distribution(self: SelfBernoulliDistribution, action_logits: th.Tensor) -> SelfBernoulliDistribution: self.distribution = Bernoulli(logits=action_logits) return self @@ -520,8 +536,8 @@ class StateDependentNoiseDistribution(Distribution): return mean_actions_net, log_std def proba_distribution( - self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor - ) -> "StateDependentNoiseDistribution": + self: SelfStateDependentNoiseDistribution, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor + ) -> SelfStateDependentNoiseDistribution: """ Create the distribution given its parameters (mean, std) diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 9c2bda2..634d9e9 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -21,7 +21,7 @@ from stable_baselines3.common.utils import safe_mean, should_collect_more_steps from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.her.her_replay_buffer import HerReplayBuffer -OffPolicyAlgorithmSelf = TypeVar("OffPolicyAlgorithmSelf", bound="OffPolicyAlgorithm") +SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="OffPolicyAlgorithm") class OffPolicyAlgorithm(BaseAlgorithm): @@ -311,14 +311,14 @@ class OffPolicyAlgorithm(BaseAlgorithm): ) def learn( - self: OffPolicyAlgorithmSelf, + self: SelfOffPolicyAlgorithm, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "run", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> OffPolicyAlgorithmSelf: + ) -> SelfOffPolicyAlgorithm: total_timesteps, callback = self._setup_learn( total_timesteps, diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index a0018b3..35ad2b9 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -14,7 +14,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul from stable_baselines3.common.utils import obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv -OnPolicyAlgorithmSelf = TypeVar("OnPolicyAlgorithmSelf", bound="OnPolicyAlgorithm") +SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="OnPolicyAlgorithm") class OnPolicyAlgorithm(BaseAlgorithm): @@ -223,14 +223,14 @@ class OnPolicyAlgorithm(BaseAlgorithm): raise NotImplementedError def learn( - self: OnPolicyAlgorithmSelf, + self: SelfOnPolicyAlgorithm, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, tb_log_name: str = "OnPolicyAlgorithm", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> OnPolicyAlgorithmSelf: + ) -> SelfOnPolicyAlgorithm: iteration = 0 total_timesteps, callback = self._setup_learn( diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 5972fa3..c084872 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -32,7 +32,7 @@ from stable_baselines3.common.torch_layers import ( from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor -BaseModelSelf = TypeVar("BaseModelSelf", bound="BaseModel") +SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel") class BaseModel(nn.Module): @@ -159,7 +159,7 @@ class BaseModel(nn.Module): th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) @classmethod - def load(cls: Type[BaseModelSelf], path: str, device: Union[th.device, str] = "auto") -> BaseModelSelf: + def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel: """ Load model from path. diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index 993a8c2..40d67b5 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -8,7 +8,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul from stable_baselines3.td3.policies import TD3Policy from stable_baselines3.td3.td3 import TD3 -DDPGSelf = TypeVar("DDPGSelf", bound="DDPG") +SelfDDPG = TypeVar("SelfDDPG", bound="DDPG") class DDPG(TD3): @@ -113,14 +113,14 @@ class DDPG(TD3): self._setup_model() def learn( - self: DDPGSelf, + self: SelfDDPG, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "DDPG", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> DDPGSelf: + ) -> SelfDDPG: return super().learn( total_timesteps=total_timesteps, diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 9e074a9..e8f4945 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -14,7 +14,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy -DQNSelf = TypeVar("DQNSelf", bound="DQN") +SelfDQN = TypeVar("SelfDQN", bound="DQN") class DQN(OffPolicyAlgorithm): @@ -253,14 +253,14 @@ class DQN(OffPolicyAlgorithm): return action, state def learn( - self: DQNSelf, + self: SelfDQN, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "DQN", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> DQNSelf: + ) -> SelfDQN: return super().learn( total_timesteps=total_timesteps, diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 5d30569..bd80736 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -11,7 +11,7 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn -PPOSelf = TypeVar("PPOSelf", bound="PPO") +SelfPPO = TypeVar("SelfPPO", bound="PPO") class PPO(OnPolicyAlgorithm): @@ -295,14 +295,14 @@ class PPO(OnPolicyAlgorithm): self.logger.record("train/clip_range_vf", clip_range_vf) def learn( - self: PPOSelf, + self: SelfPPO, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, tb_log_name: str = "PPO", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> PPOSelf: + ) -> SelfPPO: return super().learn( total_timesteps=total_timesteps, diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 85bdf78..1eafec0 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -13,7 +13,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy -SACSelf = TypeVar("SACSelf", bound="SAC") +SelfSAC = TypeVar("SelfSAC", bound="SAC") class SAC(OffPolicyAlgorithm): @@ -287,14 +287,14 @@ class SAC(OffPolicyAlgorithm): self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) def learn( - self: SACSelf, + self: SelfSAC, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "SAC", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> SACSelf: + ) -> SelfSAC: return super().learn( total_timesteps=total_timesteps, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index c611481..a8bc3ef 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -13,7 +13,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy -TD3Self = TypeVar("TD3Self", bound="TD3") +SelfTD3 = TypeVar("SelfTD3", bound="TD3") class TD3(OffPolicyAlgorithm): @@ -203,14 +203,14 @@ class TD3(OffPolicyAlgorithm): self.logger.record("train/critic_loss", np.mean(critic_losses)) def learn( - self: TD3Self, + self: SelfTD3, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "TD3", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> TD3Self: + ) -> SelfTD3: return super().learn( total_timesteps=total_timesteps,