Fix Self return type (#1167)

* Fix Self annotation

* Update changelog

* Define type var on top

* ClassSelf to SelfClass

* annotate self

* Revert Running meanstd change

* Revert vecnormalize change (static method rejected)

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Quentin Gallouédec 2022-11-22 13:42:39 +01:00 committed by GitHub
parent d829a1bb04
commit f3abda5cbc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 58 additions and 42 deletions

View file

@ -27,6 +27,7 @@ Bug Fixes:
- Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm``
- Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the ``custom_objects`` workaround
- Fix type annotation of ``model`` in ``evaluate_policy``
- Fixed ``Self`` return type using ``TypeVar``
Deprecations:
^^^^^^^^^^^^^

View file

@ -9,7 +9,7 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance
A2CSelf = TypeVar("A2CSelf", bound="A2C")
SelfA2C = TypeVar("SelfA2C", bound="A2C")
class A2C(OnPolicyAlgorithm):
@ -181,14 +181,14 @@ class A2C(OnPolicyAlgorithm):
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
def learn(
self: A2CSelf,
self: SelfA2C,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 100,
tb_log_name: str = "A2C",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> A2CSelf:
) -> SelfA2C:
return super().learn(
total_timesteps=total_timesteps,

View file

@ -38,6 +38,8 @@ from stable_baselines3.common.vec_env import (
unwrap_vec_normalize,
)
SelfBaseAlgorithm = TypeVar("SelfBaseAlgorithm", bound="BaseAlgorithm")
def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymEnv]:
"""If env is a string, make the environment; otherwise, return env.
@ -53,9 +55,6 @@ def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymE
return env
BaseAlgorithmSelf = TypeVar("BaseAlgorithmSelf", bound="BaseAlgorithm")
class BaseAlgorithm(ABC):
"""
The base of RL algorithms
@ -491,14 +490,14 @@ class BaseAlgorithm(ABC):
@abstractmethod
def learn(
self: BaseAlgorithmSelf,
self: SelfBaseAlgorithm,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 100,
tb_log_name: str = "run",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> BaseAlgorithmSelf:
) -> SelfBaseAlgorithm:
"""
Return a trained model.
@ -617,7 +616,7 @@ class BaseAlgorithm(ABC):
@classmethod
def load(
cls: Type[BaseAlgorithmSelf],
cls: Type[SelfBaseAlgorithm],
path: Union[str, pathlib.Path, io.BufferedIOBase],
env: Optional[GymEnv] = None,
device: Union[th.device, str] = "auto",
@ -625,7 +624,7 @@ class BaseAlgorithm(ABC):
print_system_info: bool = False,
force_reset: bool = True,
**kwargs,
) -> BaseAlgorithmSelf:
) -> SelfBaseAlgorithm:
"""
Load the model from a zip-file.
Warning: ``load`` re-creates the model from scratch, it does not update it in-place!

View file

@ -1,7 +1,7 @@
"""Probability distributions."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
import gym
import numpy as np
@ -12,6 +12,16 @@ from torch.distributions import Bernoulli, Categorical, Normal
from stable_baselines3.common.preprocessing import get_action_dim
SelfDistribution = TypeVar("SelfDistribution", bound="Distribution")
SelfDiagGaussianDistribution = TypeVar("SelfDiagGaussianDistribution", bound="DiagGaussianDistribution")
SelfSquashedDiagGaussianDistribution = TypeVar(
"SelfSquashedDiagGaussianDistribution", bound="SquashedDiagGaussianDistribution"
)
SelfCategoricalDistribution = TypeVar("SelfCategoricalDistribution", bound="CategoricalDistribution")
SelfMultiCategoricalDistribution = TypeVar("SelfMultiCategoricalDistribution", bound="MultiCategoricalDistribution")
SelfBernoulliDistribution = TypeVar("SelfBernoulliDistribution", bound="BernoulliDistribution")
SelfStateDependentNoiseDistribution = TypeVar("SelfStateDependentNoiseDistribution", bound="StateDependentNoiseDistribution")
class Distribution(ABC):
"""Abstract base class for distributions."""
@ -28,7 +38,7 @@ class Distribution(ABC):
concrete classes."""
@abstractmethod
def proba_distribution(self, *args, **kwargs) -> "Distribution":
def proba_distribution(self: SelfDistribution, *args, **kwargs) -> SelfDistribution:
"""Set parameters of the distribution.
:return: self
@ -141,7 +151,9 @@ class DiagGaussianDistribution(Distribution):
log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True)
return mean_actions, log_std
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution":
def proba_distribution(
self: SelfDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor
) -> SelfDiagGaussianDistribution:
"""
Create the distribution given its parameters (mean, std)
@ -207,7 +219,9 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
self.epsilon = epsilon
self.gaussian_actions = None
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "SquashedDiagGaussianDistribution":
def proba_distribution(
self: SelfSquashedDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor
) -> SelfSquashedDiagGaussianDistribution:
super().proba_distribution(mean_actions, log_std)
return self
@ -271,7 +285,7 @@ class CategoricalDistribution(Distribution):
action_logits = nn.Linear(latent_dim, self.action_dim)
return action_logits
def proba_distribution(self, action_logits: th.Tensor) -> "CategoricalDistribution":
def proba_distribution(self: SelfCategoricalDistribution, action_logits: th.Tensor) -> SelfCategoricalDistribution:
self.distribution = Categorical(logits=action_logits)
return self
@ -323,7 +337,9 @@ class MultiCategoricalDistribution(Distribution):
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
return action_logits
def proba_distribution(self, action_logits: th.Tensor) -> "MultiCategoricalDistribution":
def proba_distribution(
self: SelfMultiCategoricalDistribution, action_logits: th.Tensor
) -> SelfMultiCategoricalDistribution:
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
return self
@ -376,7 +392,7 @@ class BernoulliDistribution(Distribution):
action_logits = nn.Linear(latent_dim, self.action_dims)
return action_logits
def proba_distribution(self, action_logits: th.Tensor) -> "BernoulliDistribution":
def proba_distribution(self: SelfBernoulliDistribution, action_logits: th.Tensor) -> SelfBernoulliDistribution:
self.distribution = Bernoulli(logits=action_logits)
return self
@ -520,8 +536,8 @@ class StateDependentNoiseDistribution(Distribution):
return mean_actions_net, log_std
def proba_distribution(
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
) -> "StateDependentNoiseDistribution":
self: SelfStateDependentNoiseDistribution, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
) -> SelfStateDependentNoiseDistribution:
"""
Create the distribution given its parameters (mean, std)

View file

@ -21,7 +21,7 @@ from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
OffPolicyAlgorithmSelf = TypeVar("OffPolicyAlgorithmSelf", bound="OffPolicyAlgorithm")
SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="OffPolicyAlgorithm")
class OffPolicyAlgorithm(BaseAlgorithm):
@ -311,14 +311,14 @@ class OffPolicyAlgorithm(BaseAlgorithm):
)
def learn(
self: OffPolicyAlgorithmSelf,
self: SelfOffPolicyAlgorithm,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "run",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> OffPolicyAlgorithmSelf:
) -> SelfOffPolicyAlgorithm:
total_timesteps, callback = self._setup_learn(
total_timesteps,

View file

@ -14,7 +14,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
OnPolicyAlgorithmSelf = TypeVar("OnPolicyAlgorithmSelf", bound="OnPolicyAlgorithm")
SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="OnPolicyAlgorithm")
class OnPolicyAlgorithm(BaseAlgorithm):
@ -223,14 +223,14 @@ class OnPolicyAlgorithm(BaseAlgorithm):
raise NotImplementedError
def learn(
self: OnPolicyAlgorithmSelf,
self: SelfOnPolicyAlgorithm,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
tb_log_name: str = "OnPolicyAlgorithm",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> OnPolicyAlgorithmSelf:
) -> SelfOnPolicyAlgorithm:
iteration = 0
total_timesteps, callback = self._setup_learn(

View file

@ -32,7 +32,7 @@ from stable_baselines3.common.torch_layers import (
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
BaseModelSelf = TypeVar("BaseModelSelf", bound="BaseModel")
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
class BaseModel(nn.Module):
@ -159,7 +159,7 @@ class BaseModel(nn.Module):
th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
@classmethod
def load(cls: Type[BaseModelSelf], path: str, device: Union[th.device, str] = "auto") -> BaseModelSelf:
def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel:
"""
Load model from path.

View file

@ -8,7 +8,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
from stable_baselines3.td3.policies import TD3Policy
from stable_baselines3.td3.td3 import TD3
DDPGSelf = TypeVar("DDPGSelf", bound="DDPG")
SelfDDPG = TypeVar("SelfDDPG", bound="DDPG")
class DDPG(TD3):
@ -113,14 +113,14 @@ class DDPG(TD3):
self._setup_model()
def learn(
self: DDPGSelf,
self: SelfDDPG,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "DDPG",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> DDPGSelf:
) -> SelfDDPG:
return super().learn(
total_timesteps=total_timesteps,

View file

@ -14,7 +14,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy
DQNSelf = TypeVar("DQNSelf", bound="DQN")
SelfDQN = TypeVar("SelfDQN", bound="DQN")
class DQN(OffPolicyAlgorithm):
@ -253,14 +253,14 @@ class DQN(OffPolicyAlgorithm):
return action, state
def learn(
self: DQNSelf,
self: SelfDQN,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "DQN",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> DQNSelf:
) -> SelfDQN:
return super().learn(
total_timesteps=total_timesteps,

View file

@ -11,7 +11,7 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
PPOSelf = TypeVar("PPOSelf", bound="PPO")
SelfPPO = TypeVar("SelfPPO", bound="PPO")
class PPO(OnPolicyAlgorithm):
@ -295,14 +295,14 @@ class PPO(OnPolicyAlgorithm):
self.logger.record("train/clip_range_vf", clip_range_vf)
def learn(
self: PPOSelf,
self: SelfPPO,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
tb_log_name: str = "PPO",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> PPOSelf:
) -> SelfPPO:
return super().learn(
total_timesteps=total_timesteps,

View file

@ -13,7 +13,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
SACSelf = TypeVar("SACSelf", bound="SAC")
SelfSAC = TypeVar("SelfSAC", bound="SAC")
class SAC(OffPolicyAlgorithm):
@ -287,14 +287,14 @@ class SAC(OffPolicyAlgorithm):
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def learn(
self: SACSelf,
self: SelfSAC,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "SAC",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SACSelf:
) -> SelfSAC:
return super().learn(
total_timesteps=total_timesteps,

View file

@ -13,7 +13,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy
TD3Self = TypeVar("TD3Self", bound="TD3")
SelfTD3 = TypeVar("SelfTD3", bound="TD3")
class TD3(OffPolicyAlgorithm):
@ -203,14 +203,14 @@ class TD3(OffPolicyAlgorithm):
self.logger.record("train/critic_loss", np.mean(critic_losses))
def learn(
self: TD3Self,
self: SelfTD3,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "TD3",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> TD3Self:
) -> SelfTD3:
return super().learn(
total_timesteps=total_timesteps,