Switch to Self type

This commit is contained in:
Antonin Raffin 2023-12-07 10:45:00 +01:00
parent 373166d6ac
commit be4e1caa75
No known key found for this signature in database
GPG key ID: B8B48F65CAD6232C
11 changed files with 39 additions and 77 deletions

View file

@ -1,4 +1,4 @@
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
from typing import Any, ClassVar, Dict, Optional, Self, Type, Union
import torch as th
from gymnasium import spaces
@ -10,8 +10,6 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance
SelfA2C = TypeVar("SelfA2C", bound="A2C")
class A2C(OnPolicyAlgorithm):
"""
@ -190,14 +188,14 @@ class A2C(OnPolicyAlgorithm):
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
def learn(
self: SelfA2C,
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 100,
tb_log_name: str = "A2C",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfA2C:
) -> Self:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,

View file

@ -6,7 +6,7 @@ import time
import warnings
from abc import ABC, abstractmethod
from collections import deque
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Self, Tuple, Type, Union
import gymnasium as gym
import numpy as np
@ -41,8 +41,6 @@ from stable_baselines3.common.vec_env import (
)
from stable_baselines3.common.vec_env.patch_gym import _convert_space, _patch_env
SelfBaseAlgorithm = TypeVar("SelfBaseAlgorithm", bound="BaseAlgorithm")
def maybe_make_env(env: Union[GymEnv, str], verbose: int) -> GymEnv:
"""If env is a string, make the environment; otherwise, return env.
@ -510,14 +508,14 @@ class BaseAlgorithm(ABC):
@abstractmethod
def learn(
self: SelfBaseAlgorithm,
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 100,
tb_log_name: str = "run",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfBaseAlgorithm:
) -> Self:
"""
Return a trained model.
@ -637,7 +635,7 @@ class BaseAlgorithm(ABC):
@classmethod
def load( # noqa: C901
cls: Type[SelfBaseAlgorithm],
cls,
path: Union[str, pathlib.Path, io.BufferedIOBase],
env: Optional[GymEnv] = None,
device: Union[th.device, str] = "auto",
@ -645,7 +643,7 @@ class BaseAlgorithm(ABC):
print_system_info: bool = False,
force_reset: bool = True,
**kwargs,
) -> SelfBaseAlgorithm:
) -> Self:
"""
Load the model from a zip-file.
Warning: ``load`` re-creates the model from scratch, it does not update it in-place!

View file

@ -1,7 +1,7 @@
"""Probability distributions."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
from typing import Any, Dict, List, Optional, Self, Tuple, Union
import numpy as np
import torch as th
@ -11,16 +11,6 @@ from torch.distributions import Bernoulli, Categorical, Normal
from stable_baselines3.common.preprocessing import get_action_dim
SelfDistribution = TypeVar("SelfDistribution", bound="Distribution")
SelfDiagGaussianDistribution = TypeVar("SelfDiagGaussianDistribution", bound="DiagGaussianDistribution")
SelfSquashedDiagGaussianDistribution = TypeVar(
"SelfSquashedDiagGaussianDistribution", bound="SquashedDiagGaussianDistribution"
)
SelfCategoricalDistribution = TypeVar("SelfCategoricalDistribution", bound="CategoricalDistribution")
SelfMultiCategoricalDistribution = TypeVar("SelfMultiCategoricalDistribution", bound="MultiCategoricalDistribution")
SelfBernoulliDistribution = TypeVar("SelfBernoulliDistribution", bound="BernoulliDistribution")
SelfStateDependentNoiseDistribution = TypeVar("SelfStateDependentNoiseDistribution", bound="StateDependentNoiseDistribution")
class Distribution(ABC):
"""Abstract base class for distributions."""
@ -37,7 +27,7 @@ class Distribution(ABC):
concrete classes."""
@abstractmethod
def proba_distribution(self: SelfDistribution, *args, **kwargs) -> SelfDistribution:
def proba_distribution(self, *args, **kwargs) -> Self:
"""Set parameters of the distribution.
:return: self
@ -150,9 +140,7 @@ class DiagGaussianDistribution(Distribution):
log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True)
return mean_actions, log_std
def proba_distribution(
self: SelfDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor
) -> SelfDiagGaussianDistribution:
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Self:
"""
Create the distribution given its parameters (mean, std)
@ -218,9 +206,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
self.epsilon = epsilon
self.gaussian_actions: Optional[th.Tensor] = None
def proba_distribution(
self: SelfSquashedDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor
) -> SelfSquashedDiagGaussianDistribution:
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Self:
super().proba_distribution(mean_actions, log_std)
return self
@ -284,7 +270,7 @@ class CategoricalDistribution(Distribution):
action_logits = nn.Linear(latent_dim, self.action_dim)
return action_logits
def proba_distribution(self: SelfCategoricalDistribution, action_logits: th.Tensor) -> SelfCategoricalDistribution:
def proba_distribution(self, action_logits: th.Tensor) -> Self:
self.distribution = Categorical(logits=action_logits)
return self
@ -336,9 +322,7 @@ class MultiCategoricalDistribution(Distribution):
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
return action_logits
def proba_distribution(
self: SelfMultiCategoricalDistribution, action_logits: th.Tensor
) -> SelfMultiCategoricalDistribution:
def proba_distribution(self, action_logits: th.Tensor) -> Self:
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, list(self.action_dims), dim=1)]
return self
@ -391,7 +375,7 @@ class BernoulliDistribution(Distribution):
action_logits = nn.Linear(latent_dim, self.action_dims)
return action_logits
def proba_distribution(self: SelfBernoulliDistribution, action_logits: th.Tensor) -> SelfBernoulliDistribution:
def proba_distribution(self, action_logits: th.Tensor) -> Self:
self.distribution = Bernoulli(logits=action_logits)
return self
@ -538,9 +522,7 @@ class StateDependentNoiseDistribution(Distribution):
self.sample_weights(log_std)
return mean_actions_net, log_std
def proba_distribution(
self: SelfStateDependentNoiseDistribution, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
) -> SelfStateDependentNoiseDistribution:
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor) -> Self:
"""
Create the distribution given its parameters (mean, std)

View file

@ -4,7 +4,7 @@ import sys
import time
import warnings
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Dict, List, Optional, Self, Tuple, Type, Union
import numpy as np
import torch as th
@ -21,8 +21,6 @@ from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="OffPolicyAlgorithm")
class OffPolicyAlgorithm(BaseAlgorithm):
"""
@ -303,14 +301,14 @@ class OffPolicyAlgorithm(BaseAlgorithm):
)
def learn(
self: SelfOffPolicyAlgorithm,
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "run",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfOffPolicyAlgorithm:
) -> Self:
total_timesteps, callback = self._setup_learn(
total_timesteps,
callback,

View file

@ -1,6 +1,6 @@
import sys
import time
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Dict, List, Optional, Self, Tuple, Type, Union
import numpy as np
import torch as th
@ -14,8 +14,6 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="OnPolicyAlgorithm")
class OnPolicyAlgorithm(BaseAlgorithm):
"""
@ -251,14 +249,14 @@ class OnPolicyAlgorithm(BaseAlgorithm):
raise NotImplementedError
def learn(
self: SelfOnPolicyAlgorithm,
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
tb_log_name: str = "OnPolicyAlgorithm",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfOnPolicyAlgorithm:
) -> Self:
iteration = 0
total_timesteps, callback = self._setup_learn(

View file

@ -5,7 +5,7 @@ import copy
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Dict, List, Optional, Self, Tuple, Type, Union
import numpy as np
import torch as th
@ -33,8 +33,6 @@ from stable_baselines3.common.torch_layers import (
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
class BaseModel(nn.Module):
"""
@ -164,7 +162,7 @@ class BaseModel(nn.Module):
th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
@classmethod
def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel:
def load(cls, path: str, device: Union[th.device, str] = "auto") -> Self:
"""
Load model from path.

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Optional, Self, Tuple, Type, Union
import torch as th
@ -8,8 +8,6 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
from stable_baselines3.td3.policies import TD3Policy
from stable_baselines3.td3.td3 import TD3
SelfDDPG = TypeVar("SelfDDPG", bound="DDPG")
class DDPG(TD3):
"""
@ -112,14 +110,14 @@ class DDPG(TD3):
self._setup_model()
def learn(
self: SelfDDPG,
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "DDPG",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfDDPG:
) -> Self:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,

View file

@ -1,5 +1,5 @@
import warnings
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, ClassVar, Dict, List, Optional, Self, Tuple, Type, Union
import numpy as np
import torch as th
@ -13,8 +13,6 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork
SelfDQN = TypeVar("SelfDQN", bound="DQN")
class DQN(OffPolicyAlgorithm):
"""
@ -256,14 +254,14 @@ class DQN(OffPolicyAlgorithm):
return action, state
def learn(
self: SelfDQN,
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "DQN",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfDQN:
) -> Self:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,

View file

@ -1,5 +1,5 @@
import warnings
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
from typing import Any, ClassVar, Dict, Optional, Self, Type, Union
import numpy as np
import torch as th
@ -12,8 +12,6 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
SelfPPO = TypeVar("SelfPPO", bound="PPO")
class PPO(OnPolicyAlgorithm):
"""
@ -304,14 +302,14 @@ class PPO(OnPolicyAlgorithm):
self.logger.record("train/clip_range_vf", clip_range_vf)
def learn(
self: SelfPPO,
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
tb_log_name: str = "PPO",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfPPO:
) -> Self:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,

View file

@ -1,4 +1,4 @@
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, ClassVar, Dict, List, Optional, Self, Tuple, Type, Union
import numpy as np
import torch as th
@ -13,8 +13,6 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.sac.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
SelfSAC = TypeVar("SelfSAC", bound="SAC")
class SAC(OffPolicyAlgorithm):
"""
@ -296,14 +294,14 @@ class SAC(OffPolicyAlgorithm):
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def learn(
self: SelfSAC,
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "SAC",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfSAC:
) -> Self:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,

View file

@ -1,4 +1,4 @@
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, ClassVar, Dict, List, Optional, Self, Tuple, Type, Union
import numpy as np
import torch as th
@ -13,8 +13,6 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.td3.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy
SelfTD3 = TypeVar("SelfTD3", bound="TD3")
class TD3(OffPolicyAlgorithm):
"""
@ -211,14 +209,14 @@ class TD3(OffPolicyAlgorithm):
self.logger.record("train/critic_loss", np.mean(critic_losses))
def learn(
self: SelfTD3,
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "TD3",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfTD3:
) -> Self:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,