Drop python 3.8 and add python 3.12 support (#2041)

* Drop python 3.8 support, add python 3.12 support

* Upgrade to python 3.9 syntax

* Fixes for Numpy v2

* Fix doc warning
This commit is contained in:
Antonin RAFFIN 2024-11-18 15:40:36 +01:00 committed by GitHub
parent 020ee42f4d
commit daaebd0a52
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
66 changed files with 530 additions and 483 deletions

View file

@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
include:
# Default version
- gymnasium-version: "1.0.0"
@ -48,7 +48,8 @@ jobs:
- name: Install specific version of gym
run: |
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
# Only run for python 3.10, downgrade gym to 0.29.1
uv pip install --system "numpy<2"
# Only run for python 3.10, downgrade gym to 0.29.1, numpy<2
if: matrix.gymnasium-version != '1.0.0'
- name: Lint with ruff
run: |
@ -62,8 +63,6 @@ jobs:
- name: Type check
run: |
make type
# Do not run for python 3.8 (mypy internal error)
if: matrix.python-version != '3.8'
- name: Test with pytest
run: |
make pytest

View file

@ -100,10 +100,10 @@ It provides a minimal number of features compared to SB3 but can be much faster
## Installation
**Note:** Stable-Baselines3 supports PyTorch >= 1.13
**Note:** Stable-Baselines3 supports PyTorch >= 2.3
### Prerequisites
Stable Baselines3 requires Python 3.8+.
Stable Baselines3 requires Python 3.9+.
#### Windows

View file

@ -12,7 +12,7 @@ dependencies:
- cloudpickle
- opencv-python-headless
- pandas
- numpy>=1.20,<2.0
- numpy>=1.20,<3.0
- matplotlib
- sphinx>=5,<9
- sphinx_rtd_theme>=1.3.0

View file

@ -14,7 +14,6 @@
import datetime
import os
import sys
from typing import Dict
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
# PyEnchant.
@ -151,7 +150,7 @@ htmlhelp_basename = "StableBaselines3doc"
# -- Options for LaTeX output ------------------------------------------------
latex_elements: Dict[str, str] = {
latex_elements: dict[str, str] = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',

View file

@ -7,7 +7,7 @@ Installation
Prerequisites
-------------
Stable-Baselines3 requires python 3.8+ and PyTorch >= 1.13
Stable-Baselines3 requires python 3.9+ and PyTorch >= 2.3
Windows
~~~~~~~

View file

@ -20,6 +20,8 @@ RL Baselines3 Zoo provides a collection of pre-trained agents, scripts for train
SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
SBX (SB3 + Jax): https://github.com/araffin/sbx
Main Features
--------------

View file

@ -3,6 +3,41 @@
Changelog
==========
Release 2.5.0a0 (WIP)
--------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Increased minimum required version of PyTorch to 2.3.0
- Removed support for Python 3.8
New Features:
^^^^^^^^^^^^^
- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too
- Added official support for Python 3.12
Bug Fixes:
^^^^^^^^^^
`SB3-Contrib`_
^^^^^^^^^^^^^^
`RL Zoo`_
^^^^^^^^^
`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
Documentation:
^^^^^^^^^^^^^^
Release 2.4.0 (2024-11-18)
--------------------------

View file

@ -1,8 +1,8 @@
[tool.ruff]
# Same as Black.
line-length = 127
# Assume Python 3.8
target-version = "py38"
# Assume Python 3.9
target-version = "py39"
[tool.ruff.lint]
# See https://beta.ruff.rs/docs/rules/

View file

@ -77,8 +77,8 @@ setup(
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
"gymnasium>=0.29.1,<1.1.0",
"numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302
"torch>=1.13",
"numpy>=1.20,<3.0",
"torch>=2.3,<3.0",
# For saving models
"cloudpickle",
# For reading logs
@ -135,7 +135,7 @@ setup(
long_description=long_description,
long_description_content_type="text/markdown",
version=__version__,
python_requires=">=3.8",
python_requires=">=3.9",
# PyPI package information.
project_urls={
"Code": "https://github.com/DLR-RM/stable-baselines3",
@ -147,10 +147,10 @@ setup(
},
classifiers=[
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
)

View file

@ -1,4 +1,4 @@
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
from typing import Any, ClassVar, Optional, TypeVar, Union
import torch as th
from gymnasium import spaces
@ -57,7 +57,7 @@ class A2C(OnPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
"MlpPolicy": ActorCriticPolicy,
"CnnPolicy": ActorCriticCnnPolicy,
"MultiInputPolicy": MultiInputActorCriticPolicy,
@ -65,7 +65,7 @@ class A2C(OnPolicyAlgorithm):
def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],
policy: Union[str, type[ActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 7e-4,
n_steps: int = 5,
@ -78,12 +78,12 @@ class A2C(OnPolicyAlgorithm):
use_rms_prop: bool = True,
use_sde: bool = False,
sde_sample_freq: int = -1,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
rollout_buffer_class: Optional[type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[dict[str, Any]] = None,
normalize_advantage: bool = False,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_kwargs: Optional[dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",

View file

@ -1,4 +1,4 @@
from typing import Dict, SupportsFloat
from typing import SupportsFloat
import gymnasium as gym
import numpy as np
@ -64,7 +64,7 @@ class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
assert noops > 0
obs = np.zeros(0)
info: Dict = {}
info: dict = {}
for _ in range(noops):
obs, _, terminated, truncated, info = self.env.step(self.noop_action)
if terminated or truncated:

View file

@ -6,7 +6,8 @@ import time
import warnings
from abc import ABC, abstractmethod
from collections import deque
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
from collections.abc import Iterable
from typing import Any, ClassVar, Optional, TypeVar, Union
import gymnasium as gym
import numpy as np
@ -94,7 +95,7 @@ class BaseAlgorithm(ABC):
"""
# Policy aliases (see _get_policy_from_name())
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {}
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {}
policy: BasePolicy
observation_space: spaces.Space
action_space: spaces.Space
@ -104,10 +105,10 @@ class BaseAlgorithm(ABC):
def __init__(
self,
policy: Union[str, Type[BasePolicy]],
policy: Union[str, type[BasePolicy]],
env: Union[GymEnv, str, None],
learning_rate: Union[float, Schedule],
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_kwargs: Optional[dict[str, Any]] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
verbose: int = 0,
@ -117,7 +118,7 @@ class BaseAlgorithm(ABC):
seed: Optional[int] = None,
use_sde: bool = False,
sde_sample_freq: int = -1,
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None,
) -> None:
if isinstance(policy, str):
self.policy_class = self._get_policy_from_name(policy)
@ -141,10 +142,10 @@ class BaseAlgorithm(ABC):
self.start_time = 0.0
self.learning_rate = learning_rate
self.tensorboard_log = tensorboard_log
self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
self._last_obs = None # type: Optional[Union[np.ndarray, dict[str, np.ndarray]]]
self._last_episode_starts = None # type: Optional[np.ndarray]
# When using VecNormalize:
self._last_original_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
self._last_original_obs = None # type: Optional[Union[np.ndarray, dict[str, np.ndarray]]]
self._episode_num = 0
# Used for gSDE only
self.use_sde = use_sde
@ -283,7 +284,7 @@ class BaseAlgorithm(ABC):
"""
self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps)
def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None:
def _update_learning_rate(self, optimizers: Union[list[th.optim.Optimizer], th.optim.Optimizer]) -> None:
"""
Update the optimizers learning rate using the current learning rate schedule
and the current progress remaining (from 1 to 0).
@ -299,7 +300,7 @@ class BaseAlgorithm(ABC):
for optimizer in optimizers:
update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining))
def _excluded_save_params(self) -> List[str]:
def _excluded_save_params(self) -> list[str]:
"""
Returns the names of the parameters that should be excluded from being
saved by pickling. E.g. replay buffers are skipped by default
@ -320,7 +321,7 @@ class BaseAlgorithm(ABC):
"_custom_logger",
]
def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]:
def _get_policy_from_name(self, policy_name: str) -> type[BasePolicy]:
"""
Get a policy class from its name representation.
@ -337,7 +338,7 @@ class BaseAlgorithm(ABC):
else:
raise ValueError(f"Policy {policy_name} unknown")
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
"""
Get the name of the torch variables that will be saved with
PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default
@ -387,7 +388,7 @@ class BaseAlgorithm(ABC):
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
progress_bar: bool = False,
) -> Tuple[int, BaseCallback]:
) -> tuple[int, BaseCallback]:
"""
Initialize different variables needed for training.
@ -435,7 +436,7 @@ class BaseAlgorithm(ABC):
return total_timesteps, callback
def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
def _update_info_buffer(self, infos: list[dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
"""
Retrieve reward, episode length, episode success and update the buffer
if using Monitor wrapper or a GoalEnv.
@ -535,11 +536,11 @@ class BaseAlgorithm(ABC):
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
observation: Union[np.ndarray, dict[str, np.ndarray]],
state: Optional[tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
@ -640,11 +641,11 @@ class BaseAlgorithm(ABC):
@classmethod
def load( # noqa: C901
cls: Type[SelfBaseAlgorithm],
cls: type[SelfBaseAlgorithm],
path: Union[str, pathlib.Path, io.BufferedIOBase],
env: Optional[GymEnv] = None,
device: Union[th.device, str] = "auto",
custom_objects: Optional[Dict[str, Any]] = None,
custom_objects: Optional[dict[str, Any]] = None,
print_system_info: bool = False,
force_reset: bool = True,
**kwargs,
@ -800,7 +801,7 @@ class BaseAlgorithm(ABC):
model.policy.reset_noise() # type: ignore[operator]
return model
def get_parameters(self) -> Dict[str, Dict]:
def get_parameters(self) -> dict[str, dict]:
"""
Return the parameters of the agent. This includes parameters from different networks, e.g.
critics (value functions) and policies (pi functions).

View file

@ -1,6 +1,7 @@
import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from collections.abc import Generator
from typing import Any, Optional, Union
import numpy as np
import torch as th
@ -36,7 +37,7 @@ class BaseBuffer(ABC):
"""
observation_space: spaces.Space
obs_shape: Tuple[int, ...]
obs_shape: tuple[int, ...]
def __init__(
self,
@ -140,9 +141,9 @@ class BaseBuffer(ABC):
@staticmethod
def _normalize_obs(
obs: Union[np.ndarray, Dict[str, np.ndarray]],
obs: Union[np.ndarray, dict[str, np.ndarray]],
env: Optional[VecNormalize] = None,
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
) -> Union[np.ndarray, dict[str, np.ndarray]]:
if env is not None:
return env.normalize_obs(obs)
return obs
@ -250,7 +251,7 @@ class ReplayBuffer(BaseBuffer):
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
infos: list[dict[str, Any]],
) -> None:
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
@ -538,9 +539,9 @@ class DictReplayBuffer(ReplayBuffer):
"""
observation_space: spaces.Dict
obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment]
observations: Dict[str, np.ndarray] # type: ignore[assignment]
next_observations: Dict[str, np.ndarray] # type: ignore[assignment]
obs_shape: dict[str, tuple[int, ...]] # type: ignore[assignment]
observations: dict[str, np.ndarray] # type: ignore[assignment]
next_observations: dict[str, np.ndarray] # type: ignore[assignment]
def __init__(
self,
@ -609,12 +610,12 @@ class DictReplayBuffer(ReplayBuffer):
def add( # type: ignore[override]
self,
obs: Dict[str, np.ndarray],
next_obs: Dict[str, np.ndarray],
obs: dict[str, np.ndarray],
next_obs: dict[str, np.ndarray],
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
infos: list[dict[str, Any]],
) -> None:
# Copy to avoid modification by reference
for key in self.observations.keys():
@ -718,8 +719,8 @@ class DictRolloutBuffer(RolloutBuffer):
"""
observation_space: spaces.Dict
obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment]
observations: Dict[str, np.ndarray] # type: ignore[assignment]
obs_shape: dict[str, tuple[int, ...]] # type: ignore[assignment]
observations: dict[str, np.ndarray] # type: ignore[assignment]
def __init__(
self,
@ -757,7 +758,7 @@ class DictRolloutBuffer(RolloutBuffer):
def add( # type: ignore[override]
self,
obs: Dict[str, np.ndarray],
obs: dict[str, np.ndarray],
action: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,

View file

@ -1,7 +1,7 @@
import os
import warnings
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import gymnasium as gym
import numpy as np
@ -45,8 +45,8 @@ class BaseCallback(ABC):
# n_envs * n times env.step() was called
self.num_timesteps = 0 # type: int
self.verbose = verbose
self.locals: Dict[str, Any] = {}
self.globals: Dict[str, Any] = {}
self.locals: dict[str, Any] = {}
self.globals: dict[str, Any] = {}
# Sometimes, for event callback, it is useful
# to have access to the parent object
self.parent = None # type: Optional[BaseCallback]
@ -75,7 +75,7 @@ class BaseCallback(ABC):
def _init_callback(self) -> None:
pass
def on_training_start(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
def on_training_start(self, locals_: dict[str, Any], globals_: dict[str, Any]) -> None:
# Those are reference and will be updated automatically
self.locals = locals_
self.globals = globals_
@ -125,7 +125,7 @@ class BaseCallback(ABC):
def _on_rollout_end(self) -> None:
pass
def update_locals(self, locals_: Dict[str, Any]) -> None:
def update_locals(self, locals_: dict[str, Any]) -> None:
"""
Update the references to the local variables.
@ -134,7 +134,7 @@ class BaseCallback(ABC):
self.locals.update(locals_)
self.update_child_locals(locals_)
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
def update_child_locals(self, locals_: dict[str, Any]) -> None:
"""
Update the references to the local variables on sub callbacks.
@ -177,7 +177,7 @@ class EventCallback(BaseCallback):
def _on_step(self) -> bool:
return True
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
def update_child_locals(self, locals_: dict[str, Any]) -> None:
"""
Update the references to the local variables.
@ -195,7 +195,7 @@ class CallbackList(BaseCallback):
sequentially.
"""
def __init__(self, callbacks: List[BaseCallback]):
def __init__(self, callbacks: list[BaseCallback]):
super().__init__()
assert isinstance(callbacks, list)
self.callbacks = callbacks
@ -231,7 +231,7 @@ class CallbackList(BaseCallback):
for callback in self.callbacks:
callback.on_training_end()
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
def update_child_locals(self, locals_: dict[str, Any]) -> None:
"""
Update the references to the local variables.
@ -328,7 +328,7 @@ class ConvertCallback(BaseCallback):
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], bool]], verbose: int = 0):
def __init__(self, callback: Optional[Callable[[dict[str, Any], dict[str, Any]], bool]], verbose: int = 0):
super().__init__(verbose)
self.callback = callback
@ -405,12 +405,12 @@ class EvalCallback(EventCallback):
if log_path is not None:
log_path = os.path.join(log_path, "evaluations")
self.log_path = log_path
self.evaluations_results: List[List[float]] = []
self.evaluations_timesteps: List[int] = []
self.evaluations_length: List[List[int]] = []
self.evaluations_results: list[list[float]] = []
self.evaluations_timesteps: list[int] = []
self.evaluations_length: list[list[int]] = []
# For computing success rate
self._is_success_buffer: List[bool] = []
self.evaluations_successes: List[List[bool]] = []
self._is_success_buffer: list[bool] = []
self.evaluations_successes: list[list[bool]] = []
def _init_callback(self) -> None:
# Does not work in some corner cases, where the wrapper is not the same
@ -427,7 +427,7 @@ class EvalCallback(EventCallback):
if self.callback_on_new_best is not None:
self.callback_on_new_best.init_callback(self.model)
def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
def _log_success_callback(self, locals_: dict[str, Any], globals_: dict[str, Any]) -> None:
"""
Callback passed to the ``evaluate_policy`` function
in order to log the success rate (when applicable),
@ -530,7 +530,7 @@ class EvalCallback(EventCallback):
return continue_training
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
def update_child_locals(self, locals_: dict[str, Any]) -> None:
"""
Update the references to the local variables.

View file

@ -1,7 +1,7 @@
"""Probability distributions."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
from typing import Any, Optional, TypeVar, Union
import numpy as np
import torch as th
@ -30,7 +30,7 @@ class Distribution(ABC):
self.distribution = None
@abstractmethod
def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, Tuple[nn.Module, nn.Parameter]]:
def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, tuple[nn.Module, nn.Parameter]]:
"""Create the layers and parameters that represent the distribution.
Subclasses must define this, but the arguments and return type vary between
@ -98,7 +98,7 @@ class Distribution(ABC):
"""
@abstractmethod
def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]:
def log_prob_from_params(self, *args, **kwargs) -> tuple[th.Tensor, th.Tensor]:
"""
Returns samples and the associated log probabilities
from the probability distribution given its parameters.
@ -135,7 +135,7 @@ class DiagGaussianDistribution(Distribution):
self.mean_actions = None
self.log_std = None
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> tuple[nn.Module, nn.Parameter]:
"""
Create the layers and parameter that represent the distribution:
one output will be the mean of the Gaussian, the other parameter will be the
@ -190,7 +190,7 @@ class DiagGaussianDistribution(Distribution):
self.proba_distribution(mean_actions, log_std)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
"""
Compute the log probability of taking an action
given the distribution parameters.
@ -254,7 +254,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
# Squash the output
return th.tanh(self.gaussian_actions)
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
action = self.actions_from_params(mean_actions, log_std)
log_prob = self.log_prob(action, self.gaussian_actions)
return action, log_prob
@ -305,7 +305,7 @@ class CategoricalDistribution(Distribution):
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob
@ -318,7 +318,7 @@ class MultiCategoricalDistribution(Distribution):
:param action_dims: List of sizes of discrete action spaces
"""
def __init__(self, action_dims: List[int]):
def __init__(self, action_dims: list[int]):
super().__init__()
self.action_dims = action_dims
@ -362,7 +362,7 @@ class MultiCategoricalDistribution(Distribution):
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob
@ -412,7 +412,7 @@ class BernoulliDistribution(Distribution):
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob
@ -513,7 +513,7 @@ class StateDependentNoiseDistribution(Distribution):
def proba_distribution_net(
self, latent_dim: int, log_std_init: float = -2.0, latent_sde_dim: Optional[int] = None
) -> Tuple[nn.Module, nn.Parameter]:
) -> tuple[nn.Module, nn.Parameter]:
"""
Create the layers and parameter that represent the distribution:
one output will be the deterministic action, the other parameter will be the
@ -611,7 +611,7 @@ class StateDependentNoiseDistribution(Distribution):
def log_prob_from_params(
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
) -> Tuple[th.Tensor, th.Tensor]:
) -> tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(mean_actions, log_std, latent_sde)
log_prob = self.log_prob(actions)
return actions, log_prob
@ -661,7 +661,7 @@ class TanhBijector:
def make_proba_distribution(
action_space: spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
action_space: spaces.Space, use_sde: bool = False, dist_kwargs: Optional[dict[str, Any]] = None
) -> Distribution:
"""
Return an instance of Distribution for the correct type of action space

View file

@ -1,5 +1,5 @@
import warnings
from typing import Any, Dict, Union
from typing import Any, Union
import gymnasium as gym
import numpy as np
@ -172,10 +172,10 @@ def _check_goal_env_obs(obs: dict, observation_space: spaces.Dict, method_name:
def _check_goal_env_compute_reward(
obs: Dict[str, Union[np.ndarray, int]],
obs: dict[str, Union[np.ndarray, int]],
env: gym.Env,
reward: float,
info: Dict[str, Any],
info: dict[str, Any],
) -> None:
"""
Check that reward is computed with `compute_reward`

View file

@ -1,5 +1,5 @@
import os
from typing import Any, Callable, Dict, Optional, Type, Union
from typing import Any, Callable, Optional, Union
import gymnasium as gym
@ -9,7 +9,7 @@ from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
from stable_baselines3.common.vec_env.patch_gym import _patch_env
def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]:
def unwrap_wrapper(env: gym.Env, wrapper_class: type[gym.Wrapper]) -> Optional[gym.Wrapper]:
"""
Retrieve a ``VecEnvWrapper`` object by recursively searching.
@ -25,7 +25,7 @@ def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[g
return None
def is_wrapped(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> bool:
def is_wrapped(env: gym.Env, wrapper_class: type[gym.Wrapper]) -> bool:
"""
Check if a given environment has been wrapped with a given wrapper.
@ -43,11 +43,11 @@ def make_vec_env(
start_index: int = 0,
monitor_dir: Optional[str] = None,
wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None,
env_kwargs: Optional[Dict[str, Any]] = None,
vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
vec_env_kwargs: Optional[Dict[str, Any]] = None,
monitor_kwargs: Optional[Dict[str, Any]] = None,
wrapper_kwargs: Optional[Dict[str, Any]] = None,
env_kwargs: Optional[dict[str, Any]] = None,
vec_env_cls: Optional[type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
vec_env_kwargs: Optional[dict[str, Any]] = None,
monitor_kwargs: Optional[dict[str, Any]] = None,
wrapper_kwargs: Optional[dict[str, Any]] = None,
) -> VecEnv:
"""
Create a wrapped, monitored ``VecEnv``.
@ -134,11 +134,11 @@ def make_atari_env(
seed: Optional[int] = None,
start_index: int = 0,
monitor_dir: Optional[str] = None,
wrapper_kwargs: Optional[Dict[str, Any]] = None,
env_kwargs: Optional[Dict[str, Any]] = None,
vec_env_cls: Optional[Union[Type[DummyVecEnv], Type[SubprocVecEnv]]] = None,
vec_env_kwargs: Optional[Dict[str, Any]] = None,
monitor_kwargs: Optional[Dict[str, Any]] = None,
wrapper_kwargs: Optional[dict[str, Any]] = None,
env_kwargs: Optional[dict[str, Any]] = None,
vec_env_cls: Optional[Union[type[DummyVecEnv], type[SubprocVecEnv]]] = None,
vec_env_kwargs: Optional[dict[str, Any]] = None,
monitor_kwargs: Optional[dict[str, Any]] = None,
) -> VecEnv:
"""
Create a wrapped, monitored VecEnv for Atari.

View file

@ -1,5 +1,5 @@
from collections import OrderedDict
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Optional, Union
import numpy as np
from gymnasium import Env, spaces
@ -75,14 +75,17 @@ class BitFlippingEnv(Env):
:param state:
:return:
"""
if self.discrete_obs_space:
# Convert from int8 to int32 for NumPy 2.0
state = state.astype(np.int32)
# The internal state is the binary representation of the
# observed one
return int(sum(state[i] * 2**i for i in range(len(state))))
if self.image_obs_space:
size = np.prod(self.image_shape)
image = np.concatenate((state * 255, np.zeros(size - len(state), dtype=np.uint8)))
image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8)))
return image.reshape(self.image_shape).astype(np.uint8)
return state
@ -163,7 +166,7 @@ class BitFlippingEnv(Env):
}
)
def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]:
def _get_obs(self) -> dict[str, Union[int, np.ndarray]]:
"""
Helper to create the observation.
@ -178,8 +181,8 @@ class BitFlippingEnv(Env):
)
def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict] = None
) -> Tuple[Dict[str, Union[int, np.ndarray]], Dict]:
self, *, seed: Optional[int] = None, options: Optional[dict] = None
) -> tuple[dict[str, Union[int, np.ndarray]], dict]:
if seed is not None:
self._obs_space.seed(seed)
self.current_step = 0
@ -207,7 +210,7 @@ class BitFlippingEnv(Env):
return obs, reward, terminated, truncated, info
def compute_reward(
self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[Dict[str, Any]]
self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[dict[str, Any]]
) -> np.float32:
# As we are using a vectorized version, we need to keep track of the `batch_size`
if isinstance(achieved_goal, int):

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union
from typing import Any, Generic, Optional, TypeVar, Union
import gymnasium as gym
import numpy as np
@ -34,7 +34,7 @@ class IdentityEnv(gym.Env, Generic[T]):
self.num_resets = -1 # Becomes 0 after __init__ exits.
self.reset()
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[T, Dict]:
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[T, dict]:
if seed is not None:
super().reset(seed=seed)
self.current_step = 0
@ -42,7 +42,7 @@ class IdentityEnv(gym.Env, Generic[T]):
self._choose_next_state()
return self.state, {}
def step(self, action: T) -> Tuple[T, float, bool, bool, Dict[str, Any]]:
def step(self, action: T) -> tuple[T, float, bool, bool, dict[str, Any]]:
reward = self._get_reward(action)
self._choose_next_state()
self.current_step += 1
@ -74,7 +74,7 @@ class IdentityEnvBox(IdentityEnv[np.ndarray]):
super().__init__(ep_length=ep_length, space=space)
self.eps = eps
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]:
reward = self._get_reward(action)
self._choose_next_state()
self.current_step += 1
@ -142,7 +142,7 @@ class FakeImageEnv(gym.Env):
self.ep_length = 10
self.current_step = 0
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]:
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[np.ndarray, dict]:
if seed is not None:
super().reset(seed=seed)
self.current_step = 0

View file

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple, Union
from typing import Optional, Union
import gymnasium as gym
import numpy as np
@ -73,7 +73,7 @@ class SimpleMultiObsEnv(gym.Env):
self.init_possible_transitions()
self.num_col = num_col
self.state_mapping: List[Dict[str, np.ndarray]] = []
self.state_mapping: list[dict[str, np.ndarray]] = []
self.init_state_mapping(num_col, num_row)
self.max_state = len(self.state_mapping) - 1
@ -94,7 +94,7 @@ class SimpleMultiObsEnv(gym.Env):
for j in range(num_row):
self.state_mapping.append({"vec": col_vecs[i], "img": row_imgs[j].reshape(self.img_size)})
def get_state_mapping(self) -> Dict[str, np.ndarray]:
def get_state_mapping(self) -> dict[str, np.ndarray]:
"""
Uses the state to get the observation mapping.
@ -166,7 +166,7 @@ class SimpleMultiObsEnv(gym.Env):
"""
print(self.log)
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[Dict[str, np.ndarray], Dict]:
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[dict[str, np.ndarray], dict]:
"""
Resets the environment state and step count and returns reset observation.

View file

@ -1,5 +1,5 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import gymnasium as gym
import numpy as np
@ -14,11 +14,11 @@ def evaluate_policy(
n_eval_episodes: int = 10,
deterministic: bool = True,
render: bool = False,
callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
callback: Optional[Callable[[dict[str, Any], dict[str, Any]], None]] = None,
reward_threshold: Optional[float] = None,
return_episode_rewards: bool = False,
warn: bool = True,
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
) -> Union[tuple[float, float], tuple[list[float], list[int]]]:
"""
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
If a vector env is passed in, this divides the episodes to evaluate onto the

View file

@ -5,8 +5,9 @@ import sys
import tempfile
import warnings
from collections import defaultdict
from collections.abc import Mapping, Sequence
from io import TextIOBase
from typing import Any, Dict, List, Mapping, Optional, Sequence, TextIO, Tuple, Union
from typing import Any, Optional, TextIO, Union
import matplotlib.figure
import numpy as np
@ -114,7 +115,7 @@ class KVWriter:
Key Value writer
"""
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None:
"""
Write a dictionary to file
@ -136,7 +137,7 @@ class SeqWriter:
sequence writer
"""
def write_sequence(self, sequence: List[str]) -> None:
def write_sequence(self, sequence: list[str]) -> None:
"""
write_sequence an array to file
@ -172,7 +173,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
else:
raise ValueError(f"Expected file or str, got {filename_or_file}")
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None:
# Create strings for printing
key2str = {}
tag = ""
@ -244,7 +245,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
string = string[: self.max_length - 3] + "..."
return string
def write_sequence(self, sequence: List[str]) -> None:
def write_sequence(self, sequence: list[str]) -> None:
for i, elem in enumerate(sequence):
self.file.write(elem)
if i < len(sequence) - 1: # add space unless this is the last one
@ -260,7 +261,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
self.file.close()
def filter_excluded_keys(key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], _format: str) -> Dict[str, Any]:
def filter_excluded_keys(key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], _format: str) -> dict[str, Any]:
"""
Filters the keys specified by ``key_exclude`` for the specified format
@ -286,7 +287,7 @@ class JSONOutputFormat(KVWriter):
def __init__(self, filename: str):
self.file = open(filename, "w")
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None:
def cast_to_json_serializable(value: Any):
if isinstance(value, Video):
raise FormatUnsupportedError(["json"], "video")
@ -328,12 +329,12 @@ class CSVOutputFormat(KVWriter):
"""
def __init__(self, filename: str):
self.file = open(filename, "w+t")
self.keys: List[str] = []
self.file = open(filename, "w+")
self.keys: list[str] = []
self.separator = ","
self.quotechar = '"'
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None:
# Add our current row to the history
key_values = filter_excluded_keys(key_values, key_excluded, "csv")
extra_keys = key_values.keys() - self.keys
@ -399,7 +400,7 @@ class TensorBoardOutputFormat(KVWriter):
self.writer = SummaryWriter(log_dir=folder)
self._is_closed = False
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None:
assert not self._is_closed, "The SummaryWriter was closed, please re-create one."
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
if excluded is not None and "tensorboard" in excluded:
@ -481,16 +482,16 @@ class Logger:
:param output_formats: the list of output formats
"""
def __init__(self, folder: Optional[str], output_formats: List[KVWriter]):
self.name_to_value: Dict[str, float] = defaultdict(float) # values this iteration
self.name_to_count: Dict[str, int] = defaultdict(int)
self.name_to_excluded: Dict[str, Tuple[str, ...]] = {}
def __init__(self, folder: Optional[str], output_formats: list[KVWriter]):
self.name_to_value: dict[str, float] = defaultdict(float) # values this iteration
self.name_to_count: dict[str, int] = defaultdict(int)
self.name_to_excluded: dict[str, tuple[str, ...]] = {}
self.level = INFO
self.dir = folder
self.output_formats = output_formats
@staticmethod
def to_tuple(string_or_tuple: Optional[Union[str, Tuple[str, ...]]]) -> Tuple[str, ...]:
def to_tuple(string_or_tuple: Optional[Union[str, tuple[str, ...]]]) -> tuple[str, ...]:
"""
Helper function to convert str to tuple of str.
"""
@ -500,7 +501,7 @@ class Logger:
return string_or_tuple
return (string_or_tuple,)
def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
def record(self, key: str, value: Any, exclude: Optional[Union[str, tuple[str, ...]]] = None) -> None:
"""
Log a value of some diagnostic
Call this once for each diagnostic quantity, each iteration
@ -513,7 +514,7 @@ class Logger:
self.name_to_value[key] = value
self.name_to_excluded[key] = self.to_tuple(exclude)
def record_mean(self, key: str, value: Optional[float], exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
def record_mean(self, key: str, value: Optional[float], exclude: Optional[Union[str, tuple[str, ...]]] = None) -> None:
"""
The same as record(), but if called many times, values averaged.
@ -624,7 +625,7 @@ class Logger:
# Misc
# ----------------------------------------
def _do_log(self, args: Tuple[Any, ...]) -> None:
def _do_log(self, args: tuple[Any, ...]) -> None:
"""
log to the requested format outputs
@ -635,7 +636,7 @@ class Logger:
_format.write_sequence(list(map(str, args)))
def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] = None) -> Logger:
def configure(folder: Optional[str] = None, format_strings: Optional[list[str]] = None) -> Logger:
"""
Configure the current logger.

View file

@ -5,7 +5,7 @@ import json
import os
import time
from glob import glob
from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union
from typing import Any, Optional, SupportsFloat, Union
import gymnasium as gym
import pandas
@ -33,8 +33,8 @@ class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
env: gym.Env,
filename: Optional[str] = None,
allow_early_resets: bool = True,
reset_keywords: Tuple[str, ...] = (),
info_keywords: Tuple[str, ...] = (),
reset_keywords: tuple[str, ...] = (),
info_keywords: tuple[str, ...] = (),
override_existing: bool = True,
):
super().__init__(env=env)
@ -52,16 +52,16 @@ class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
self.reset_keywords = reset_keywords
self.info_keywords = info_keywords
self.allow_early_resets = allow_early_resets
self.rewards: List[float] = []
self.rewards: list[float] = []
self.needs_reset = True
self.episode_returns: List[float] = []
self.episode_lengths: List[int] = []
self.episode_times: List[float] = []
self.episode_returns: list[float] = []
self.episode_lengths: list[int] = []
self.episode_times: list[float] = []
self.total_steps = 0
# extra info about the current episode, that was passed in during reset()
self.current_reset_info: Dict[str, Any] = {}
self.current_reset_info: dict[str, Any] = {}
def reset(self, **kwargs) -> Tuple[ObsType, Dict[str, Any]]:
def reset(self, **kwargs) -> tuple[ObsType, dict[str, Any]]:
"""
Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True
@ -82,7 +82,7 @@ class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
self.current_reset_info[key] = value
return self.env.reset(**kwargs)
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""
Step the environment with the given action
@ -126,7 +126,7 @@ class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""
return self.total_steps
def get_episode_rewards(self) -> List[float]:
def get_episode_rewards(self) -> list[float]:
"""
Returns the rewards of all the episodes
@ -134,7 +134,7 @@ class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""
return self.episode_returns
def get_episode_lengths(self) -> List[int]:
def get_episode_lengths(self) -> list[int]:
"""
Returns the number of timesteps of all the episodes
@ -142,7 +142,7 @@ class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""
return self.episode_lengths
def get_episode_times(self) -> List[float]:
def get_episode_times(self) -> list[float]:
"""
Returns the runtime in seconds of all the episodes
@ -175,8 +175,8 @@ class ResultsWriter:
def __init__(
self,
filename: str = "",
header: Optional[Dict[str, Union[float, str]]] = None,
extra_keys: Tuple[str, ...] = (),
header: Optional[dict[str, Union[float, str]]] = None,
extra_keys: tuple[str, ...] = (),
override_existing: bool = True,
):
if header is None:
@ -200,7 +200,7 @@ class ResultsWriter:
self.file_handler.flush()
def write_row(self, epinfo: Dict[str, float]) -> None:
def write_row(self, epinfo: dict[str, float]) -> None:
"""
Write row of monitor data to csv log file.
@ -217,7 +217,7 @@ class ResultsWriter:
self.file_handler.close()
def get_monitor_files(path: str) -> List[str]:
def get_monitor_files(path: str) -> list[str]:
"""
get all the monitor files in the given path

View file

@ -1,6 +1,7 @@
import copy
from abc import ABC, abstractmethod
from typing import Iterable, List, Optional
from collections.abc import Iterable
from typing import Optional
import numpy as np
from numpy.typing import DTypeLike
@ -153,11 +154,11 @@ class VectorizedActionNoise(ActionNoise):
self._base_noise = base_noise
@property
def noises(self) -> List[ActionNoise]:
def noises(self) -> list[ActionNoise]:
return self._noises
@noises.setter
def noises(self, noises: List[ActionNoise]) -> None:
def noises(self, noises: list[ActionNoise]) -> None:
noises = list(noises) # raises TypeError if not iterable
assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}."

View file

@ -4,7 +4,7 @@ import sys
import time
import warnings
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Optional, TypeVar, Union
import numpy as np
import torch as th
@ -79,7 +79,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
def __init__(
self,
policy: Union[str, Type[BasePolicy]],
policy: Union[str, type[BasePolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule],
buffer_size: int = 1_000_000, # 1e6
@ -87,13 +87,13 @@ class OffPolicyAlgorithm(BaseAlgorithm):
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = (1, "step"),
train_freq: Union[int, tuple[int, str]] = (1, "step"),
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
optimize_memory_usage: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_kwargs: Optional[dict[str, Any]] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
verbose: int = 0,
@ -105,7 +105,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
sde_support: bool = True,
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None,
):
super().__init__(
policy=policy,
@ -256,7 +256,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
progress_bar: bool = False,
) -> Tuple[int, BaseCallback]:
) -> tuple[int, BaseCallback]:
"""
cf `BaseAlgorithm`.
"""
@ -362,7 +362,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
learning_starts: int,
action_noise: Optional[ActionNoise] = None,
n_envs: int = 1,
) -> Tuple[np.ndarray, np.ndarray]:
) -> tuple[np.ndarray, np.ndarray]:
"""
Sample an action according to the exploration policy.
This is either done by sampling the probability distribution of the policy,
@ -442,10 +442,10 @@ class OffPolicyAlgorithm(BaseAlgorithm):
self,
replay_buffer: ReplayBuffer,
buffer_action: np.ndarray,
new_obs: Union[np.ndarray, Dict[str, np.ndarray]],
new_obs: Union[np.ndarray, dict[str, np.ndarray]],
reward: np.ndarray,
dones: np.ndarray,
infos: List[Dict[str, Any]],
infos: list[dict[str, Any]],
) -> None:
"""
Store transition in the replay buffer.

View file

@ -1,7 +1,7 @@
import sys
import time
import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Optional, TypeVar, Union
import numpy as np
import torch as th
@ -60,7 +60,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],
policy: Union[str, type[ActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule],
n_steps: int,
@ -71,17 +71,17 @@ class OnPolicyAlgorithm(BaseAlgorithm):
max_grad_norm: float,
use_sde: bool,
sde_sample_freq: int,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
rollout_buffer_class: Optional[type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[dict[str, Any]] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
monitor_wrapper: bool = True,
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_kwargs: Optional[dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None,
):
super().__init__(
policy=policy,
@ -339,7 +339,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
return self
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
state_dicts = ["policy", "policy.optimizer"]
return state_dicts, []

View file

@ -5,7 +5,7 @@ import copy
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Optional, TypeVar, Union
import numpy as np
import torch as th
@ -64,12 +64,12 @@ class BaseModel(nn.Module):
self,
observation_space: spaces.Space,
action_space: spaces.Space,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
features_extractor: Optional[BaseFeaturesExtractor] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
):
super().__init__()
@ -95,9 +95,9 @@ class BaseModel(nn.Module):
def _update_features_extractor(
self,
net_kwargs: Dict[str, Any],
net_kwargs: dict[str, Any],
features_extractor: Optional[BaseFeaturesExtractor] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Update the network keyword arguments and create a new features extractor object if needed.
If a ``features_extractor`` object is passed, then it will be shared.
@ -130,7 +130,7 @@ class BaseModel(nn.Module):
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
return features_extractor(preprocessed_obs)
def _get_constructor_parameters(self) -> Dict[str, Any]:
def _get_constructor_parameters(self) -> dict[str, Any]:
"""
Get data that need to be saved in order to re-create the model when loading it from disk.
@ -164,7 +164,7 @@ class BaseModel(nn.Module):
th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
@classmethod
def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel:
def load(cls: type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel:
"""
Load model from path.
@ -210,7 +210,7 @@ class BaseModel(nn.Module):
"""
self.train(mode)
def is_vectorized_observation(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> bool:
def is_vectorized_observation(self, observation: Union[np.ndarray, dict[str, np.ndarray]]) -> bool:
"""
Check whether or not the observation is vectorized,
apply transposition to image (so that they are channel-first) if needed.
@ -233,7 +233,7 @@ class BaseModel(nn.Module):
)
return vectorized_env
def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[PyTorchObs, bool]:
def obs_to_tensor(self, observation: Union[np.ndarray, dict[str, np.ndarray]]) -> tuple[PyTorchObs, bool]:
"""
Convert an input observation to a PyTorch tensor that can be fed to a model.
Includes sugar-coating to handle different observations (e.g. normalizing images).
@ -330,11 +330,11 @@ class BasePolicy(BaseModel, ABC):
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
observation: Union[np.ndarray, dict[str, np.ndarray]],
state: Optional[tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
@ -450,20 +450,20 @@ class ActorCriticPolicy(BasePolicy):
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
activation_fn: type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
):
if optimizer_kwargs is None:
optimizer_kwargs = {}
@ -534,7 +534,7 @@ class ActorCriticPolicy(BasePolicy):
self._build(lr_schedule)
def _get_constructor_parameters(self) -> Dict[str, Any]:
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) # type: ignore[arg-type, return-value]
@ -633,7 +633,7 @@ class ActorCriticPolicy(BasePolicy):
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) # type: ignore[call-arg]
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
def forward(self, obs: th.Tensor, deterministic: bool = False) -> tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Forward pass in all the networks (actor and critic)
@ -659,7 +659,7 @@ class ActorCriticPolicy(BasePolicy):
def extract_features( # type: ignore[override]
self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
) -> Union[th.Tensor, tuple[th.Tensor, th.Tensor]]:
"""
Preprocess the observation if needed and extract features.
@ -716,7 +716,7 @@ class ActorCriticPolicy(BasePolicy):
"""
return self.get_distribution(observation).get_actions(deterministic=deterministic)
def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
"""
Evaluate actions according to the current policy,
given the observations.
@ -800,20 +800,20 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
activation_fn: type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
):
super().__init__(
observation_space,
@ -873,20 +873,20 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
observation_space: spaces.Dict,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
activation_fn: type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
):
super().__init__(
observation_space,
@ -942,10 +942,10 @@ class ContinuousCritic(BaseModel):
self,
observation_space: spaces.Space,
action_space: spaces.Box,
net_arch: List[int],
net_arch: list[int],
features_extractor: BaseFeaturesExtractor,
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
activation_fn: type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
n_critics: int = 2,
share_features_extractor: bool = True,
@ -961,14 +961,14 @@ class ContinuousCritic(BaseModel):
self.share_features_extractor = share_features_extractor
self.n_critics = n_critics
self.q_networks: List[nn.Module] = []
self.q_networks: list[nn.Module] = []
for idx in range(n_critics):
q_net_list = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
q_net = nn.Sequential(*q_net_list)
self.add_module(f"qf{idx}", q_net)
self.q_networks.append(q_net)
def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]:
def forward(self, obs: th.Tensor, actions: th.Tensor) -> tuple[th.Tensor, ...]:
# Learn the features extractor using the policy loss only
# when the features_extractor is shared with the actor
with th.set_grad_enabled(not self.share_features_extractor):

View file

@ -1,5 +1,5 @@
import warnings
from typing import Dict, Tuple, Union
from typing import Union
import numpy as np
import torch as th
@ -90,10 +90,10 @@ def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) ->
def preprocess_obs(
obs: Union[th.Tensor, Dict[str, th.Tensor]],
obs: Union[th.Tensor, dict[str, th.Tensor]],
observation_space: spaces.Space,
normalize_images: bool = True,
) -> Union[th.Tensor, Dict[str, th.Tensor]]:
) -> Union[th.Tensor, dict[str, th.Tensor]]:
"""
Preprocess observation to be to a neural network.
For images, it normalizes the values by dividing them by 255 (to have values in [0, 1])
@ -107,7 +107,7 @@ def preprocess_obs(
"""
if isinstance(observation_space, spaces.Dict):
# Do not modify by reference the original observation
assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}"
assert isinstance(obs, dict), f"Expected dict, got {type(obs)}"
preprocessed_obs = {}
for key, _obs in obs.items():
preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
@ -142,7 +142,7 @@ def preprocess_obs(
def get_obs_shape(
observation_space: spaces.Space,
) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]:
) -> Union[tuple[int, ...], dict[str, tuple[int, ...]]]:
"""
Get the shape of the observation (useful for the buffers).

View file

@ -1,4 +1,4 @@
from typing import Callable, List, Optional, Tuple
from typing import Callable, Optional
import numpy as np
import pandas as pd
@ -29,7 +29,7 @@ def rolling_window(array: np.ndarray, window: int) -> np.ndarray:
return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides)
def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> Tuple[np.ndarray, np.ndarray]:
def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> tuple[np.ndarray, np.ndarray]:
"""
Apply a function to the rolling window of 2 arrays
@ -44,7 +44,7 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callabl
return var_1[window - 1 :], function_on_var2
def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]:
def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> tuple[np.ndarray, np.ndarray]:
"""
Decompose a data frame variable to x and ys
@ -69,7 +69,7 @@ def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray
def plot_curves(
xy_list: List[Tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: Tuple[int, int] = (8, 2)
xy_list: list[tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: tuple[int, int] = (8, 2)
) -> None:
"""
plot the curves
@ -99,7 +99,7 @@ def plot_curves(
def plot_results(
dirs: List[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2)
dirs: list[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: tuple[int, int] = (8, 2)
) -> None:
"""
Plot the results using csv files from ``Monitor`` wrapper.

View file

@ -1,10 +1,8 @@
from typing import Tuple
import numpy as np
class RunningMeanStd:
def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
def __init__(self, epsilon: float = 1e-4, shape: tuple[int, ...] = ()):
"""
Calculates the running mean and std of a data stream
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm

View file

@ -12,7 +12,7 @@ import pathlib
import pickle
import warnings
import zipfile
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Optional, Union
import cloudpickle
import torch as th
@ -73,7 +73,7 @@ def is_json_serializable(item: Any) -> bool:
return json_serializable
def data_to_json(data: Dict[str, Any]) -> str:
def data_to_json(data: dict[str, Any]) -> str:
"""
Turn data (class parameters) into a JSON string for storing
@ -128,7 +128,7 @@ def data_to_json(data: Dict[str, Any]) -> str:
return json_string
def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
def json_to_data(json_string: str, custom_objects: Optional[dict[str, Any]] = None) -> dict[str, Any]:
"""
Turn JSON serialization of class-parameters back into dictionary.
@ -293,9 +293,9 @@ def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: O
def save_to_zip_file(
save_path: Union[str, pathlib.Path, io.BufferedIOBase],
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
pytorch_variables: Optional[Dict[str, Any]] = None,
data: Optional[dict[str, Any]] = None,
params: Optional[dict[str, Any]] = None,
pytorch_variables: Optional[dict[str, Any]] = None,
verbose: int = 0,
) -> None:
"""
@ -376,11 +376,11 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: in
def load_from_zip_file(
load_path: Union[str, pathlib.Path, io.BufferedIOBase],
load_data: bool = True,
custom_objects: Optional[Dict[str, Any]] = None,
custom_objects: Optional[dict[str, Any]] = None,
device: Union[th.device, str] = "auto",
verbose: int = 0,
print_system_info: bool = False,
) -> Tuple[Optional[Dict[str, Any]], TensorDict, Optional[TensorDict]]:
) -> tuple[Optional[dict[str, Any]], TensorDict, Optional[TensorDict]]:
"""
Load model data from a .zip archive

View file

@ -1,4 +1,5 @@
from typing import Any, Callable, Dict, Iterable, Optional
from collections.abc import Iterable
from typing import Any, Callable, Optional
import torch
from torch.optim import Optimizer
@ -67,7 +68,7 @@ class RMSpropTFLike(Optimizer):
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
super().__init__(params, defaults)
def __setstate__(self, state: Dict[str, Any]) -> None:
def __setstate__(self, state: dict[str, Any]) -> None:
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("momentum", 0)

View file

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple, Type, Union
from typing import Optional, Union
import gymnasium as gym
import torch as th
@ -110,13 +110,13 @@ class NatureCNN(BaseFeaturesExtractor):
def create_mlp(
input_dim: int,
output_dim: int,
net_arch: List[int],
activation_fn: Type[nn.Module] = nn.ReLU,
net_arch: list[int],
activation_fn: type[nn.Module] = nn.ReLU,
squash_output: bool = False,
with_bias: bool = True,
pre_linear_modules: Optional[List[Type[nn.Module]]] = None,
post_linear_modules: Optional[List[Type[nn.Module]]] = None,
) -> List[nn.Module]:
pre_linear_modules: Optional[list[type[nn.Module]]] = None,
post_linear_modules: Optional[list[type[nn.Module]]] = None,
) -> list[nn.Module]:
"""
Create a multi layer perceptron (MLP), which is
a collection of fully-connected layers each followed by an activation function.
@ -211,14 +211,14 @@ class MlpExtractor(nn.Module):
def __init__(
self,
feature_dim: int,
net_arch: Union[List[int], Dict[str, List[int]]],
activation_fn: Type[nn.Module],
net_arch: Union[list[int], dict[str, list[int]]],
activation_fn: type[nn.Module],
device: Union[th.device, str] = "auto",
) -> None:
super().__init__()
device = get_device(device)
policy_net: List[nn.Module] = []
value_net: List[nn.Module] = []
policy_net: list[nn.Module] = []
value_net: list[nn.Module] = []
last_layer_dim_pi = feature_dim
last_layer_dim_vf = feature_dim
@ -249,7 +249,7 @@ class MlpExtractor(nn.Module):
self.policy_net = nn.Sequential(*policy_net).to(device)
self.value_net = nn.Sequential(*value_net).to(device)
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
def forward(self, features: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
"""
:return: latent_policy, latent_value of the specified network.
If all layers are shared, then ``latent_policy == latent_value``
@ -288,7 +288,7 @@ class CombinedExtractor(BaseFeaturesExtractor):
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
super().__init__(observation_space, features_dim=1)
extractors: Dict[str, nn.Module] = {}
extractors: dict[str, nn.Module] = {}
total_concat_size = 0
for key, subspace in observation_space.spaces.items():
@ -313,7 +313,7 @@ class CombinedExtractor(BaseFeaturesExtractor):
return th.cat(encoded_tensor_list, dim=1)
def get_actor_critic_arch(net_arch: Union[List[int], Dict[str, List[int]]]) -> Tuple[List[int], List[int]]:
def get_actor_critic_arch(net_arch: Union[list[int], dict[str, list[int]]]) -> tuple[list[int], list[int]]:
"""
Get the actor and critic network architectures for off-policy actor-critic algorithms (SAC, TD3, DDPG).

View file

@ -1,7 +1,7 @@
"""Common aliases for type hints"""
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Protocol, SupportsFloat, Union
import gymnasium as gym
import numpy as np
@ -13,14 +13,14 @@ if TYPE_CHECKING:
from stable_baselines3.common.vec_env import VecEnv
GymEnv = Union[gym.Env, "VecEnv"]
GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
GymResetReturn = Tuple[GymObs, Dict]
AtariResetReturn = Tuple[np.ndarray, Dict[str, Any]]
GymStepReturn = Tuple[GymObs, float, bool, bool, Dict]
AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]]
TensorDict = Dict[str, th.Tensor]
OptimizerStateDict = Dict[str, Any]
MaybeCallback = Union[None, Callable, List["BaseCallback"], "BaseCallback"]
GymObs = Union[tuple, dict[str, Any], np.ndarray, int]
GymResetReturn = tuple[GymObs, dict]
AtariResetReturn = tuple[np.ndarray, dict[str, Any]]
GymStepReturn = tuple[GymObs, float, bool, bool, dict]
AtariStepReturn = tuple[np.ndarray, SupportsFloat, bool, bool, dict[str, Any]]
TensorDict = dict[str, th.Tensor]
OptimizerStateDict = dict[str, Any]
MaybeCallback = Union[None, Callable, list["BaseCallback"], "BaseCallback"]
PyTorchObs = Union[th.Tensor, TensorDict]
# A schedule takes the remaining progress as input
@ -81,11 +81,11 @@ class TrainFreq(NamedTuple):
class PolicyPredictor(Protocol):
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
observation: Union[np.ndarray, dict[str, np.ndarray]],
state: Optional[tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).

View file

@ -4,8 +4,9 @@ import platform
import random
import re
from collections import deque
from collections.abc import Iterable
from itertools import zip_longest
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Optional, Union
import cloudpickle
import gymnasium as gym
@ -415,7 +416,7 @@ def safe_mean(arr: Union[np.ndarray, list, deque]) -> float:
return np.nan if len(arr) == 0 else float(np.mean(arr)) # type: ignore[arg-type]
def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> List[th.Tensor]:
def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> list[th.Tensor]:
"""
Extract parameters from the state dict of ``model``
if the name contains one of the strings in ``included_names``.
@ -473,7 +474,7 @@ def polyak_update(
th.add(target_param.data, param.data, alpha=tau, out=target_param.data)
def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.device) -> Union[th.Tensor, TensorDict]:
def obs_as_tensor(obs: Union[np.ndarray, dict[str, np.ndarray]], device: th.device) -> Union[th.Tensor, TensorDict]:
"""
Moves the observation to the given device.
@ -517,7 +518,7 @@ def should_collect_more_steps(
)
def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:
"""
Retrieve system and python env info for the current system.

View file

@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Optional, Type, TypeVar
from typing import Optional, TypeVar
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
@ -16,7 +16,7 @@ from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper)
def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]:
def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]:
"""
Retrieve a ``VecEnvWrapper`` object by recursively searching.
@ -42,7 +42,7 @@ def unwrap_vec_normalize(env: VecEnv) -> Optional[VecNormalize]:
return unwrap_vec_wrapper(env, VecNormalize)
def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: type[VecEnvWrapper]) -> bool:
"""
Check if an environment is already wrapped in a given ``VecEnvWrapper``.

View file

@ -1,8 +1,9 @@
import inspect
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union
from typing import Any, Optional, Union
import cloudpickle
import gymnasium as gym
@ -14,10 +15,10 @@ from gymnasium import spaces
VecEnvIndices = Union[None, int, Iterable[int]]
# VecEnvObs is what is returned by the reset() method
# it contains the observation for each env
VecEnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]]
VecEnvObs = Union[np.ndarray, dict[str, np.ndarray], tuple[np.ndarray, ...]]
# VecEnvStepReturn is what is returned by the step() method
# it contains the observation, reward, done, info for each env
VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]]
VecEnvStepReturn = tuple[VecEnvObs, np.ndarray, np.ndarray, list[dict]]
def tile_images(images_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
@ -65,11 +66,11 @@ class VecEnv(ABC):
self.observation_space = observation_space
self.action_space = action_space
# store info returned by the reset method
self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)]
self.reset_infos: list[dict[str, Any]] = [{} for _ in range(num_envs)]
# seeds to be used in the next call to env.reset()
self._seeds: List[Optional[int]] = [None for _ in range(num_envs)]
self._seeds: list[Optional[int]] = [None for _ in range(num_envs)]
# options to be used in the next call to env.reset()
self._options: List[Dict[str, Any]] = [{} for _ in range(num_envs)]
self._options: list[dict[str, Any]] = [{} for _ in range(num_envs)]
try:
render_modes = self.get_attr("render_mode")
@ -147,7 +148,7 @@ class VecEnv(ABC):
raise NotImplementedError()
@abstractmethod
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
"""
Return attribute from vectorized environment.
@ -170,7 +171,7 @@ class VecEnv(ABC):
raise NotImplementedError()
@abstractmethod
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]:
"""
Call instance methods of vectorized environments.
@ -183,7 +184,7 @@ class VecEnv(ABC):
raise NotImplementedError()
@abstractmethod
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]:
"""
Check if environments are wrapped with a given wrapper.
@ -292,7 +293,7 @@ class VecEnv(ABC):
self._seeds = [seed + idx for idx in range(self.num_envs)]
return self._seeds
def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None:
def set_options(self, options: Optional[Union[list[dict], dict]] = None) -> None:
"""
Set environment options for all environments.
If a dict is passed instead of a list, the same options will be used for all environments.
@ -379,7 +380,7 @@ class VecEnvWrapper(VecEnv):
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
return self.venv.seed(seed)
def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None:
def set_options(self, options: Optional[Union[list[dict], dict]] = None) -> None:
return self.venv.set_options(options)
def close(self) -> None:
@ -391,16 +392,16 @@ class VecEnvWrapper(VecEnv):
def get_images(self) -> Sequence[Optional[np.ndarray]]:
return self.venv.get_images()
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
return self.venv.get_attr(attr_name, indices)
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
return self.venv.set_attr(attr_name, value, indices)
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]:
return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]:
return self.venv.env_is_wrapped(wrapper_class, indices=indices)
def __getattr__(self, name: str) -> Any:
@ -419,7 +420,7 @@ class VecEnvWrapper(VecEnv):
return self.getattr_recursive(name)
def _get_all_attributes(self) -> Dict[str, Any]:
def _get_all_attributes(self) -> dict[str, Any]:
"""Get all (inherited) instance and class attributes
:return: all_attributes

View file

@ -1,7 +1,8 @@
import warnings
from collections import OrderedDict
from collections.abc import Sequence
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Sequence, Type
from typing import Any, Callable, Optional
import gymnasium as gym
import numpy as np
@ -26,7 +27,7 @@ class DummyVecEnv(VecEnv):
actions: np.ndarray
def __init__(self, env_fns: List[Callable[[], gym.Env]]):
def __init__(self, env_fns: list[Callable[[], gym.Env]]):
self.envs = [_patch_env(fn()) for fn in env_fns]
if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs):
raise ValueError(
@ -46,7 +47,7 @@ class DummyVecEnv(VecEnv):
self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys])
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
self.buf_infos: List[Dict[str, Any]] = [{} for _ in range(self.num_envs)]
self.buf_infos: list[dict[str, Any]] = [{} for _ in range(self.num_envs)]
self.metadata = env.metadata
def step_async(self, actions: np.ndarray) -> None:
@ -112,7 +113,7 @@ class DummyVecEnv(VecEnv):
def _obs_from_buf(self) -> VecEnvObs:
return dict_to_obs(self.observation_space, deepcopy(self.buf_obs))
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
"""Return attribute from vectorized environment (see base class)."""
target_envs = self._get_target_envs(indices)
return [env_i.get_wrapper_attr(attr_name) for env_i in target_envs]
@ -123,12 +124,12 @@ class DummyVecEnv(VecEnv):
for env_i in target_envs:
setattr(env_i, attr_name, value)
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]:
"""Call instance methods of vectorized environments."""
target_envs = self._get_target_envs(indices)
return [env_i.get_wrapper_attr(method_name)(*method_args, **method_kwargs) for env_i in target_envs]
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]:
"""Check if worker environments are wrapped with a given wrapper"""
target_envs = self._get_target_envs(indices)
# Import here to avoid a circular import
@ -136,6 +137,6 @@ class DummyVecEnv(VecEnv):
return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
def _get_target_envs(self, indices: VecEnvIndices) -> list[gym.Env]:
indices = self._get_indices(indices)
return [self.envs[i] for i in indices]

View file

@ -1,12 +1,13 @@
import warnings
from typing import Any, Dict, Generic, List, Mapping, Optional, Tuple, TypeVar, Union
from collections.abc import Mapping
from typing import Any, Generic, Optional, TypeVar, Union
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
TObs = TypeVar("TObs", np.ndarray, Dict[str, np.ndarray])
TObs = TypeVar("TObs", np.ndarray, dict[str, np.ndarray])
class StackedObservations(Generic[TObs]):
@ -66,7 +67,7 @@ class StackedObservations(Generic[TObs]):
@staticmethod
def compute_stacking(
n_stack: int, observation_space: spaces.Box, channels_order: Optional[str] = None
) -> Tuple[bool, int, Tuple[int, ...], int]:
) -> tuple[bool, int, tuple[int, ...], int]:
"""
Calculates the parameters in order to stack observations
@ -119,8 +120,8 @@ class StackedObservations(Generic[TObs]):
self,
observations: TObs,
dones: np.ndarray,
infos: List[Dict[str, Any]],
) -> Tuple[TObs, List[Dict[str, Any]]]:
infos: list[dict[str, Any]],
) -> tuple[TObs, list[dict[str, Any]]]:
"""
Add the observations to the stack and use the dones to update the infos.

View file

@ -1,6 +1,7 @@
import multiprocessing as mp
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
import gymnasium as gym
import numpy as np
@ -26,7 +27,7 @@ def _worker(
parent_remote.close()
env = _patch_env(env_fn_wrapper.var())
reset_info: Optional[Dict[str, Any]] = {}
reset_info: Optional[dict[str, Any]] = {}
while True:
try:
cmd, data = remote.recv()
@ -91,7 +92,7 @@ class SubprocVecEnv(VecEnv):
Defaults to 'forkserver' on available platforms, and 'spawn' otherwise.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[str] = None):
def __init__(self, env_fns: list[Callable[[], gym.Env]], start_method: Optional[str] = None):
self.waiting = False
self.closed = False
n_envs = len(env_fns)
@ -164,7 +165,7 @@ class SubprocVecEnv(VecEnv):
outputs = [pipe.recv() for pipe in self.remotes]
return outputs
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
"""Return attribute from vectorized environment (see base class)."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
@ -179,21 +180,21 @@ class SubprocVecEnv(VecEnv):
for remote in target_remotes:
remote.recv()
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]:
"""Call instance methods of vectorized environments."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(("env_method", (method_name, method_args, method_kwargs)))
return [remote.recv() for remote in target_remotes]
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]:
"""Check if worker environments are wrapped with a given wrapper"""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(("is_wrapped", wrapper_class))
return [remote.recv() for remote in target_remotes]
def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]:
def _get_target_remotes(self, indices: VecEnvIndices) -> list[Any]:
"""
Get the connection object needed to communicate with the wanted
envs that are in subprocesses.
@ -205,7 +206,7 @@ class SubprocVecEnv(VecEnv):
return [self.remotes[i] for i in indices]
def _stack_obs(obs_list: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
def _stack_obs(obs_list: Union[list[VecEnvObs], tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
"""
Stack observations (convert from a list of single env obs to a stack of obs),
depending on the observation space.

View file

@ -2,7 +2,7 @@
Helpers for dealing with vectorized environments.
"""
from typing import Any, Dict, List, Tuple
from typing import Any
import numpy as np
from gymnasium import spaces
@ -11,7 +11,7 @@ from stable_baselines3.common.preprocessing import check_for_nested_spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs:
def dict_to_obs(obs_space: spaces.Space, obs_dict: dict[Any, np.ndarray]) -> VecEnvObs:
"""
Convert an internal representation raw_obs into the appropriate type
specified by space.
@ -32,7 +32,7 @@ def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> Vec
return obs_dict[None]
def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]:
def obs_space_info(obs_space: spaces.Space) -> tuple[list[str], dict[Any, tuple[int, ...]], dict[Any, np.dtype]]:
"""
Get dict-structured information about a gym.Space.

View file

@ -1,5 +1,4 @@
import warnings
from typing import List, Tuple
import numpy as np
from gymnasium import spaces
@ -48,7 +47,7 @@ class VecCheckNan(VecEnvWrapper):
self._observations = observations
return observations
def check_array_value(self, name: str, value: np.ndarray) -> List[Tuple[str, str]]:
def check_array_value(self, name: str, value: np.ndarray) -> list[tuple[str, str]]:
"""
Check for inf and NaN for a single numpy array.

View file

@ -1,4 +1,5 @@
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from collections.abc import Mapping
from typing import Any, Optional, Union
import numpy as np
from gymnasium import spaces
@ -29,17 +30,17 @@ class VecFrameStack(VecEnvWrapper):
def step_wait(
self,
) -> Tuple[
Union[np.ndarray, Dict[str, np.ndarray]],
) -> tuple[
Union[np.ndarray, dict[str, np.ndarray]],
np.ndarray,
np.ndarray,
List[Dict[str, Any]],
list[dict[str, Any]],
]:
observations, rewards, dones, infos = self.venv.step_wait()
observations, infos = self.stacked_obs.update(observations, dones, infos) # type: ignore[arg-type]
return observations, rewards, dones, infos
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
def reset(self) -> Union[np.ndarray, dict[str, np.ndarray]]:
"""
Reset all environments
"""

View file

@ -1,6 +1,6 @@
import time
import warnings
from typing import Optional, Tuple
from typing import Optional
import numpy as np
@ -27,7 +27,7 @@ class VecMonitor(VecEnvWrapper):
self,
venv: VecEnv,
filename: Optional[str] = None,
info_keywords: Tuple[str, ...] = (),
info_keywords: tuple[str, ...] = (),
):
# Avoid circular import
from stable_baselines3.common.monitor import Monitor, ResultsWriter

View file

@ -1,7 +1,7 @@
import inspect
import pickle
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
import numpy as np
from gymnasium import spaces
@ -29,8 +29,8 @@ class VecNormalize(VecEnvWrapper):
If not specified, all keys will be normalized.
"""
obs_spaces: Dict[str, spaces.Space]
old_obs: Union[np.ndarray, Dict[str, np.ndarray]]
obs_spaces: dict[str, spaces.Space]
old_obs: Union[np.ndarray, dict[str, np.ndarray]]
def __init__(
self,
@ -42,7 +42,7 @@ class VecNormalize(VecEnvWrapper):
clip_reward: float = 10.0,
gamma: float = 0.99,
epsilon: float = 1e-8,
norm_obs_keys: Optional[List[str]] = None,
norm_obs_keys: Optional[list[str]] = None,
):
VecEnvWrapper.__init__(self, venv)
@ -125,7 +125,7 @@ class VecNormalize(VecEnvWrapper):
f"not {self.observation_space}"
)
def __getstate__(self) -> Dict[str, Any]:
def __getstate__(self) -> dict[str, Any]:
"""
Gets state for pickling.
@ -138,7 +138,7 @@ class VecNormalize(VecEnvWrapper):
del state["returns"]
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
def __setstate__(self, state: dict[str, Any]) -> None:
"""
Restores pickled state.
@ -229,7 +229,7 @@ class VecNormalize(VecEnvWrapper):
"""
return (obs * np.sqrt(obs_rms.var + self.epsilon)) + obs_rms.mean
def normalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
def normalize_obs(self, obs: Union[np.ndarray, dict[str, np.ndarray]]) -> Union[np.ndarray, dict[str, np.ndarray]]:
"""
Normalize observations using this VecNormalize's observations statistics.
Calling this method does not update statistics.
@ -254,9 +254,11 @@ class VecNormalize(VecEnvWrapper):
"""
if self.norm_reward:
reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward)
return reward
# Note: we cast to float32 as it correspond to Python default float type
# This cast is needed because `RunningMeanStd` keeps stats in float64
return reward.astype(np.float32)
def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
def unnormalize_obs(self, obs: Union[np.ndarray, dict[str, np.ndarray]]) -> Union[np.ndarray, dict[str, np.ndarray]]:
# Avoid modifying by reference the original object
obs_ = deepcopy(obs)
if self.norm_obs:
@ -274,7 +276,7 @@ class VecNormalize(VecEnvWrapper):
return reward * np.sqrt(self.ret_rms.var + self.epsilon)
return reward
def get_original_obs(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
def get_original_obs(self) -> Union[np.ndarray, dict[str, np.ndarray]]:
"""
Returns an unnormalized version of the observations from the most recent
step or reset.
@ -287,7 +289,7 @@ class VecNormalize(VecEnvWrapper):
"""
return self.old_reward.copy()
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
def reset(self) -> Union[np.ndarray, dict[str, np.ndarray]]:
"""
Reset all environments
:return: first observation of the episode

View file

@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Dict, Union
from typing import Union
import numpy as np
from gymnasium import spaces
@ -73,7 +73,7 @@ class VecTransposeImage(VecEnvWrapper):
return np.transpose(image, (2, 0, 1))
return np.transpose(image, (0, 3, 1, 2))
def transpose_observations(self, observations: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]:
def transpose_observations(self, observations: Union[np.ndarray, dict]) -> Union[np.ndarray, dict]:
"""
Transpose (if needed) and return new observations.
@ -106,7 +106,7 @@ class VecTransposeImage(VecEnvWrapper):
assert isinstance(observations, (np.ndarray, dict))
return self.transpose_observations(observations), rewards, dones, infos
def reset(self) -> Union[np.ndarray, Dict]:
def reset(self) -> Union[np.ndarray, dict]:
"""
Reset all environments
"""

View file

@ -1,6 +1,6 @@
import os
import os.path
from typing import Callable, List
from typing import Callable
import numpy as np
from gymnasium import error, logger
@ -109,7 +109,7 @@ class VecVideoRecorder(VecEnvWrapper):
assert self.recording, "Cannot capture a frame, recording wasn't started."
frame = self.env.render()
if isinstance(frame, List):
if isinstance(frame, list):
frame = frame[-1]
if isinstance(frame, np.ndarray):

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Optional, TypeVar, Union
import torch as th
@ -55,7 +55,7 @@ class DDPG(TD3):
def __init__(
self,
policy: Union[str, Type[TD3Policy]],
policy: Union[str, type[TD3Policy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-3,
buffer_size: int = 1_000_000, # 1e6
@ -63,14 +63,14 @@ class DDPG(TD3):
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 1,
train_freq: Union[int, tuple[int, str]] = 1,
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
optimize_memory_usage: bool = False,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_kwargs: Optional[dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",

View file

@ -1,5 +1,5 @@
import warnings
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, ClassVar, Optional, TypeVar, Union
import numpy as np
import torch as th
@ -62,7 +62,7 @@ class DQN(OffPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
@ -75,7 +75,7 @@ class DQN(OffPolicyAlgorithm):
def __init__(
self,
policy: Union[str, Type[DQNPolicy]],
policy: Union[str, type[DQNPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-4,
buffer_size: int = 1_000_000, # 1e6
@ -83,10 +83,10 @@ class DQN(OffPolicyAlgorithm):
batch_size: int = 32,
tau: float = 1.0,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 4,
train_freq: Union[int, tuple[int, str]] = 4,
gradient_steps: int = 1,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
optimize_memory_usage: bool = False,
target_update_interval: int = 10000,
exploration_fraction: float = 0.1,
@ -95,7 +95,7 @@ class DQN(OffPolicyAlgorithm):
max_grad_norm: float = 10,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_kwargs: Optional[dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
@ -227,11 +227,11 @@ class DQN(OffPolicyAlgorithm):
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
observation: Union[np.ndarray, dict[str, np.ndarray]],
state: Optional[tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
"""
Overrides the base_class predict function to include epsilon-greedy exploration.
@ -273,10 +273,10 @@ class DQN(OffPolicyAlgorithm):
progress_bar=progress_bar,
)
def _excluded_save_params(self) -> List[str]:
def _excluded_save_params(self) -> list[str]:
return [*super()._excluded_save_params(), "q_net", "q_net_target"]
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
state_dicts = ["policy", "policy.optimizer"]
return state_dicts, []

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Type
from typing import Any, Optional
import torch as th
from gymnasium import spaces
@ -35,8 +35,8 @@ class QNetwork(BasePolicy):
action_space: spaces.Discrete,
features_extractor: BaseFeaturesExtractor,
features_dim: int,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
net_arch: Optional[list[int]] = None,
activation_fn: type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
) -> None:
super().__init__(
@ -71,7 +71,7 @@ class QNetwork(BasePolicy):
action = q_values.argmax(dim=1).reshape(-1)
return action
def _get_constructor_parameters(self) -> Dict[str, Any]:
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
@ -113,13 +113,13 @@ class DQNPolicy(BasePolicy):
observation_space: spaces.Space,
action_space: spaces.Discrete,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
net_arch: Optional[list[int]] = None,
activation_fn: type[nn.Module] = nn.ReLU,
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
) -> None:
super().__init__(
observation_space,
@ -183,7 +183,7 @@ class DQNPolicy(BasePolicy):
def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
return self.q_net._predict(obs, deterministic=deterministic)
def _get_constructor_parameters(self) -> Dict[str, Any]:
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
@ -237,13 +237,13 @@ class CnnPolicy(DQNPolicy):
observation_space: spaces.Space,
action_space: spaces.Discrete,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
net_arch: Optional[list[int]] = None,
activation_fn: type[nn.Module] = nn.ReLU,
features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
) -> None:
super().__init__(
observation_space,
@ -282,13 +282,13 @@ class MultiInputPolicy(DQNPolicy):
observation_space: spaces.Dict,
action_space: spaces.Discrete,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
net_arch: Optional[list[int]] = None,
activation_fn: type[nn.Module] = nn.ReLU,
features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
) -> None:
super().__init__(
observation_space,

View file

@ -1,6 +1,6 @@
import copy
import warnings
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
import numpy as np
import torch as th
@ -98,7 +98,7 @@ class HerReplayBuffer(DictReplayBuffer):
self.ep_length = np.zeros((self.buffer_size, self.n_envs), dtype=np.int64)
self._current_ep_start = np.zeros(self.n_envs, dtype=np.int64)
def __getstate__(self) -> Dict[str, Any]:
def __getstate__(self) -> dict[str, Any]:
"""
Gets state for pickling.
@ -109,7 +109,7 @@ class HerReplayBuffer(DictReplayBuffer):
del state["env"]
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
def __setstate__(self, state: dict[str, Any]) -> None:
"""
Restores pickled state.
@ -134,12 +134,12 @@ class HerReplayBuffer(DictReplayBuffer):
def add( # type: ignore[override]
self,
obs: Dict[str, np.ndarray],
next_obs: Dict[str, np.ndarray],
obs: dict[str, np.ndarray],
next_obs: dict[str, np.ndarray],
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
infos: list[dict[str, Any]],
) -> None:
# When the buffer is full, we rewrite on old episodes. When we start to
# rewrite on an old episodes, we want the whole old episode to be deleted

View file

@ -1,5 +1,5 @@
import warnings
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
from typing import Any, ClassVar, Optional, TypeVar, Union
import numpy as np
import torch as th
@ -71,7 +71,7 @@ class PPO(OnPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
"MlpPolicy": ActorCriticPolicy,
"CnnPolicy": ActorCriticCnnPolicy,
"MultiInputPolicy": MultiInputActorCriticPolicy,
@ -79,7 +79,7 @@ class PPO(OnPolicyAlgorithm):
def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],
policy: Union[str, type[ActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
n_steps: int = 2048,
@ -95,12 +95,12 @@ class PPO(OnPolicyAlgorithm):
max_grad_norm: float = 0.5,
use_sde: bool = False,
sde_sample_freq: int = -1,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
rollout_buffer_class: Optional[type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[dict[str, Any]] = None,
target_kl: Optional[float] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_kwargs: Optional[dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Optional, Union
import torch as th
from gymnasium import spaces
@ -51,10 +51,10 @@ class Actor(BasePolicy):
self,
observation_space: spaces.Space,
action_space: spaces.Box,
net_arch: List[int],
net_arch: list[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
activation_fn: type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
full_std: bool = True,
@ -102,7 +102,7 @@ class Actor(BasePolicy):
self.mu = nn.Linear(last_layer_dim, action_dim)
self.log_std = nn.Linear(last_layer_dim, action_dim) # type: ignore[assignment]
def _get_constructor_parameters(self) -> Dict[str, Any]:
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
@ -144,7 +144,7 @@ class Actor(BasePolicy):
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
def get_action_dist_params(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tensor, dict[str, th.Tensor]]:
"""
Get the parameters for the action distribution.
@ -169,7 +169,7 @@ class Actor(BasePolicy):
# Note: the action is squashed
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]:
def action_log_prob(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tensor]:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# return action and associated log prob
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
@ -216,17 +216,17 @@ class SACPolicy(BasePolicy):
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
activation_fn: type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):
@ -309,7 +309,7 @@ class SACPolicy(BasePolicy):
# Target networks should always be in eval mode
self.critic_target.set_training_mode(False)
def _get_constructor_parameters(self) -> Dict[str, Any]:
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
@ -400,17 +400,17 @@ class CnnPolicy(SACPolicy):
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
activation_fn: type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):
@ -466,17 +466,17 @@ class MultiInputPolicy(SACPolicy):
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
activation_fn: type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):

View file

@ -1,4 +1,4 @@
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, ClassVar, Optional, TypeVar, Union
import numpy as np
import torch as th
@ -77,7 +77,7 @@ class SAC(OffPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
@ -89,7 +89,7 @@ class SAC(OffPolicyAlgorithm):
def __init__(
self,
policy: Union[str, Type[SACPolicy]],
policy: Union[str, type[SACPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
buffer_size: int = 1_000_000, # 1e6
@ -97,11 +97,11 @@ class SAC(OffPolicyAlgorithm):
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 1,
train_freq: Union[int, tuple[int, str]] = 1,
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
optimize_memory_usage: bool = False,
ent_coef: Union[str, float] = "auto",
target_update_interval: int = 1,
@ -111,7 +111,7 @@ class SAC(OffPolicyAlgorithm):
use_sde_at_warmup: bool = False,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_kwargs: Optional[dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
@ -313,10 +313,10 @@ class SAC(OffPolicyAlgorithm):
progress_bar=progress_bar,
)
def _excluded_save_params(self) -> List[str]:
def _excluded_save_params(self) -> list[str]:
return super()._excluded_save_params() + ["actor", "critic", "critic_target"] # noqa: RUF005
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
if self.ent_coef_optimizer is not None:
saved_pytorch_variables = ["log_ent_coef"]

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Optional, Union
import torch as th
from gymnasium import spaces
@ -36,10 +36,10 @@ class Actor(BasePolicy):
self,
observation_space: spaces.Space,
action_space: spaces.Box,
net_arch: List[int],
net_arch: list[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
activation_fn: type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
):
super().__init__(
@ -59,7 +59,7 @@ class Actor(BasePolicy):
# Deterministic action
self.mu = nn.Sequential(*actor_net)
def _get_constructor_parameters(self) -> Dict[str, Any]:
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
@ -116,13 +116,13 @@ class TD3Policy(BasePolicy):
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
activation_fn: type[nn.Module] = nn.ReLU,
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):
@ -207,7 +207,7 @@ class TD3Policy(BasePolicy):
self.actor_target.set_training_mode(False)
self.critic_target.set_training_mode(False)
def _get_constructor_parameters(self) -> Dict[str, Any]:
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
@ -285,13 +285,13 @@ class CnnPolicy(TD3Policy):
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
activation_fn: type[nn.Module] = nn.ReLU,
features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):
@ -339,13 +339,13 @@ class MultiInputPolicy(TD3Policy):
observation_space: spaces.Dict,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
activation_fn: type[nn.Module] = nn.ReLU,
features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):

View file

@ -1,4 +1,4 @@
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, ClassVar, Optional, TypeVar, Union
import numpy as np
import torch as th
@ -65,7 +65,7 @@ class TD3(OffPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
@ -78,7 +78,7 @@ class TD3(OffPolicyAlgorithm):
def __init__(
self,
policy: Union[str, Type[TD3Policy]],
policy: Union[str, type[TD3Policy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-3,
buffer_size: int = 1_000_000, # 1e6
@ -86,18 +86,18 @@ class TD3(OffPolicyAlgorithm):
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 1,
train_freq: Union[int, tuple[int, str]] = 1,
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
optimize_memory_usage: bool = False,
policy_delay: int = 2,
target_policy_noise: float = 0.2,
target_noise_clip: float = 0.5,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_kwargs: Optional[dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
@ -228,9 +228,9 @@ class TD3(OffPolicyAlgorithm):
progress_bar=progress_bar,
)
def _excluded_save_params(self) -> List[str]:
def _excluded_save_params(self) -> list[str]:
return super()._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"] # noqa: RUF005
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
return state_dicts, []

View file

@ -1 +1 @@
2.4.0
2.5.0a0

View file

@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Optional
import gymnasium as gym
import numpy as np
@ -72,7 +72,7 @@ class DummyDictEnv(gym.Env):
terminated = truncated = False
return self.observation_space.sample(), reward, terminated, truncated, {}
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
if seed is not None:
self.observation_space.seed(seed)
return self.observation_space.sample(), {}

View file

@ -1,5 +1,4 @@
from copy import deepcopy
from typing import Tuple
import gymnasium as gym
import numpy as np
@ -55,7 +54,7 @@ def test_squashed_gaussian(model_class):
@pytest.fixture()
def dummy_model_distribution_obs_and_actions() -> Tuple[A2C, np.ndarray, np.ndarray]:
def dummy_model_distribution_obs_and_actions() -> tuple[A2C, np.ndarray, np.ndarray]:
"""
Fixture creating a Pendulum-v1 gym env, an A2C model and sampling 10 random observations and actions from the env
:return: A2C model, random observations, random actions

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Tuple
from typing import Any, Optional
import gymnasium as gym
import numpy as np
@ -135,7 +135,7 @@ def test_check_env_detailed_error(obs_tuple, method):
class TestEnv(gym.Env):
action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32)
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
return wrong_obs if method == "reset" else good_obs, {}
def step(self, action):
@ -162,7 +162,7 @@ class LimitedStepsTestEnv(gym.Env):
self._steps_called = 0
self._terminated = False
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[int, Dict]:
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[int, dict]:
super().reset(seed=seed)
self._steps_called = 0
@ -170,7 +170,7 @@ class LimitedStepsTestEnv(gym.Env):
return 0, {}
def step(self, action: np.ndarray) -> Tuple[int, float, bool, bool, Dict[str, Any]]:
def step(self, action: np.ndarray) -> tuple[int, float, bool, bool, dict[str, Any]]:
self._steps_called += 1
assert not self._terminated

View file

@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Optional
import gymnasium as gym
import numpy as np
@ -23,7 +23,7 @@ class CustomEnv(gym.Env):
def seed(self, seed):
self.observation_space.seed(seed)
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
if seed is not None:
self.observation_space.seed(seed)
self.n_steps = 0
@ -53,7 +53,7 @@ class InfiniteHorizonEnv(gym.Env):
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.current_state = 0
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
if seed is not None:
super().reset(seed=seed)

View file

@ -2,8 +2,8 @@ import importlib.util
import os
import sys
import time
from collections.abc import Sequence
from io import TextIOBase
from typing import Sequence
from unittest import mock
import gymnasium as gym

View file

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Optional
import gymnasium as gym
import numpy as np
@ -24,7 +24,7 @@ class DummyEnv(gym.Env):
def step(self, action):
return self.observation_space.sample(), 0.0, False, False, {}
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
if seed is not None:
super().reset(seed=seed)
return self.observation_space.sample(), {}

View file

@ -1,5 +1,5 @@
import os
from typing import Dict, Union
from typing import Union
import pytest
@ -24,7 +24,7 @@ class HParamCallback(BaseCallback):
"""
def _on_training_start(self) -> None:
hparam_dict: Dict[str, Union[str, float]] = {
hparam_dict: dict[str, Union[str, float]] = {
"algorithm": self.model.__class__.__name__,
# Ignore type checking for gamma, see https://github.com/DLR-RM/stable-baselines3/pull/1194/files#r1035006458
"gamma": self.model.gamma, # type: ignore[attr-defined]
@ -33,7 +33,7 @@ class HParamCallback(BaseCallback):
hparam_dict["learning rate"] = self.model.learning_rate
# define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
# Tensorbaord will find & display metrics from the `SCALARS` tab
metric_dict: Dict[str, float] = {
metric_dict: dict[str, float] = {
"rollout/ep_len_mean": 0,
}
self.logger.record(

View file

@ -4,7 +4,7 @@ import itertools
import multiprocessing
import os
import warnings
from typing import Dict, Optional
from typing import Optional
import gymnasium as gym
import numpy as np
@ -30,9 +30,9 @@ class CustomGymEnv(gym.Env):
self.current_step = 0
self.ep_length = 4
self.render_mode = render_mode
self.current_options: Optional[Dict] = None
self.current_options: Optional[dict] = None
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
if seed is not None:
self.seed(seed)
self.current_step = 0
@ -193,7 +193,7 @@ class StepEnv(gym.Env):
self.max_steps = max_steps
self.current_step = 0
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
self.current_step = 0
return np.array([self.current_step], dtype="int"), {}

View file

@ -1,5 +1,5 @@
import operator
from typing import Any, Dict, Optional
from typing import Any, Optional
import gymnasium as gym
import numpy as np
@ -22,7 +22,7 @@ ENV_ID = "Pendulum-v1"
class DummyRewardEnv(gym.Env):
metadata: Dict[str, Any] = {}
metadata: dict[str, Any] = {}
def __init__(self, return_reward_idx=0):
self.action_space = spaces.Discrete(2)
@ -39,7 +39,7 @@ class DummyRewardEnv(gym.Env):
truncated = self.t == len(self.returned_rewards)
return np.array([returned_value]), returned_value, terminated, truncated, {}
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
if seed is not None:
super().reset(seed=seed)
self.t = 0
@ -62,7 +62,7 @@ class DummyDictEnv(gym.Env):
)
self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
if seed is not None:
super().reset(seed=seed)
return self.observation_space.sample(), {}
@ -94,7 +94,7 @@ class DummyMixedDictEnv(gym.Env):
)
self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
if seed is not None:
super().reset(seed=seed)
return self.observation_space.sample(), {}