From be4e1caa75a96a8f894bfceaf832e1106cf14c61 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 7 Dec 2023 10:45:00 +0100 Subject: [PATCH] Switch to Self type --- stable_baselines3/a2c/a2c.py | 8 ++--- stable_baselines3/common/base_class.py | 12 +++---- stable_baselines3/common/distributions.py | 34 +++++-------------- .../common/off_policy_algorithm.py | 8 ++--- .../common/on_policy_algorithm.py | 8 ++--- stable_baselines3/common/policies.py | 6 ++-- stable_baselines3/ddpg/ddpg.py | 8 ++--- stable_baselines3/dqn/dqn.py | 8 ++--- stable_baselines3/ppo/ppo.py | 8 ++--- stable_baselines3/sac/sac.py | 8 ++--- stable_baselines3/td3/td3.py | 8 ++--- 11 files changed, 39 insertions(+), 77 deletions(-) diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 718571f..71bf9b9 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, Optional, Self, Type, Union import torch as th from gymnasium import spaces @@ -10,8 +10,6 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance -SelfA2C = TypeVar("SelfA2C", bound="A2C") - class A2C(OnPolicyAlgorithm): """ @@ -190,14 +188,14 @@ class A2C(OnPolicyAlgorithm): self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) def learn( - self: SelfA2C, + self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 100, tb_log_name: str = "A2C", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> SelfA2C: + ) -> Self: return super().learn( total_timesteps=total_timesteps, callback=callback, diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 5e87599..9abea24 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -6,7 +6,7 @@ import time import warnings from abc import ABC, abstractmethod from collections import deque -from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, Iterable, List, Optional, Self, Tuple, Type, Union import gymnasium as gym import numpy as np @@ -41,8 +41,6 @@ from stable_baselines3.common.vec_env import ( ) from stable_baselines3.common.vec_env.patch_gym import _convert_space, _patch_env -SelfBaseAlgorithm = TypeVar("SelfBaseAlgorithm", bound="BaseAlgorithm") - def maybe_make_env(env: Union[GymEnv, str], verbose: int) -> GymEnv: """If env is a string, make the environment; otherwise, return env. @@ -510,14 +508,14 @@ class BaseAlgorithm(ABC): @abstractmethod def learn( - self: SelfBaseAlgorithm, + self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 100, tb_log_name: str = "run", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> SelfBaseAlgorithm: + ) -> Self: """ Return a trained model. @@ -637,7 +635,7 @@ class BaseAlgorithm(ABC): @classmethod def load( # noqa: C901 - cls: Type[SelfBaseAlgorithm], + cls, path: Union[str, pathlib.Path, io.BufferedIOBase], env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", @@ -645,7 +643,7 @@ class BaseAlgorithm(ABC): print_system_info: bool = False, force_reset: bool = True, **kwargs, - ) -> SelfBaseAlgorithm: + ) -> Self: """ Load the model from a zip-file. Warning: ``load`` re-creates the model from scratch, it does not update it in-place! diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 149345d..229eda2 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -1,7 +1,7 @@ """Probability distributions.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union +from typing import Any, Dict, List, Optional, Self, Tuple, Union import numpy as np import torch as th @@ -11,16 +11,6 @@ from torch.distributions import Bernoulli, Categorical, Normal from stable_baselines3.common.preprocessing import get_action_dim -SelfDistribution = TypeVar("SelfDistribution", bound="Distribution") -SelfDiagGaussianDistribution = TypeVar("SelfDiagGaussianDistribution", bound="DiagGaussianDistribution") -SelfSquashedDiagGaussianDistribution = TypeVar( - "SelfSquashedDiagGaussianDistribution", bound="SquashedDiagGaussianDistribution" -) -SelfCategoricalDistribution = TypeVar("SelfCategoricalDistribution", bound="CategoricalDistribution") -SelfMultiCategoricalDistribution = TypeVar("SelfMultiCategoricalDistribution", bound="MultiCategoricalDistribution") -SelfBernoulliDistribution = TypeVar("SelfBernoulliDistribution", bound="BernoulliDistribution") -SelfStateDependentNoiseDistribution = TypeVar("SelfStateDependentNoiseDistribution", bound="StateDependentNoiseDistribution") - class Distribution(ABC): """Abstract base class for distributions.""" @@ -37,7 +27,7 @@ class Distribution(ABC): concrete classes.""" @abstractmethod - def proba_distribution(self: SelfDistribution, *args, **kwargs) -> SelfDistribution: + def proba_distribution(self, *args, **kwargs) -> Self: """Set parameters of the distribution. :return: self @@ -150,9 +140,7 @@ class DiagGaussianDistribution(Distribution): log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True) return mean_actions, log_std - def proba_distribution( - self: SelfDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor - ) -> SelfDiagGaussianDistribution: + def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Self: """ Create the distribution given its parameters (mean, std) @@ -218,9 +206,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): self.epsilon = epsilon self.gaussian_actions: Optional[th.Tensor] = None - def proba_distribution( - self: SelfSquashedDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor - ) -> SelfSquashedDiagGaussianDistribution: + def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Self: super().proba_distribution(mean_actions, log_std) return self @@ -284,7 +270,7 @@ class CategoricalDistribution(Distribution): action_logits = nn.Linear(latent_dim, self.action_dim) return action_logits - def proba_distribution(self: SelfCategoricalDistribution, action_logits: th.Tensor) -> SelfCategoricalDistribution: + def proba_distribution(self, action_logits: th.Tensor) -> Self: self.distribution = Categorical(logits=action_logits) return self @@ -336,9 +322,7 @@ class MultiCategoricalDistribution(Distribution): action_logits = nn.Linear(latent_dim, sum(self.action_dims)) return action_logits - def proba_distribution( - self: SelfMultiCategoricalDistribution, action_logits: th.Tensor - ) -> SelfMultiCategoricalDistribution: + def proba_distribution(self, action_logits: th.Tensor) -> Self: self.distribution = [Categorical(logits=split) for split in th.split(action_logits, list(self.action_dims), dim=1)] return self @@ -391,7 +375,7 @@ class BernoulliDistribution(Distribution): action_logits = nn.Linear(latent_dim, self.action_dims) return action_logits - def proba_distribution(self: SelfBernoulliDistribution, action_logits: th.Tensor) -> SelfBernoulliDistribution: + def proba_distribution(self, action_logits: th.Tensor) -> Self: self.distribution = Bernoulli(logits=action_logits) return self @@ -538,9 +522,7 @@ class StateDependentNoiseDistribution(Distribution): self.sample_weights(log_std) return mean_actions_net, log_std - def proba_distribution( - self: SelfStateDependentNoiseDistribution, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor - ) -> SelfStateDependentNoiseDistribution: + def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor) -> Self: """ Create the distribution given its parameters (mean, std) diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index c460d02..8eba91d 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -4,7 +4,7 @@ import sys import time import warnings from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Dict, List, Optional, Self, Tuple, Type, Union import numpy as np import torch as th @@ -21,8 +21,6 @@ from stable_baselines3.common.utils import safe_mean, should_collect_more_steps from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.her.her_replay_buffer import HerReplayBuffer -SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="OffPolicyAlgorithm") - class OffPolicyAlgorithm(BaseAlgorithm): """ @@ -303,14 +301,14 @@ class OffPolicyAlgorithm(BaseAlgorithm): ) def learn( - self: SelfOffPolicyAlgorithm, + self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "run", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> SelfOffPolicyAlgorithm: + ) -> Self: total_timesteps, callback = self._setup_learn( total_timesteps, callback, diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index ddd0f8d..f65cb78 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -1,6 +1,6 @@ import sys import time -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Dict, List, Optional, Self, Tuple, Type, Union import numpy as np import torch as th @@ -14,8 +14,6 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul from stable_baselines3.common.utils import obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv -SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="OnPolicyAlgorithm") - class OnPolicyAlgorithm(BaseAlgorithm): """ @@ -251,14 +249,14 @@ class OnPolicyAlgorithm(BaseAlgorithm): raise NotImplementedError def learn( - self: SelfOnPolicyAlgorithm, + self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, tb_log_name: str = "OnPolicyAlgorithm", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> SelfOnPolicyAlgorithm: + ) -> Self: iteration = 0 total_timesteps, callback = self._setup_learn( diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 50be01c..acddb46 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -5,7 +5,7 @@ import copy import warnings from abc import ABC, abstractmethod from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Dict, List, Optional, Self, Tuple, Type, Union import numpy as np import torch as th @@ -33,8 +33,6 @@ from stable_baselines3.common.torch_layers import ( from stable_baselines3.common.type_aliases import PyTorchObs, Schedule from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor -SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel") - class BaseModel(nn.Module): """ @@ -164,7 +162,7 @@ class BaseModel(nn.Module): th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) @classmethod - def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel: + def load(cls, path: str, device: Union[th.device, str] = "auto") -> Self: """ Load model from path. diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index c311b23..5a4091d 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Optional, Self, Tuple, Type, Union import torch as th @@ -8,8 +8,6 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul from stable_baselines3.td3.policies import TD3Policy from stable_baselines3.td3.td3 import TD3 -SelfDDPG = TypeVar("SelfDDPG", bound="DDPG") - class DDPG(TD3): """ @@ -112,14 +110,14 @@ class DDPG(TD3): self._setup_model() def learn( - self: SelfDDPG, + self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "DDPG", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> SelfDDPG: + ) -> Self: return super().learn( total_timesteps=total_timesteps, callback=callback, diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 42e3d0d..9b952af 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, List, Optional, Self, Tuple, Type, Union import numpy as np import torch as th @@ -13,8 +13,6 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork -SelfDQN = TypeVar("SelfDQN", bound="DQN") - class DQN(OffPolicyAlgorithm): """ @@ -256,14 +254,14 @@ class DQN(OffPolicyAlgorithm): return action, state def learn( - self: SelfDQN, + self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "DQN", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> SelfDQN: + ) -> Self: return super().learn( total_timesteps=total_timesteps, callback=callback, diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index ea7cf5e..2dffd95 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, Optional, Self, Type, Union import numpy as np import torch as th @@ -12,8 +12,6 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn -SelfPPO = TypeVar("SelfPPO", bound="PPO") - class PPO(OnPolicyAlgorithm): """ @@ -304,14 +302,14 @@ class PPO(OnPolicyAlgorithm): self.logger.record("train/clip_range_vf", clip_range_vf) def learn( - self: SelfPPO, + self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, tb_log_name: str = "PPO", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> SelfPPO: + ) -> Self: return super().learn( total_timesteps=total_timesteps, callback=callback, diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index bf0fa50..854542b 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, List, Optional, Self, Tuple, Type, Union import numpy as np import torch as th @@ -13,8 +13,6 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from stable_baselines3.sac.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy -SelfSAC = TypeVar("SelfSAC", bound="SAC") - class SAC(OffPolicyAlgorithm): """ @@ -296,14 +294,14 @@ class SAC(OffPolicyAlgorithm): self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) def learn( - self: SelfSAC, + self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "SAC", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> SelfSAC: + ) -> Self: return super().learn( total_timesteps=total_timesteps, callback=callback, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index a06ce67..346c9ed 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, List, Optional, Self, Tuple, Type, Union import numpy as np import torch as th @@ -13,8 +13,6 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from stable_baselines3.td3.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy -SelfTD3 = TypeVar("SelfTD3", bound="TD3") - class TD3(OffPolicyAlgorithm): """ @@ -211,14 +209,14 @@ class TD3(OffPolicyAlgorithm): self.logger.record("train/critic_loss", np.mean(critic_losses)) def learn( - self: SelfTD3, + self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "TD3", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> SelfTD3: + ) -> Self: return super().learn( total_timesteps=total_timesteps, callback=callback,