mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-15 01:23:46 +00:00
Fix Self return type (#1167)
* Fix Self annotation * Update changelog * Define type var on top * ClassSelf to SelfClass * annotate self * Revert Running meanstd change * Revert vecnormalize change (static method rejected) Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
d829a1bb04
commit
f3abda5cbc
12 changed files with 58 additions and 42 deletions
|
|
@ -27,6 +27,7 @@ Bug Fixes:
|
|||
- Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm``
|
||||
- Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the ``custom_objects`` workaround
|
||||
- Fix type annotation of ``model`` in ``evaluate_policy``
|
||||
- Fixed ``Self`` return type using ``TypeVar``
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP
|
|||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance
|
||||
|
||||
A2CSelf = TypeVar("A2CSelf", bound="A2C")
|
||||
SelfA2C = TypeVar("SelfA2C", bound="A2C")
|
||||
|
||||
|
||||
class A2C(OnPolicyAlgorithm):
|
||||
|
|
@ -181,14 +181,14 @@ class A2C(OnPolicyAlgorithm):
|
|||
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
|
||||
|
||||
def learn(
|
||||
self: A2CSelf,
|
||||
self: SelfA2C,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 100,
|
||||
tb_log_name: str = "A2C",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> A2CSelf:
|
||||
) -> SelfA2C:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ from stable_baselines3.common.vec_env import (
|
|||
unwrap_vec_normalize,
|
||||
)
|
||||
|
||||
SelfBaseAlgorithm = TypeVar("SelfBaseAlgorithm", bound="BaseAlgorithm")
|
||||
|
||||
|
||||
def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymEnv]:
|
||||
"""If env is a string, make the environment; otherwise, return env.
|
||||
|
|
@ -53,9 +55,6 @@ def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymE
|
|||
return env
|
||||
|
||||
|
||||
BaseAlgorithmSelf = TypeVar("BaseAlgorithmSelf", bound="BaseAlgorithm")
|
||||
|
||||
|
||||
class BaseAlgorithm(ABC):
|
||||
"""
|
||||
The base of RL algorithms
|
||||
|
|
@ -491,14 +490,14 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def learn(
|
||||
self: BaseAlgorithmSelf,
|
||||
self: SelfBaseAlgorithm,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 100,
|
||||
tb_log_name: str = "run",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> BaseAlgorithmSelf:
|
||||
) -> SelfBaseAlgorithm:
|
||||
"""
|
||||
Return a trained model.
|
||||
|
||||
|
|
@ -617,7 +616,7 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
@classmethod
|
||||
def load(
|
||||
cls: Type[BaseAlgorithmSelf],
|
||||
cls: Type[SelfBaseAlgorithm],
|
||||
path: Union[str, pathlib.Path, io.BufferedIOBase],
|
||||
env: Optional[GymEnv] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
|
|
@ -625,7 +624,7 @@ class BaseAlgorithm(ABC):
|
|||
print_system_info: bool = False,
|
||||
force_reset: bool = True,
|
||||
**kwargs,
|
||||
) -> BaseAlgorithmSelf:
|
||||
) -> SelfBaseAlgorithm:
|
||||
"""
|
||||
Load the model from a zip-file.
|
||||
Warning: ``load`` re-creates the model from scratch, it does not update it in-place!
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Probability distributions."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -12,6 +12,16 @@ from torch.distributions import Bernoulli, Categorical, Normal
|
|||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
|
||||
SelfDistribution = TypeVar("SelfDistribution", bound="Distribution")
|
||||
SelfDiagGaussianDistribution = TypeVar("SelfDiagGaussianDistribution", bound="DiagGaussianDistribution")
|
||||
SelfSquashedDiagGaussianDistribution = TypeVar(
|
||||
"SelfSquashedDiagGaussianDistribution", bound="SquashedDiagGaussianDistribution"
|
||||
)
|
||||
SelfCategoricalDistribution = TypeVar("SelfCategoricalDistribution", bound="CategoricalDistribution")
|
||||
SelfMultiCategoricalDistribution = TypeVar("SelfMultiCategoricalDistribution", bound="MultiCategoricalDistribution")
|
||||
SelfBernoulliDistribution = TypeVar("SelfBernoulliDistribution", bound="BernoulliDistribution")
|
||||
SelfStateDependentNoiseDistribution = TypeVar("SelfStateDependentNoiseDistribution", bound="StateDependentNoiseDistribution")
|
||||
|
||||
|
||||
class Distribution(ABC):
|
||||
"""Abstract base class for distributions."""
|
||||
|
|
@ -28,7 +38,7 @@ class Distribution(ABC):
|
|||
concrete classes."""
|
||||
|
||||
@abstractmethod
|
||||
def proba_distribution(self, *args, **kwargs) -> "Distribution":
|
||||
def proba_distribution(self: SelfDistribution, *args, **kwargs) -> SelfDistribution:
|
||||
"""Set parameters of the distribution.
|
||||
|
||||
:return: self
|
||||
|
|
@ -141,7 +151,9 @@ class DiagGaussianDistribution(Distribution):
|
|||
log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True)
|
||||
return mean_actions, log_std
|
||||
|
||||
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution":
|
||||
def proba_distribution(
|
||||
self: SelfDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor
|
||||
) -> SelfDiagGaussianDistribution:
|
||||
"""
|
||||
Create the distribution given its parameters (mean, std)
|
||||
|
||||
|
|
@ -207,7 +219,9 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
self.epsilon = epsilon
|
||||
self.gaussian_actions = None
|
||||
|
||||
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "SquashedDiagGaussianDistribution":
|
||||
def proba_distribution(
|
||||
self: SelfSquashedDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor
|
||||
) -> SelfSquashedDiagGaussianDistribution:
|
||||
super().proba_distribution(mean_actions, log_std)
|
||||
return self
|
||||
|
||||
|
|
@ -271,7 +285,7 @@ class CategoricalDistribution(Distribution):
|
|||
action_logits = nn.Linear(latent_dim, self.action_dim)
|
||||
return action_logits
|
||||
|
||||
def proba_distribution(self, action_logits: th.Tensor) -> "CategoricalDistribution":
|
||||
def proba_distribution(self: SelfCategoricalDistribution, action_logits: th.Tensor) -> SelfCategoricalDistribution:
|
||||
self.distribution = Categorical(logits=action_logits)
|
||||
return self
|
||||
|
||||
|
|
@ -323,7 +337,9 @@ class MultiCategoricalDistribution(Distribution):
|
|||
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
|
||||
return action_logits
|
||||
|
||||
def proba_distribution(self, action_logits: th.Tensor) -> "MultiCategoricalDistribution":
|
||||
def proba_distribution(
|
||||
self: SelfMultiCategoricalDistribution, action_logits: th.Tensor
|
||||
) -> SelfMultiCategoricalDistribution:
|
||||
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
|
||||
return self
|
||||
|
||||
|
|
@ -376,7 +392,7 @@ class BernoulliDistribution(Distribution):
|
|||
action_logits = nn.Linear(latent_dim, self.action_dims)
|
||||
return action_logits
|
||||
|
||||
def proba_distribution(self, action_logits: th.Tensor) -> "BernoulliDistribution":
|
||||
def proba_distribution(self: SelfBernoulliDistribution, action_logits: th.Tensor) -> SelfBernoulliDistribution:
|
||||
self.distribution = Bernoulli(logits=action_logits)
|
||||
return self
|
||||
|
||||
|
|
@ -520,8 +536,8 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
return mean_actions_net, log_std
|
||||
|
||||
def proba_distribution(
|
||||
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
|
||||
) -> "StateDependentNoiseDistribution":
|
||||
self: SelfStateDependentNoiseDistribution, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
|
||||
) -> SelfStateDependentNoiseDistribution:
|
||||
"""
|
||||
Create the distribution given its parameters (mean, std)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
|
|||
from stable_baselines3.common.vec_env import VecEnv
|
||||
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
|
||||
|
||||
OffPolicyAlgorithmSelf = TypeVar("OffPolicyAlgorithmSelf", bound="OffPolicyAlgorithm")
|
||||
SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="OffPolicyAlgorithm")
|
||||
|
||||
|
||||
class OffPolicyAlgorithm(BaseAlgorithm):
|
||||
|
|
@ -311,14 +311,14 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
)
|
||||
|
||||
def learn(
|
||||
self: OffPolicyAlgorithmSelf,
|
||||
self: SelfOffPolicyAlgorithm,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
tb_log_name: str = "run",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> OffPolicyAlgorithmSelf:
|
||||
) -> SelfOffPolicyAlgorithm:
|
||||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
total_timesteps,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
|
|||
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
|
||||
OnPolicyAlgorithmSelf = TypeVar("OnPolicyAlgorithmSelf", bound="OnPolicyAlgorithm")
|
||||
SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="OnPolicyAlgorithm")
|
||||
|
||||
|
||||
class OnPolicyAlgorithm(BaseAlgorithm):
|
||||
|
|
@ -223,14 +223,14 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
raise NotImplementedError
|
||||
|
||||
def learn(
|
||||
self: OnPolicyAlgorithmSelf,
|
||||
self: SelfOnPolicyAlgorithm,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
tb_log_name: str = "OnPolicyAlgorithm",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> OnPolicyAlgorithmSelf:
|
||||
) -> SelfOnPolicyAlgorithm:
|
||||
iteration = 0
|
||||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from stable_baselines3.common.torch_layers import (
|
|||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
|
||||
|
||||
BaseModelSelf = TypeVar("BaseModelSelf", bound="BaseModel")
|
||||
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
|
||||
|
||||
|
||||
class BaseModel(nn.Module):
|
||||
|
|
@ -159,7 +159,7 @@ class BaseModel(nn.Module):
|
|||
th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
|
||||
|
||||
@classmethod
|
||||
def load(cls: Type[BaseModelSelf], path: str, device: Union[th.device, str] = "auto") -> BaseModelSelf:
|
||||
def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel:
|
||||
"""
|
||||
Load model from path.
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
|
|||
from stable_baselines3.td3.policies import TD3Policy
|
||||
from stable_baselines3.td3.td3 import TD3
|
||||
|
||||
DDPGSelf = TypeVar("DDPGSelf", bound="DDPG")
|
||||
SelfDDPG = TypeVar("SelfDDPG", bound="DDPG")
|
||||
|
||||
|
||||
class DDPG(TD3):
|
||||
|
|
@ -113,14 +113,14 @@ class DDPG(TD3):
|
|||
self._setup_model()
|
||||
|
||||
def learn(
|
||||
self: DDPGSelf,
|
||||
self: SelfDDPG,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
tb_log_name: str = "DDPG",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> DDPGSelf:
|
||||
) -> SelfDDPG:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
|
|||
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update
|
||||
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy
|
||||
|
||||
DQNSelf = TypeVar("DQNSelf", bound="DQN")
|
||||
SelfDQN = TypeVar("SelfDQN", bound="DQN")
|
||||
|
||||
|
||||
class DQN(OffPolicyAlgorithm):
|
||||
|
|
@ -253,14 +253,14 @@ class DQN(OffPolicyAlgorithm):
|
|||
return action, state
|
||||
|
||||
def learn(
|
||||
self: DQNSelf,
|
||||
self: SelfDQN,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
tb_log_name: str = "DQN",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> DQNSelf:
|
||||
) -> SelfDQN:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP
|
|||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
|
||||
|
||||
PPOSelf = TypeVar("PPOSelf", bound="PPO")
|
||||
SelfPPO = TypeVar("SelfPPO", bound="PPO")
|
||||
|
||||
|
||||
class PPO(OnPolicyAlgorithm):
|
||||
|
|
@ -295,14 +295,14 @@ class PPO(OnPolicyAlgorithm):
|
|||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||
|
||||
def learn(
|
||||
self: PPOSelf,
|
||||
self: SelfPPO,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
tb_log_name: str = "PPO",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> PPOSelf:
|
||||
) -> SelfPPO:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
|
|||
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
|
||||
|
||||
SACSelf = TypeVar("SACSelf", bound="SAC")
|
||||
SelfSAC = TypeVar("SelfSAC", bound="SAC")
|
||||
|
||||
|
||||
class SAC(OffPolicyAlgorithm):
|
||||
|
|
@ -287,14 +287,14 @@ class SAC(OffPolicyAlgorithm):
|
|||
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
||||
|
||||
def learn(
|
||||
self: SACSelf,
|
||||
self: SelfSAC,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
tb_log_name: str = "SAC",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SACSelf:
|
||||
) -> SelfSAC:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
|
|||
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||
from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy
|
||||
|
||||
TD3Self = TypeVar("TD3Self", bound="TD3")
|
||||
SelfTD3 = TypeVar("SelfTD3", bound="TD3")
|
||||
|
||||
|
||||
class TD3(OffPolicyAlgorithm):
|
||||
|
|
@ -203,14 +203,14 @@ class TD3(OffPolicyAlgorithm):
|
|||
self.logger.record("train/critic_loss", np.mean(critic_losses))
|
||||
|
||||
def learn(
|
||||
self: TD3Self,
|
||||
self: SelfTD3,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
tb_log_name: str = "TD3",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> TD3Self:
|
||||
) -> SelfTD3:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
Loading…
Reference in a new issue