mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-11 00:49:42 +00:00
Switch to Self type
This commit is contained in:
parent
373166d6ac
commit
be4e1caa75
11 changed files with 39 additions and 77 deletions
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
|
||||
from typing import Any, ClassVar, Dict, Optional, Self, Type, Union
|
||||
|
||||
import torch as th
|
||||
from gymnasium import spaces
|
||||
|
|
@ -10,8 +10,6 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP
|
|||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance
|
||||
|
||||
SelfA2C = TypeVar("SelfA2C", bound="A2C")
|
||||
|
||||
|
||||
class A2C(OnPolicyAlgorithm):
|
||||
"""
|
||||
|
|
@ -190,14 +188,14 @@ class A2C(OnPolicyAlgorithm):
|
|||
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
|
||||
|
||||
def learn(
|
||||
self: SelfA2C,
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 100,
|
||||
tb_log_name: str = "A2C",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfA2C:
|
||||
) -> Self:
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import time
|
|||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Self, Tuple, Type, Union
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -41,8 +41,6 @@ from stable_baselines3.common.vec_env import (
|
|||
)
|
||||
from stable_baselines3.common.vec_env.patch_gym import _convert_space, _patch_env
|
||||
|
||||
SelfBaseAlgorithm = TypeVar("SelfBaseAlgorithm", bound="BaseAlgorithm")
|
||||
|
||||
|
||||
def maybe_make_env(env: Union[GymEnv, str], verbose: int) -> GymEnv:
|
||||
"""If env is a string, make the environment; otherwise, return env.
|
||||
|
|
@ -510,14 +508,14 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def learn(
|
||||
self: SelfBaseAlgorithm,
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 100,
|
||||
tb_log_name: str = "run",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfBaseAlgorithm:
|
||||
) -> Self:
|
||||
"""
|
||||
Return a trained model.
|
||||
|
||||
|
|
@ -637,7 +635,7 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
@classmethod
|
||||
def load( # noqa: C901
|
||||
cls: Type[SelfBaseAlgorithm],
|
||||
cls,
|
||||
path: Union[str, pathlib.Path, io.BufferedIOBase],
|
||||
env: Optional[GymEnv] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
|
|
@ -645,7 +643,7 @@ class BaseAlgorithm(ABC):
|
|||
print_system_info: bool = False,
|
||||
force_reset: bool = True,
|
||||
**kwargs,
|
||||
) -> SelfBaseAlgorithm:
|
||||
) -> Self:
|
||||
"""
|
||||
Load the model from a zip-file.
|
||||
Warning: ``load`` re-creates the model from scratch, it does not update it in-place!
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Probability distributions."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
|
||||
from typing import Any, Dict, List, Optional, Self, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -11,16 +11,6 @@ from torch.distributions import Bernoulli, Categorical, Normal
|
|||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
|
||||
SelfDistribution = TypeVar("SelfDistribution", bound="Distribution")
|
||||
SelfDiagGaussianDistribution = TypeVar("SelfDiagGaussianDistribution", bound="DiagGaussianDistribution")
|
||||
SelfSquashedDiagGaussianDistribution = TypeVar(
|
||||
"SelfSquashedDiagGaussianDistribution", bound="SquashedDiagGaussianDistribution"
|
||||
)
|
||||
SelfCategoricalDistribution = TypeVar("SelfCategoricalDistribution", bound="CategoricalDistribution")
|
||||
SelfMultiCategoricalDistribution = TypeVar("SelfMultiCategoricalDistribution", bound="MultiCategoricalDistribution")
|
||||
SelfBernoulliDistribution = TypeVar("SelfBernoulliDistribution", bound="BernoulliDistribution")
|
||||
SelfStateDependentNoiseDistribution = TypeVar("SelfStateDependentNoiseDistribution", bound="StateDependentNoiseDistribution")
|
||||
|
||||
|
||||
class Distribution(ABC):
|
||||
"""Abstract base class for distributions."""
|
||||
|
|
@ -37,7 +27,7 @@ class Distribution(ABC):
|
|||
concrete classes."""
|
||||
|
||||
@abstractmethod
|
||||
def proba_distribution(self: SelfDistribution, *args, **kwargs) -> SelfDistribution:
|
||||
def proba_distribution(self, *args, **kwargs) -> Self:
|
||||
"""Set parameters of the distribution.
|
||||
|
||||
:return: self
|
||||
|
|
@ -150,9 +140,7 @@ class DiagGaussianDistribution(Distribution):
|
|||
log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True)
|
||||
return mean_actions, log_std
|
||||
|
||||
def proba_distribution(
|
||||
self: SelfDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor
|
||||
) -> SelfDiagGaussianDistribution:
|
||||
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Self:
|
||||
"""
|
||||
Create the distribution given its parameters (mean, std)
|
||||
|
||||
|
|
@ -218,9 +206,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
self.epsilon = epsilon
|
||||
self.gaussian_actions: Optional[th.Tensor] = None
|
||||
|
||||
def proba_distribution(
|
||||
self: SelfSquashedDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor
|
||||
) -> SelfSquashedDiagGaussianDistribution:
|
||||
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Self:
|
||||
super().proba_distribution(mean_actions, log_std)
|
||||
return self
|
||||
|
||||
|
|
@ -284,7 +270,7 @@ class CategoricalDistribution(Distribution):
|
|||
action_logits = nn.Linear(latent_dim, self.action_dim)
|
||||
return action_logits
|
||||
|
||||
def proba_distribution(self: SelfCategoricalDistribution, action_logits: th.Tensor) -> SelfCategoricalDistribution:
|
||||
def proba_distribution(self, action_logits: th.Tensor) -> Self:
|
||||
self.distribution = Categorical(logits=action_logits)
|
||||
return self
|
||||
|
||||
|
|
@ -336,9 +322,7 @@ class MultiCategoricalDistribution(Distribution):
|
|||
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
|
||||
return action_logits
|
||||
|
||||
def proba_distribution(
|
||||
self: SelfMultiCategoricalDistribution, action_logits: th.Tensor
|
||||
) -> SelfMultiCategoricalDistribution:
|
||||
def proba_distribution(self, action_logits: th.Tensor) -> Self:
|
||||
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, list(self.action_dims), dim=1)]
|
||||
return self
|
||||
|
||||
|
|
@ -391,7 +375,7 @@ class BernoulliDistribution(Distribution):
|
|||
action_logits = nn.Linear(latent_dim, self.action_dims)
|
||||
return action_logits
|
||||
|
||||
def proba_distribution(self: SelfBernoulliDistribution, action_logits: th.Tensor) -> SelfBernoulliDistribution:
|
||||
def proba_distribution(self, action_logits: th.Tensor) -> Self:
|
||||
self.distribution = Bernoulli(logits=action_logits)
|
||||
return self
|
||||
|
||||
|
|
@ -538,9 +522,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
self.sample_weights(log_std)
|
||||
return mean_actions_net, log_std
|
||||
|
||||
def proba_distribution(
|
||||
self: SelfStateDependentNoiseDistribution, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
|
||||
) -> SelfStateDependentNoiseDistribution:
|
||||
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor) -> Self:
|
||||
"""
|
||||
Create the distribution given its parameters (mean, std)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sys
|
|||
import time
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Dict, List, Optional, Self, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -21,8 +21,6 @@ from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
|
|||
from stable_baselines3.common.vec_env import VecEnv
|
||||
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
|
||||
|
||||
SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="OffPolicyAlgorithm")
|
||||
|
||||
|
||||
class OffPolicyAlgorithm(BaseAlgorithm):
|
||||
"""
|
||||
|
|
@ -303,14 +301,14 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
)
|
||||
|
||||
def learn(
|
||||
self: SelfOffPolicyAlgorithm,
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
tb_log_name: str = "run",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfOffPolicyAlgorithm:
|
||||
) -> Self:
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
total_timesteps,
|
||||
callback,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Dict, List, Optional, Self, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -14,8 +14,6 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
|
|||
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
|
||||
SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="OnPolicyAlgorithm")
|
||||
|
||||
|
||||
class OnPolicyAlgorithm(BaseAlgorithm):
|
||||
"""
|
||||
|
|
@ -251,14 +249,14 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
raise NotImplementedError
|
||||
|
||||
def learn(
|
||||
self: SelfOnPolicyAlgorithm,
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
tb_log_name: str = "OnPolicyAlgorithm",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfOnPolicyAlgorithm:
|
||||
) -> Self:
|
||||
iteration = 0
|
||||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import copy
|
|||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Dict, List, Optional, Self, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -33,8 +33,6 @@ from stable_baselines3.common.torch_layers import (
|
|||
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
||||
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
|
||||
|
||||
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
|
||||
|
||||
|
||||
class BaseModel(nn.Module):
|
||||
"""
|
||||
|
|
@ -164,7 +162,7 @@ class BaseModel(nn.Module):
|
|||
th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
|
||||
|
||||
@classmethod
|
||||
def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel:
|
||||
def load(cls, path: str, device: Union[th.device, str] = "auto") -> Self:
|
||||
"""
|
||||
Load model from path.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Dict, Optional, Self, Tuple, Type, Union
|
||||
|
||||
import torch as th
|
||||
|
||||
|
|
@ -8,8 +8,6 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
|
|||
from stable_baselines3.td3.policies import TD3Policy
|
||||
from stable_baselines3.td3.td3 import TD3
|
||||
|
||||
SelfDDPG = TypeVar("SelfDDPG", bound="DDPG")
|
||||
|
||||
|
||||
class DDPG(TD3):
|
||||
"""
|
||||
|
|
@ -112,14 +110,14 @@ class DDPG(TD3):
|
|||
self._setup_model()
|
||||
|
||||
def learn(
|
||||
self: SelfDDPG,
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
tb_log_name: str = "DDPG",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfDDPG:
|
||||
) -> Self:
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import warnings
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Self, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -13,8 +13,6 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
|
|||
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
|
||||
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork
|
||||
|
||||
SelfDQN = TypeVar("SelfDQN", bound="DQN")
|
||||
|
||||
|
||||
class DQN(OffPolicyAlgorithm):
|
||||
"""
|
||||
|
|
@ -256,14 +254,14 @@ class DQN(OffPolicyAlgorithm):
|
|||
return action, state
|
||||
|
||||
def learn(
|
||||
self: SelfDQN,
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
tb_log_name: str = "DQN",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfDQN:
|
||||
) -> Self:
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import warnings
|
||||
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
|
||||
from typing import Any, ClassVar, Dict, Optional, Self, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -12,8 +12,6 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP
|
|||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
|
||||
|
||||
SelfPPO = TypeVar("SelfPPO", bound="PPO")
|
||||
|
||||
|
||||
class PPO(OnPolicyAlgorithm):
|
||||
"""
|
||||
|
|
@ -304,14 +302,14 @@ class PPO(OnPolicyAlgorithm):
|
|||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||
|
||||
def learn(
|
||||
self: SelfPPO,
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
tb_log_name: str = "PPO",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfPPO:
|
||||
) -> Self:
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Self, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -13,8 +13,6 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
|
|||
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||
from stable_baselines3.sac.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
|
||||
|
||||
SelfSAC = TypeVar("SelfSAC", bound="SAC")
|
||||
|
||||
|
||||
class SAC(OffPolicyAlgorithm):
|
||||
"""
|
||||
|
|
@ -296,14 +294,14 @@ class SAC(OffPolicyAlgorithm):
|
|||
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
||||
|
||||
def learn(
|
||||
self: SelfSAC,
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
tb_log_name: str = "SAC",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfSAC:
|
||||
) -> Self:
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Self, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -13,8 +13,6 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
|
|||
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||
from stable_baselines3.td3.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy
|
||||
|
||||
SelfTD3 = TypeVar("SelfTD3", bound="TD3")
|
||||
|
||||
|
||||
class TD3(OffPolicyAlgorithm):
|
||||
"""
|
||||
|
|
@ -211,14 +209,14 @@ class TD3(OffPolicyAlgorithm):
|
|||
self.logger.record("train/critic_loss", np.mean(critic_losses))
|
||||
|
||||
def learn(
|
||||
self: SelfTD3,
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
tb_log_name: str = "TD3",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfTD3:
|
||||
) -> Self:
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
|
|
|
|||
Loading…
Reference in a new issue