mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Drop python 3.8 and add python 3.12 support (#2041)
* Drop python 3.8 support, add python 3.12 support * Upgrade to python 3.9 syntax * Fixes for Numpy v2 * Fix doc warning
This commit is contained in:
parent
020ee42f4d
commit
daaebd0a52
66 changed files with 530 additions and 483 deletions
7
.github/workflows/ci.yml
vendored
7
.github/workflows/ci.yml
vendored
|
|
@ -20,7 +20,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
include:
|
||||
# Default version
|
||||
- gymnasium-version: "1.0.0"
|
||||
|
|
@ -48,7 +48,8 @@ jobs:
|
|||
- name: Install specific version of gym
|
||||
run: |
|
||||
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
|
||||
# Only run for python 3.10, downgrade gym to 0.29.1
|
||||
uv pip install --system "numpy<2"
|
||||
# Only run for python 3.10, downgrade gym to 0.29.1, numpy<2
|
||||
if: matrix.gymnasium-version != '1.0.0'
|
||||
- name: Lint with ruff
|
||||
run: |
|
||||
|
|
@ -62,8 +63,6 @@ jobs:
|
|||
- name: Type check
|
||||
run: |
|
||||
make type
|
||||
# Do not run for python 3.8 (mypy internal error)
|
||||
if: matrix.python-version != '3.8'
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make pytest
|
||||
|
|
|
|||
|
|
@ -100,10 +100,10 @@ It provides a minimal number of features compared to SB3 but can be much faster
|
|||
|
||||
## Installation
|
||||
|
||||
**Note:** Stable-Baselines3 supports PyTorch >= 1.13
|
||||
**Note:** Stable-Baselines3 supports PyTorch >= 2.3
|
||||
|
||||
### Prerequisites
|
||||
Stable Baselines3 requires Python 3.8+.
|
||||
Stable Baselines3 requires Python 3.9+.
|
||||
|
||||
#### Windows
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ dependencies:
|
|||
- cloudpickle
|
||||
- opencv-python-headless
|
||||
- pandas
|
||||
- numpy>=1.20,<2.0
|
||||
- numpy>=1.20,<3.0
|
||||
- matplotlib
|
||||
- sphinx>=5,<9
|
||||
- sphinx_rtd_theme>=1.3.0
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@
|
|||
import datetime
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict
|
||||
|
||||
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
|
||||
# PyEnchant.
|
||||
|
|
@ -151,7 +150,7 @@ htmlhelp_basename = "StableBaselines3doc"
|
|||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
latex_elements: Dict[str, str] = {
|
||||
latex_elements: dict[str, str] = {
|
||||
# The paper size ('letterpaper' or 'a4paper').
|
||||
#
|
||||
# 'papersize': 'letterpaper',
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ Installation
|
|||
Prerequisites
|
||||
-------------
|
||||
|
||||
Stable-Baselines3 requires python 3.8+ and PyTorch >= 1.13
|
||||
Stable-Baselines3 requires python 3.9+ and PyTorch >= 2.3
|
||||
|
||||
Windows
|
||||
~~~~~~~
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ RL Baselines3 Zoo provides a collection of pre-trained agents, scripts for train
|
|||
|
||||
SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
|
||||
|
||||
SBX (SB3 + Jax): https://github.com/araffin/sbx
|
||||
|
||||
|
||||
Main Features
|
||||
--------------
|
||||
|
|
|
|||
|
|
@ -3,6 +3,41 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.5.0a0 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Increased minimum required version of PyTorch to 2.3.0
|
||||
- Removed support for Python 3.8
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too
|
||||
- Added official support for Python 3.12
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
||||
`SB3-Contrib`_
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
`RL Zoo`_
|
||||
^^^^^^^^^
|
||||
|
||||
`SBX`_ (SB3 + Jax)
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Release 2.4.0 (2024-11-18)
|
||||
--------------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
[tool.ruff]
|
||||
# Same as Black.
|
||||
line-length = 127
|
||||
# Assume Python 3.8
|
||||
target-version = "py38"
|
||||
# Assume Python 3.9
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
# See https://beta.ruff.rs/docs/rules/
|
||||
|
|
|
|||
8
setup.py
8
setup.py
|
|
@ -77,8 +77,8 @@ setup(
|
|||
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"gymnasium>=0.29.1,<1.1.0",
|
||||
"numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302
|
||||
"torch>=1.13",
|
||||
"numpy>=1.20,<3.0",
|
||||
"torch>=2.3,<3.0",
|
||||
# For saving models
|
||||
"cloudpickle",
|
||||
# For reading logs
|
||||
|
|
@ -135,7 +135,7 @@ setup(
|
|||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
version=__version__,
|
||||
python_requires=">=3.8",
|
||||
python_requires=">=3.9",
|
||||
# PyPI package information.
|
||||
project_urls={
|
||||
"Code": "https://github.com/DLR-RM/stable-baselines3",
|
||||
|
|
@ -147,10 +147,10 @@ setup(
|
|||
},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
|
||||
from typing import Any, ClassVar, Optional, TypeVar, Union
|
||||
|
||||
import torch as th
|
||||
from gymnasium import spaces
|
||||
|
|
@ -57,7 +57,7 @@ class A2C(OnPolicyAlgorithm):
|
|||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
|
||||
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
|
||||
"MlpPolicy": ActorCriticPolicy,
|
||||
"CnnPolicy": ActorCriticCnnPolicy,
|
||||
"MultiInputPolicy": MultiInputActorCriticPolicy,
|
||||
|
|
@ -65,7 +65,7 @@ class A2C(OnPolicyAlgorithm):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[ActorCriticPolicy]],
|
||||
policy: Union[str, type[ActorCriticPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 7e-4,
|
||||
n_steps: int = 5,
|
||||
|
|
@ -78,12 +78,12 @@ class A2C(OnPolicyAlgorithm):
|
|||
use_rms_prop: bool = True,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
rollout_buffer_class: Optional[type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[dict[str, Any]] = None,
|
||||
normalize_advantage: bool = False,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
policy_kwargs: Optional[dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, SupportsFloat
|
||||
from typing import SupportsFloat
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -64,7 +64,7 @@ class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
|
|||
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
|
||||
assert noops > 0
|
||||
obs = np.zeros(0)
|
||||
info: Dict = {}
|
||||
info: dict = {}
|
||||
for _ in range(noops):
|
||||
obs, _, terminated, truncated, info = self.env.step(self.noop_action)
|
||||
if terminated or truncated:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ import time
|
|||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, ClassVar, Optional, TypeVar, Union
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -94,7 +95,7 @@ class BaseAlgorithm(ABC):
|
|||
"""
|
||||
|
||||
# Policy aliases (see _get_policy_from_name())
|
||||
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {}
|
||||
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {}
|
||||
policy: BasePolicy
|
||||
observation_space: spaces.Space
|
||||
action_space: spaces.Space
|
||||
|
|
@ -104,10 +105,10 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[BasePolicy]],
|
||||
policy: Union[str, type[BasePolicy]],
|
||||
env: Union[GymEnv, str, None],
|
||||
learning_rate: Union[float, Schedule],
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
policy_kwargs: Optional[dict[str, Any]] = None,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
|
|
@ -117,7 +118,7 @@ class BaseAlgorithm(ABC):
|
|||
seed: Optional[int] = None,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
|
||||
supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None,
|
||||
) -> None:
|
||||
if isinstance(policy, str):
|
||||
self.policy_class = self._get_policy_from_name(policy)
|
||||
|
|
@ -141,10 +142,10 @@ class BaseAlgorithm(ABC):
|
|||
self.start_time = 0.0
|
||||
self.learning_rate = learning_rate
|
||||
self.tensorboard_log = tensorboard_log
|
||||
self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
|
||||
self._last_obs = None # type: Optional[Union[np.ndarray, dict[str, np.ndarray]]]
|
||||
self._last_episode_starts = None # type: Optional[np.ndarray]
|
||||
# When using VecNormalize:
|
||||
self._last_original_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
|
||||
self._last_original_obs = None # type: Optional[Union[np.ndarray, dict[str, np.ndarray]]]
|
||||
self._episode_num = 0
|
||||
# Used for gSDE only
|
||||
self.use_sde = use_sde
|
||||
|
|
@ -283,7 +284,7 @@ class BaseAlgorithm(ABC):
|
|||
"""
|
||||
self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps)
|
||||
|
||||
def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None:
|
||||
def _update_learning_rate(self, optimizers: Union[list[th.optim.Optimizer], th.optim.Optimizer]) -> None:
|
||||
"""
|
||||
Update the optimizers learning rate using the current learning rate schedule
|
||||
and the current progress remaining (from 1 to 0).
|
||||
|
|
@ -299,7 +300,7 @@ class BaseAlgorithm(ABC):
|
|||
for optimizer in optimizers:
|
||||
update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining))
|
||||
|
||||
def _excluded_save_params(self) -> List[str]:
|
||||
def _excluded_save_params(self) -> list[str]:
|
||||
"""
|
||||
Returns the names of the parameters that should be excluded from being
|
||||
saved by pickling. E.g. replay buffers are skipped by default
|
||||
|
|
@ -320,7 +321,7 @@ class BaseAlgorithm(ABC):
|
|||
"_custom_logger",
|
||||
]
|
||||
|
||||
def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]:
|
||||
def _get_policy_from_name(self, policy_name: str) -> type[BasePolicy]:
|
||||
"""
|
||||
Get a policy class from its name representation.
|
||||
|
||||
|
|
@ -337,7 +338,7 @@ class BaseAlgorithm(ABC):
|
|||
else:
|
||||
raise ValueError(f"Policy {policy_name} unknown")
|
||||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Get the name of the torch variables that will be saved with
|
||||
PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default
|
||||
|
|
@ -387,7 +388,7 @@ class BaseAlgorithm(ABC):
|
|||
reset_num_timesteps: bool = True,
|
||||
tb_log_name: str = "run",
|
||||
progress_bar: bool = False,
|
||||
) -> Tuple[int, BaseCallback]:
|
||||
) -> tuple[int, BaseCallback]:
|
||||
"""
|
||||
Initialize different variables needed for training.
|
||||
|
||||
|
|
@ -435,7 +436,7 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
return total_timesteps, callback
|
||||
|
||||
def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
|
||||
def _update_info_buffer(self, infos: list[dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
|
||||
"""
|
||||
Retrieve reward, episode length, episode success and update the buffer
|
||||
if using Monitor wrapper or a GoalEnv.
|
||||
|
|
@ -535,11 +536,11 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
def predict(
|
||||
self,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
observation: Union[np.ndarray, dict[str, np.ndarray]],
|
||||
state: Optional[tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Get the policy action from an observation (and optional hidden state).
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
|
@ -640,11 +641,11 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
@classmethod
|
||||
def load( # noqa: C901
|
||||
cls: Type[SelfBaseAlgorithm],
|
||||
cls: type[SelfBaseAlgorithm],
|
||||
path: Union[str, pathlib.Path, io.BufferedIOBase],
|
||||
env: Optional[GymEnv] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
custom_objects: Optional[Dict[str, Any]] = None,
|
||||
custom_objects: Optional[dict[str, Any]] = None,
|
||||
print_system_info: bool = False,
|
||||
force_reset: bool = True,
|
||||
**kwargs,
|
||||
|
|
@ -800,7 +801,7 @@ class BaseAlgorithm(ABC):
|
|||
model.policy.reset_noise() # type: ignore[operator]
|
||||
return model
|
||||
|
||||
def get_parameters(self) -> Dict[str, Dict]:
|
||||
def get_parameters(self) -> dict[str, dict]:
|
||||
"""
|
||||
Return the parameters of the agent. This includes parameters from different networks, e.g.
|
||||
critics (value functions) and policies (pi functions).
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -36,7 +37,7 @@ class BaseBuffer(ABC):
|
|||
"""
|
||||
|
||||
observation_space: spaces.Space
|
||||
obs_shape: Tuple[int, ...]
|
||||
obs_shape: tuple[int, ...]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -140,9 +141,9 @@ class BaseBuffer(ABC):
|
|||
|
||||
@staticmethod
|
||||
def _normalize_obs(
|
||||
obs: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
obs: Union[np.ndarray, dict[str, np.ndarray]],
|
||||
env: Optional[VecNormalize] = None,
|
||||
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
||||
) -> Union[np.ndarray, dict[str, np.ndarray]]:
|
||||
if env is not None:
|
||||
return env.normalize_obs(obs)
|
||||
return obs
|
||||
|
|
@ -250,7 +251,7 @@ class ReplayBuffer(BaseBuffer):
|
|||
action: np.ndarray,
|
||||
reward: np.ndarray,
|
||||
done: np.ndarray,
|
||||
infos: List[Dict[str, Any]],
|
||||
infos: list[dict[str, Any]],
|
||||
) -> None:
|
||||
# Reshape needed when using multiple envs with discrete observations
|
||||
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
||||
|
|
@ -538,9 +539,9 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
"""
|
||||
|
||||
observation_space: spaces.Dict
|
||||
obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment]
|
||||
observations: Dict[str, np.ndarray] # type: ignore[assignment]
|
||||
next_observations: Dict[str, np.ndarray] # type: ignore[assignment]
|
||||
obs_shape: dict[str, tuple[int, ...]] # type: ignore[assignment]
|
||||
observations: dict[str, np.ndarray] # type: ignore[assignment]
|
||||
next_observations: dict[str, np.ndarray] # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -609,12 +610,12 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
|
||||
def add( # type: ignore[override]
|
||||
self,
|
||||
obs: Dict[str, np.ndarray],
|
||||
next_obs: Dict[str, np.ndarray],
|
||||
obs: dict[str, np.ndarray],
|
||||
next_obs: dict[str, np.ndarray],
|
||||
action: np.ndarray,
|
||||
reward: np.ndarray,
|
||||
done: np.ndarray,
|
||||
infos: List[Dict[str, Any]],
|
||||
infos: list[dict[str, Any]],
|
||||
) -> None:
|
||||
# Copy to avoid modification by reference
|
||||
for key in self.observations.keys():
|
||||
|
|
@ -718,8 +719,8 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
"""
|
||||
|
||||
observation_space: spaces.Dict
|
||||
obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment]
|
||||
observations: Dict[str, np.ndarray] # type: ignore[assignment]
|
||||
obs_shape: dict[str, tuple[int, ...]] # type: ignore[assignment]
|
||||
observations: dict[str, np.ndarray] # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -757,7 +758,7 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
|
||||
def add( # type: ignore[override]
|
||||
self,
|
||||
obs: Dict[str, np.ndarray],
|
||||
obs: dict[str, np.ndarray],
|
||||
action: np.ndarray,
|
||||
reward: np.ndarray,
|
||||
episode_start: np.ndarray,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -45,8 +45,8 @@ class BaseCallback(ABC):
|
|||
# n_envs * n times env.step() was called
|
||||
self.num_timesteps = 0 # type: int
|
||||
self.verbose = verbose
|
||||
self.locals: Dict[str, Any] = {}
|
||||
self.globals: Dict[str, Any] = {}
|
||||
self.locals: dict[str, Any] = {}
|
||||
self.globals: dict[str, Any] = {}
|
||||
# Sometimes, for event callback, it is useful
|
||||
# to have access to the parent object
|
||||
self.parent = None # type: Optional[BaseCallback]
|
||||
|
|
@ -75,7 +75,7 @@ class BaseCallback(ABC):
|
|||
def _init_callback(self) -> None:
|
||||
pass
|
||||
|
||||
def on_training_start(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
|
||||
def on_training_start(self, locals_: dict[str, Any], globals_: dict[str, Any]) -> None:
|
||||
# Those are reference and will be updated automatically
|
||||
self.locals = locals_
|
||||
self.globals = globals_
|
||||
|
|
@ -125,7 +125,7 @@ class BaseCallback(ABC):
|
|||
def _on_rollout_end(self) -> None:
|
||||
pass
|
||||
|
||||
def update_locals(self, locals_: Dict[str, Any]) -> None:
|
||||
def update_locals(self, locals_: dict[str, Any]) -> None:
|
||||
"""
|
||||
Update the references to the local variables.
|
||||
|
||||
|
|
@ -134,7 +134,7 @@ class BaseCallback(ABC):
|
|||
self.locals.update(locals_)
|
||||
self.update_child_locals(locals_)
|
||||
|
||||
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
|
||||
def update_child_locals(self, locals_: dict[str, Any]) -> None:
|
||||
"""
|
||||
Update the references to the local variables on sub callbacks.
|
||||
|
||||
|
|
@ -177,7 +177,7 @@ class EventCallback(BaseCallback):
|
|||
def _on_step(self) -> bool:
|
||||
return True
|
||||
|
||||
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
|
||||
def update_child_locals(self, locals_: dict[str, Any]) -> None:
|
||||
"""
|
||||
Update the references to the local variables.
|
||||
|
||||
|
|
@ -195,7 +195,7 @@ class CallbackList(BaseCallback):
|
|||
sequentially.
|
||||
"""
|
||||
|
||||
def __init__(self, callbacks: List[BaseCallback]):
|
||||
def __init__(self, callbacks: list[BaseCallback]):
|
||||
super().__init__()
|
||||
assert isinstance(callbacks, list)
|
||||
self.callbacks = callbacks
|
||||
|
|
@ -231,7 +231,7 @@ class CallbackList(BaseCallback):
|
|||
for callback in self.callbacks:
|
||||
callback.on_training_end()
|
||||
|
||||
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
|
||||
def update_child_locals(self, locals_: dict[str, Any]) -> None:
|
||||
"""
|
||||
Update the references to the local variables.
|
||||
|
||||
|
|
@ -328,7 +328,7 @@ class ConvertCallback(BaseCallback):
|
|||
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
|
||||
"""
|
||||
|
||||
def __init__(self, callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], bool]], verbose: int = 0):
|
||||
def __init__(self, callback: Optional[Callable[[dict[str, Any], dict[str, Any]], bool]], verbose: int = 0):
|
||||
super().__init__(verbose)
|
||||
self.callback = callback
|
||||
|
||||
|
|
@ -405,12 +405,12 @@ class EvalCallback(EventCallback):
|
|||
if log_path is not None:
|
||||
log_path = os.path.join(log_path, "evaluations")
|
||||
self.log_path = log_path
|
||||
self.evaluations_results: List[List[float]] = []
|
||||
self.evaluations_timesteps: List[int] = []
|
||||
self.evaluations_length: List[List[int]] = []
|
||||
self.evaluations_results: list[list[float]] = []
|
||||
self.evaluations_timesteps: list[int] = []
|
||||
self.evaluations_length: list[list[int]] = []
|
||||
# For computing success rate
|
||||
self._is_success_buffer: List[bool] = []
|
||||
self.evaluations_successes: List[List[bool]] = []
|
||||
self._is_success_buffer: list[bool] = []
|
||||
self.evaluations_successes: list[list[bool]] = []
|
||||
|
||||
def _init_callback(self) -> None:
|
||||
# Does not work in some corner cases, where the wrapper is not the same
|
||||
|
|
@ -427,7 +427,7 @@ class EvalCallback(EventCallback):
|
|||
if self.callback_on_new_best is not None:
|
||||
self.callback_on_new_best.init_callback(self.model)
|
||||
|
||||
def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
|
||||
def _log_success_callback(self, locals_: dict[str, Any], globals_: dict[str, Any]) -> None:
|
||||
"""
|
||||
Callback passed to the ``evaluate_policy`` function
|
||||
in order to log the success rate (when applicable),
|
||||
|
|
@ -530,7 +530,7 @@ class EvalCallback(EventCallback):
|
|||
|
||||
return continue_training
|
||||
|
||||
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
|
||||
def update_child_locals(self, locals_: dict[str, Any]) -> None:
|
||||
"""
|
||||
Update the references to the local variables.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Probability distributions."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
|
||||
from typing import Any, Optional, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -30,7 +30,7 @@ class Distribution(ABC):
|
|||
self.distribution = None
|
||||
|
||||
@abstractmethod
|
||||
def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, Tuple[nn.Module, nn.Parameter]]:
|
||||
def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, tuple[nn.Module, nn.Parameter]]:
|
||||
"""Create the layers and parameters that represent the distribution.
|
||||
|
||||
Subclasses must define this, but the arguments and return type vary between
|
||||
|
|
@ -98,7 +98,7 @@ class Distribution(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]:
|
||||
def log_prob_from_params(self, *args, **kwargs) -> tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Returns samples and the associated log probabilities
|
||||
from the probability distribution given its parameters.
|
||||
|
|
@ -135,7 +135,7 @@ class DiagGaussianDistribution(Distribution):
|
|||
self.mean_actions = None
|
||||
self.log_std = None
|
||||
|
||||
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
|
||||
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> tuple[nn.Module, nn.Parameter]:
|
||||
"""
|
||||
Create the layers and parameter that represent the distribution:
|
||||
one output will be the mean of the Gaussian, the other parameter will be the
|
||||
|
|
@ -190,7 +190,7 @@ class DiagGaussianDistribution(Distribution):
|
|||
self.proba_distribution(mean_actions, log_std)
|
||||
return self.get_actions(deterministic=deterministic)
|
||||
|
||||
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Compute the log probability of taking an action
|
||||
given the distribution parameters.
|
||||
|
|
@ -254,7 +254,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
# Squash the output
|
||||
return th.tanh(self.gaussian_actions)
|
||||
|
||||
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
|
||||
action = self.actions_from_params(mean_actions, log_std)
|
||||
log_prob = self.log_prob(action, self.gaussian_actions)
|
||||
return action, log_prob
|
||||
|
|
@ -305,7 +305,7 @@ class CategoricalDistribution(Distribution):
|
|||
self.proba_distribution(action_logits)
|
||||
return self.get_actions(deterministic=deterministic)
|
||||
|
||||
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
|
||||
actions = self.actions_from_params(action_logits)
|
||||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
|
@ -318,7 +318,7 @@ class MultiCategoricalDistribution(Distribution):
|
|||
:param action_dims: List of sizes of discrete action spaces
|
||||
"""
|
||||
|
||||
def __init__(self, action_dims: List[int]):
|
||||
def __init__(self, action_dims: list[int]):
|
||||
super().__init__()
|
||||
self.action_dims = action_dims
|
||||
|
||||
|
|
@ -362,7 +362,7 @@ class MultiCategoricalDistribution(Distribution):
|
|||
self.proba_distribution(action_logits)
|
||||
return self.get_actions(deterministic=deterministic)
|
||||
|
||||
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
|
||||
actions = self.actions_from_params(action_logits)
|
||||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
|
@ -412,7 +412,7 @@ class BernoulliDistribution(Distribution):
|
|||
self.proba_distribution(action_logits)
|
||||
return self.get_actions(deterministic=deterministic)
|
||||
|
||||
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
|
||||
actions = self.actions_from_params(action_logits)
|
||||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
|
@ -513,7 +513,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
|
||||
def proba_distribution_net(
|
||||
self, latent_dim: int, log_std_init: float = -2.0, latent_sde_dim: Optional[int] = None
|
||||
) -> Tuple[nn.Module, nn.Parameter]:
|
||||
) -> tuple[nn.Module, nn.Parameter]:
|
||||
"""
|
||||
Create the layers and parameter that represent the distribution:
|
||||
one output will be the deterministic action, the other parameter will be the
|
||||
|
|
@ -611,7 +611,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
|
||||
def log_prob_from_params(
|
||||
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
|
||||
) -> Tuple[th.Tensor, th.Tensor]:
|
||||
) -> tuple[th.Tensor, th.Tensor]:
|
||||
actions = self.actions_from_params(mean_actions, log_std, latent_sde)
|
||||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
|
@ -661,7 +661,7 @@ class TanhBijector:
|
|||
|
||||
|
||||
def make_proba_distribution(
|
||||
action_space: spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
|
||||
action_space: spaces.Space, use_sde: bool = False, dist_kwargs: Optional[dict[str, Any]] = None
|
||||
) -> Distribution:
|
||||
"""
|
||||
Return an instance of Distribution for the correct type of action space
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import warnings
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -172,10 +172,10 @@ def _check_goal_env_obs(obs: dict, observation_space: spaces.Dict, method_name:
|
|||
|
||||
|
||||
def _check_goal_env_compute_reward(
|
||||
obs: Dict[str, Union[np.ndarray, int]],
|
||||
obs: dict[str, Union[np.ndarray, int]],
|
||||
env: gym.Env,
|
||||
reward: float,
|
||||
info: Dict[str, Any],
|
||||
info: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Check that reward is computed with `compute_reward`
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Type, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
|
|
@ -9,7 +9,7 @@ from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
|
|||
from stable_baselines3.common.vec_env.patch_gym import _patch_env
|
||||
|
||||
|
||||
def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]:
|
||||
def unwrap_wrapper(env: gym.Env, wrapper_class: type[gym.Wrapper]) -> Optional[gym.Wrapper]:
|
||||
"""
|
||||
Retrieve a ``VecEnvWrapper`` object by recursively searching.
|
||||
|
||||
|
|
@ -25,7 +25,7 @@ def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[g
|
|||
return None
|
||||
|
||||
|
||||
def is_wrapped(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> bool:
|
||||
def is_wrapped(env: gym.Env, wrapper_class: type[gym.Wrapper]) -> bool:
|
||||
"""
|
||||
Check if a given environment has been wrapped with a given wrapper.
|
||||
|
||||
|
|
@ -43,11 +43,11 @@ def make_vec_env(
|
|||
start_index: int = 0,
|
||||
monitor_dir: Optional[str] = None,
|
||||
wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None,
|
||||
env_kwargs: Optional[Dict[str, Any]] = None,
|
||||
vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
|
||||
vec_env_kwargs: Optional[Dict[str, Any]] = None,
|
||||
monitor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
wrapper_kwargs: Optional[Dict[str, Any]] = None,
|
||||
env_kwargs: Optional[dict[str, Any]] = None,
|
||||
vec_env_cls: Optional[type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
|
||||
vec_env_kwargs: Optional[dict[str, Any]] = None,
|
||||
monitor_kwargs: Optional[dict[str, Any]] = None,
|
||||
wrapper_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> VecEnv:
|
||||
"""
|
||||
Create a wrapped, monitored ``VecEnv``.
|
||||
|
|
@ -134,11 +134,11 @@ def make_atari_env(
|
|||
seed: Optional[int] = None,
|
||||
start_index: int = 0,
|
||||
monitor_dir: Optional[str] = None,
|
||||
wrapper_kwargs: Optional[Dict[str, Any]] = None,
|
||||
env_kwargs: Optional[Dict[str, Any]] = None,
|
||||
vec_env_cls: Optional[Union[Type[DummyVecEnv], Type[SubprocVecEnv]]] = None,
|
||||
vec_env_kwargs: Optional[Dict[str, Any]] = None,
|
||||
monitor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
wrapper_kwargs: Optional[dict[str, Any]] = None,
|
||||
env_kwargs: Optional[dict[str, Any]] = None,
|
||||
vec_env_cls: Optional[Union[type[DummyVecEnv], type[SubprocVecEnv]]] = None,
|
||||
vec_env_kwargs: Optional[dict[str, Any]] = None,
|
||||
monitor_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> VecEnv:
|
||||
"""
|
||||
Create a wrapped, monitored VecEnv for Atari.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from gymnasium import Env, spaces
|
||||
|
|
@ -75,14 +75,17 @@ class BitFlippingEnv(Env):
|
|||
:param state:
|
||||
:return:
|
||||
"""
|
||||
|
||||
if self.discrete_obs_space:
|
||||
# Convert from int8 to int32 for NumPy 2.0
|
||||
state = state.astype(np.int32)
|
||||
# The internal state is the binary representation of the
|
||||
# observed one
|
||||
return int(sum(state[i] * 2**i for i in range(len(state))))
|
||||
|
||||
if self.image_obs_space:
|
||||
size = np.prod(self.image_shape)
|
||||
image = np.concatenate((state * 255, np.zeros(size - len(state), dtype=np.uint8)))
|
||||
image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8)))
|
||||
return image.reshape(self.image_shape).astype(np.uint8)
|
||||
return state
|
||||
|
||||
|
|
@ -163,7 +166,7 @@ class BitFlippingEnv(Env):
|
|||
}
|
||||
)
|
||||
|
||||
def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]:
|
||||
def _get_obs(self) -> dict[str, Union[int, np.ndarray]]:
|
||||
"""
|
||||
Helper to create the observation.
|
||||
|
||||
|
|
@ -178,8 +181,8 @@ class BitFlippingEnv(Env):
|
|||
)
|
||||
|
||||
def reset(
|
||||
self, *, seed: Optional[int] = None, options: Optional[Dict] = None
|
||||
) -> Tuple[Dict[str, Union[int, np.ndarray]], Dict]:
|
||||
self, *, seed: Optional[int] = None, options: Optional[dict] = None
|
||||
) -> tuple[dict[str, Union[int, np.ndarray]], dict]:
|
||||
if seed is not None:
|
||||
self._obs_space.seed(seed)
|
||||
self.current_step = 0
|
||||
|
|
@ -207,7 +210,7 @@ class BitFlippingEnv(Env):
|
|||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def compute_reward(
|
||||
self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[Dict[str, Any]]
|
||||
self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[dict[str, Any]]
|
||||
) -> np.float32:
|
||||
# As we are using a vectorized version, we need to keep track of the `batch_size`
|
||||
if isinstance(achieved_goal, int):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union
|
||||
from typing import Any, Generic, Optional, TypeVar, Union
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -34,7 +34,7 @@ class IdentityEnv(gym.Env, Generic[T]):
|
|||
self.num_resets = -1 # Becomes 0 after __init__ exits.
|
||||
self.reset()
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[T, Dict]:
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[T, dict]:
|
||||
if seed is not None:
|
||||
super().reset(seed=seed)
|
||||
self.current_step = 0
|
||||
|
|
@ -42,7 +42,7 @@ class IdentityEnv(gym.Env, Generic[T]):
|
|||
self._choose_next_state()
|
||||
return self.state, {}
|
||||
|
||||
def step(self, action: T) -> Tuple[T, float, bool, bool, Dict[str, Any]]:
|
||||
def step(self, action: T) -> tuple[T, float, bool, bool, dict[str, Any]]:
|
||||
reward = self._get_reward(action)
|
||||
self._choose_next_state()
|
||||
self.current_step += 1
|
||||
|
|
@ -74,7 +74,7 @@ class IdentityEnvBox(IdentityEnv[np.ndarray]):
|
|||
super().__init__(ep_length=ep_length, space=space)
|
||||
self.eps = eps
|
||||
|
||||
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
|
||||
def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]:
|
||||
reward = self._get_reward(action)
|
||||
self._choose_next_state()
|
||||
self.current_step += 1
|
||||
|
|
@ -142,7 +142,7 @@ class FakeImageEnv(gym.Env):
|
|||
self.ep_length = 10
|
||||
self.current_step = 0
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]:
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[np.ndarray, dict]:
|
||||
if seed is not None:
|
||||
super().reset(seed=seed)
|
||||
self.current_step = 0
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -73,7 +73,7 @@ class SimpleMultiObsEnv(gym.Env):
|
|||
self.init_possible_transitions()
|
||||
|
||||
self.num_col = num_col
|
||||
self.state_mapping: List[Dict[str, np.ndarray]] = []
|
||||
self.state_mapping: list[dict[str, np.ndarray]] = []
|
||||
self.init_state_mapping(num_col, num_row)
|
||||
|
||||
self.max_state = len(self.state_mapping) - 1
|
||||
|
|
@ -94,7 +94,7 @@ class SimpleMultiObsEnv(gym.Env):
|
|||
for j in range(num_row):
|
||||
self.state_mapping.append({"vec": col_vecs[i], "img": row_imgs[j].reshape(self.img_size)})
|
||||
|
||||
def get_state_mapping(self) -> Dict[str, np.ndarray]:
|
||||
def get_state_mapping(self) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Uses the state to get the observation mapping.
|
||||
|
||||
|
|
@ -166,7 +166,7 @@ class SimpleMultiObsEnv(gym.Env):
|
|||
"""
|
||||
print(self.log)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[Dict[str, np.ndarray], Dict]:
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[dict[str, np.ndarray], dict]:
|
||||
"""
|
||||
Resets the environment state and step count and returns reset observation.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -14,11 +14,11 @@ def evaluate_policy(
|
|||
n_eval_episodes: int = 10,
|
||||
deterministic: bool = True,
|
||||
render: bool = False,
|
||||
callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
|
||||
callback: Optional[Callable[[dict[str, Any], dict[str, Any]], None]] = None,
|
||||
reward_threshold: Optional[float] = None,
|
||||
return_episode_rewards: bool = False,
|
||||
warn: bool = True,
|
||||
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
|
||||
) -> Union[tuple[float, float], tuple[list[float], list[int]]]:
|
||||
"""
|
||||
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
|
||||
If a vector env is passed in, this divides the episodes to evaluate onto the
|
||||
|
|
|
|||
|
|
@ -5,8 +5,9 @@ import sys
|
|||
import tempfile
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from io import TextIOBase
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, TextIO, Tuple, Union
|
||||
from typing import Any, Optional, TextIO, Union
|
||||
|
||||
import matplotlib.figure
|
||||
import numpy as np
|
||||
|
|
@ -114,7 +115,7 @@ class KVWriter:
|
|||
Key Value writer
|
||||
"""
|
||||
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
|
||||
def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None:
|
||||
"""
|
||||
Write a dictionary to file
|
||||
|
||||
|
|
@ -136,7 +137,7 @@ class SeqWriter:
|
|||
sequence writer
|
||||
"""
|
||||
|
||||
def write_sequence(self, sequence: List[str]) -> None:
|
||||
def write_sequence(self, sequence: list[str]) -> None:
|
||||
"""
|
||||
write_sequence an array to file
|
||||
|
||||
|
|
@ -172,7 +173,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
else:
|
||||
raise ValueError(f"Expected file or str, got {filename_or_file}")
|
||||
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
|
||||
def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None:
|
||||
# Create strings for printing
|
||||
key2str = {}
|
||||
tag = ""
|
||||
|
|
@ -244,7 +245,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
string = string[: self.max_length - 3] + "..."
|
||||
return string
|
||||
|
||||
def write_sequence(self, sequence: List[str]) -> None:
|
||||
def write_sequence(self, sequence: list[str]) -> None:
|
||||
for i, elem in enumerate(sequence):
|
||||
self.file.write(elem)
|
||||
if i < len(sequence) - 1: # add space unless this is the last one
|
||||
|
|
@ -260,7 +261,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
self.file.close()
|
||||
|
||||
|
||||
def filter_excluded_keys(key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], _format: str) -> Dict[str, Any]:
|
||||
def filter_excluded_keys(key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], _format: str) -> dict[str, Any]:
|
||||
"""
|
||||
Filters the keys specified by ``key_exclude`` for the specified format
|
||||
|
||||
|
|
@ -286,7 +287,7 @@ class JSONOutputFormat(KVWriter):
|
|||
def __init__(self, filename: str):
|
||||
self.file = open(filename, "w")
|
||||
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
|
||||
def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None:
|
||||
def cast_to_json_serializable(value: Any):
|
||||
if isinstance(value, Video):
|
||||
raise FormatUnsupportedError(["json"], "video")
|
||||
|
|
@ -328,12 +329,12 @@ class CSVOutputFormat(KVWriter):
|
|||
"""
|
||||
|
||||
def __init__(self, filename: str):
|
||||
self.file = open(filename, "w+t")
|
||||
self.keys: List[str] = []
|
||||
self.file = open(filename, "w+")
|
||||
self.keys: list[str] = []
|
||||
self.separator = ","
|
||||
self.quotechar = '"'
|
||||
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
|
||||
def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None:
|
||||
# Add our current row to the history
|
||||
key_values = filter_excluded_keys(key_values, key_excluded, "csv")
|
||||
extra_keys = key_values.keys() - self.keys
|
||||
|
|
@ -399,7 +400,7 @@ class TensorBoardOutputFormat(KVWriter):
|
|||
self.writer = SummaryWriter(log_dir=folder)
|
||||
self._is_closed = False
|
||||
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
|
||||
def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None:
|
||||
assert not self._is_closed, "The SummaryWriter was closed, please re-create one."
|
||||
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
|
||||
if excluded is not None and "tensorboard" in excluded:
|
||||
|
|
@ -481,16 +482,16 @@ class Logger:
|
|||
:param output_formats: the list of output formats
|
||||
"""
|
||||
|
||||
def __init__(self, folder: Optional[str], output_formats: List[KVWriter]):
|
||||
self.name_to_value: Dict[str, float] = defaultdict(float) # values this iteration
|
||||
self.name_to_count: Dict[str, int] = defaultdict(int)
|
||||
self.name_to_excluded: Dict[str, Tuple[str, ...]] = {}
|
||||
def __init__(self, folder: Optional[str], output_formats: list[KVWriter]):
|
||||
self.name_to_value: dict[str, float] = defaultdict(float) # values this iteration
|
||||
self.name_to_count: dict[str, int] = defaultdict(int)
|
||||
self.name_to_excluded: dict[str, tuple[str, ...]] = {}
|
||||
self.level = INFO
|
||||
self.dir = folder
|
||||
self.output_formats = output_formats
|
||||
|
||||
@staticmethod
|
||||
def to_tuple(string_or_tuple: Optional[Union[str, Tuple[str, ...]]]) -> Tuple[str, ...]:
|
||||
def to_tuple(string_or_tuple: Optional[Union[str, tuple[str, ...]]]) -> tuple[str, ...]:
|
||||
"""
|
||||
Helper function to convert str to tuple of str.
|
||||
"""
|
||||
|
|
@ -500,7 +501,7 @@ class Logger:
|
|||
return string_or_tuple
|
||||
return (string_or_tuple,)
|
||||
|
||||
def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
|
||||
def record(self, key: str, value: Any, exclude: Optional[Union[str, tuple[str, ...]]] = None) -> None:
|
||||
"""
|
||||
Log a value of some diagnostic
|
||||
Call this once for each diagnostic quantity, each iteration
|
||||
|
|
@ -513,7 +514,7 @@ class Logger:
|
|||
self.name_to_value[key] = value
|
||||
self.name_to_excluded[key] = self.to_tuple(exclude)
|
||||
|
||||
def record_mean(self, key: str, value: Optional[float], exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
|
||||
def record_mean(self, key: str, value: Optional[float], exclude: Optional[Union[str, tuple[str, ...]]] = None) -> None:
|
||||
"""
|
||||
The same as record(), but if called many times, values averaged.
|
||||
|
||||
|
|
@ -624,7 +625,7 @@ class Logger:
|
|||
|
||||
# Misc
|
||||
# ----------------------------------------
|
||||
def _do_log(self, args: Tuple[Any, ...]) -> None:
|
||||
def _do_log(self, args: tuple[Any, ...]) -> None:
|
||||
"""
|
||||
log to the requested format outputs
|
||||
|
||||
|
|
@ -635,7 +636,7 @@ class Logger:
|
|||
_format.write_sequence(list(map(str, args)))
|
||||
|
||||
|
||||
def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] = None) -> Logger:
|
||||
def configure(folder: Optional[str] = None, format_strings: Optional[list[str]] = None) -> Logger:
|
||||
"""
|
||||
Configure the current logger.
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import json
|
|||
import os
|
||||
import time
|
||||
from glob import glob
|
||||
from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union
|
||||
from typing import Any, Optional, SupportsFloat, Union
|
||||
|
||||
import gymnasium as gym
|
||||
import pandas
|
||||
|
|
@ -33,8 +33,8 @@ class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||
env: gym.Env,
|
||||
filename: Optional[str] = None,
|
||||
allow_early_resets: bool = True,
|
||||
reset_keywords: Tuple[str, ...] = (),
|
||||
info_keywords: Tuple[str, ...] = (),
|
||||
reset_keywords: tuple[str, ...] = (),
|
||||
info_keywords: tuple[str, ...] = (),
|
||||
override_existing: bool = True,
|
||||
):
|
||||
super().__init__(env=env)
|
||||
|
|
@ -52,16 +52,16 @@ class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||
self.reset_keywords = reset_keywords
|
||||
self.info_keywords = info_keywords
|
||||
self.allow_early_resets = allow_early_resets
|
||||
self.rewards: List[float] = []
|
||||
self.rewards: list[float] = []
|
||||
self.needs_reset = True
|
||||
self.episode_returns: List[float] = []
|
||||
self.episode_lengths: List[int] = []
|
||||
self.episode_times: List[float] = []
|
||||
self.episode_returns: list[float] = []
|
||||
self.episode_lengths: list[int] = []
|
||||
self.episode_times: list[float] = []
|
||||
self.total_steps = 0
|
||||
# extra info about the current episode, that was passed in during reset()
|
||||
self.current_reset_info: Dict[str, Any] = {}
|
||||
self.current_reset_info: dict[str, Any] = {}
|
||||
|
||||
def reset(self, **kwargs) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
def reset(self, **kwargs) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""
|
||||
Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True
|
||||
|
||||
|
|
@ -82,7 +82,7 @@ class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||
self.current_reset_info[key] = value
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""
|
||||
Step the environment with the given action
|
||||
|
||||
|
|
@ -126,7 +126,7 @@ class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||
"""
|
||||
return self.total_steps
|
||||
|
||||
def get_episode_rewards(self) -> List[float]:
|
||||
def get_episode_rewards(self) -> list[float]:
|
||||
"""
|
||||
Returns the rewards of all the episodes
|
||||
|
||||
|
|
@ -134,7 +134,7 @@ class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||
"""
|
||||
return self.episode_returns
|
||||
|
||||
def get_episode_lengths(self) -> List[int]:
|
||||
def get_episode_lengths(self) -> list[int]:
|
||||
"""
|
||||
Returns the number of timesteps of all the episodes
|
||||
|
||||
|
|
@ -142,7 +142,7 @@ class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||
"""
|
||||
return self.episode_lengths
|
||||
|
||||
def get_episode_times(self) -> List[float]:
|
||||
def get_episode_times(self) -> list[float]:
|
||||
"""
|
||||
Returns the runtime in seconds of all the episodes
|
||||
|
||||
|
|
@ -175,8 +175,8 @@ class ResultsWriter:
|
|||
def __init__(
|
||||
self,
|
||||
filename: str = "",
|
||||
header: Optional[Dict[str, Union[float, str]]] = None,
|
||||
extra_keys: Tuple[str, ...] = (),
|
||||
header: Optional[dict[str, Union[float, str]]] = None,
|
||||
extra_keys: tuple[str, ...] = (),
|
||||
override_existing: bool = True,
|
||||
):
|
||||
if header is None:
|
||||
|
|
@ -200,7 +200,7 @@ class ResultsWriter:
|
|||
|
||||
self.file_handler.flush()
|
||||
|
||||
def write_row(self, epinfo: Dict[str, float]) -> None:
|
||||
def write_row(self, epinfo: dict[str, float]) -> None:
|
||||
"""
|
||||
Write row of monitor data to csv log file.
|
||||
|
||||
|
|
@ -217,7 +217,7 @@ class ResultsWriter:
|
|||
self.file_handler.close()
|
||||
|
||||
|
||||
def get_monitor_files(path: str) -> List[str]:
|
||||
def get_monitor_files(path: str) -> list[str]:
|
||||
"""
|
||||
get all the monitor files in the given path
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable, List, Optional
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import DTypeLike
|
||||
|
|
@ -153,11 +154,11 @@ class VectorizedActionNoise(ActionNoise):
|
|||
self._base_noise = base_noise
|
||||
|
||||
@property
|
||||
def noises(self) -> List[ActionNoise]:
|
||||
def noises(self) -> list[ActionNoise]:
|
||||
return self._noises
|
||||
|
||||
@noises.setter
|
||||
def noises(self, noises: List[ActionNoise]) -> None:
|
||||
def noises(self, noises: list[ActionNoise]) -> None:
|
||||
noises = list(noises) # raises TypeError if not iterable
|
||||
assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}."
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sys
|
|||
import time
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Optional, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -79,7 +79,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[BasePolicy]],
|
||||
policy: Union[str, type[BasePolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule],
|
||||
buffer_size: int = 1_000_000, # 1e6
|
||||
|
|
@ -87,13 +87,13 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
batch_size: int = 256,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
train_freq: Union[int, Tuple[int, str]] = (1, "step"),
|
||||
train_freq: Union[int, tuple[int, str]] = (1, "step"),
|
||||
gradient_steps: int = 1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
|
||||
optimize_memory_usage: bool = False,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
policy_kwargs: Optional[dict[str, Any]] = None,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
|
|
@ -105,7 +105,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
sde_sample_freq: int = -1,
|
||||
use_sde_at_warmup: bool = False,
|
||||
sde_support: bool = True,
|
||||
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
|
||||
supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
policy=policy,
|
||||
|
|
@ -256,7 +256,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
tb_log_name: str = "run",
|
||||
progress_bar: bool = False,
|
||||
) -> Tuple[int, BaseCallback]:
|
||||
) -> tuple[int, BaseCallback]:
|
||||
"""
|
||||
cf `BaseAlgorithm`.
|
||||
"""
|
||||
|
|
@ -362,7 +362,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
learning_starts: int,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
n_envs: int = 1,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Sample an action according to the exploration policy.
|
||||
This is either done by sampling the probability distribution of the policy,
|
||||
|
|
@ -442,10 +442,10 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
self,
|
||||
replay_buffer: ReplayBuffer,
|
||||
buffer_action: np.ndarray,
|
||||
new_obs: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
new_obs: Union[np.ndarray, dict[str, np.ndarray]],
|
||||
reward: np.ndarray,
|
||||
dones: np.ndarray,
|
||||
infos: List[Dict[str, Any]],
|
||||
infos: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Store transition in the replay buffer.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Optional, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -60,7 +60,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[ActorCriticPolicy]],
|
||||
policy: Union[str, type[ActorCriticPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule],
|
||||
n_steps: int,
|
||||
|
|
@ -71,17 +71,17 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
max_grad_norm: float,
|
||||
use_sde: bool,
|
||||
sde_sample_freq: int,
|
||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
rollout_buffer_class: Optional[type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[dict[str, Any]] = None,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
monitor_wrapper: bool = True,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
policy_kwargs: Optional[dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
|
||||
supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
policy=policy,
|
||||
|
|
@ -339,7 +339,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
|
||||
return self
|
||||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
|
||||
state_dicts = ["policy", "policy.optimizer"]
|
||||
|
||||
return state_dicts, []
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import copy
|
|||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Optional, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -64,12 +64,12 @@ class BaseModel(nn.Module):
|
|||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[dict[str, Any]] = None,
|
||||
features_extractor: Optional[BaseFeaturesExtractor] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -95,9 +95,9 @@ class BaseModel(nn.Module):
|
|||
|
||||
def _update_features_extractor(
|
||||
self,
|
||||
net_kwargs: Dict[str, Any],
|
||||
net_kwargs: dict[str, Any],
|
||||
features_extractor: Optional[BaseFeaturesExtractor] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Update the network keyword arguments and create a new features extractor object if needed.
|
||||
If a ``features_extractor`` object is passed, then it will be shared.
|
||||
|
|
@ -130,7 +130,7 @@ class BaseModel(nn.Module):
|
|||
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
|
||||
return features_extractor(preprocessed_obs)
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
def _get_constructor_parameters(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get data that need to be saved in order to re-create the model when loading it from disk.
|
||||
|
||||
|
|
@ -164,7 +164,7 @@ class BaseModel(nn.Module):
|
|||
th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
|
||||
|
||||
@classmethod
|
||||
def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel:
|
||||
def load(cls: type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel:
|
||||
"""
|
||||
Load model from path.
|
||||
|
||||
|
|
@ -210,7 +210,7 @@ class BaseModel(nn.Module):
|
|||
"""
|
||||
self.train(mode)
|
||||
|
||||
def is_vectorized_observation(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> bool:
|
||||
def is_vectorized_observation(self, observation: Union[np.ndarray, dict[str, np.ndarray]]) -> bool:
|
||||
"""
|
||||
Check whether or not the observation is vectorized,
|
||||
apply transposition to image (so that they are channel-first) if needed.
|
||||
|
|
@ -233,7 +233,7 @@ class BaseModel(nn.Module):
|
|||
)
|
||||
return vectorized_env
|
||||
|
||||
def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[PyTorchObs, bool]:
|
||||
def obs_to_tensor(self, observation: Union[np.ndarray, dict[str, np.ndarray]]) -> tuple[PyTorchObs, bool]:
|
||||
"""
|
||||
Convert an input observation to a PyTorch tensor that can be fed to a model.
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
|
@ -330,11 +330,11 @@ class BasePolicy(BaseModel, ABC):
|
|||
|
||||
def predict(
|
||||
self,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
observation: Union[np.ndarray, dict[str, np.ndarray]],
|
||||
state: Optional[tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Get the policy action from an observation (and optional hidden state).
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
|
@ -450,20 +450,20 @@ class ActorCriticPolicy(BasePolicy):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
|
||||
activation_fn: type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = 0.0,
|
||||
full_std: bool = True,
|
||||
use_expln: bool = False,
|
||||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[dict[str, Any]] = None,
|
||||
share_features_extractor: bool = True,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
|
|
@ -534,7 +534,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
def _get_constructor_parameters(self) -> dict[str, Any]:
|
||||
data = super()._get_constructor_parameters()
|
||||
|
||||
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) # type: ignore[arg-type, return-value]
|
||||
|
|
@ -633,7 +633,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
# Setup optimizer with initial learning rate
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) # type: ignore[call-arg]
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Forward pass in all the networks (actor and critic)
|
||||
|
||||
|
|
@ -659,7 +659,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
|
||||
def extract_features( # type: ignore[override]
|
||||
self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None
|
||||
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
||||
) -> Union[th.Tensor, tuple[th.Tensor, th.Tensor]]:
|
||||
"""
|
||||
Preprocess the observation if needed and extract features.
|
||||
|
||||
|
|
@ -716,7 +716,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
"""
|
||||
return self.get_distribution(observation).get_actions(deterministic=deterministic)
|
||||
|
||||
def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
|
||||
def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
|
||||
"""
|
||||
Evaluate actions according to the current policy,
|
||||
given the observations.
|
||||
|
|
@ -800,20 +800,20 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
|
||||
activation_fn: type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = 0.0,
|
||||
full_std: bool = True,
|
||||
use_expln: bool = False,
|
||||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[dict[str, Any]] = None,
|
||||
share_features_extractor: bool = True,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
observation_space,
|
||||
|
|
@ -873,20 +873,20 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
|
|||
observation_space: spaces.Dict,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
|
||||
activation_fn: type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = 0.0,
|
||||
full_std: bool = True,
|
||||
use_expln: bool = False,
|
||||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
features_extractor_kwargs: Optional[dict[str, Any]] = None,
|
||||
share_features_extractor: bool = True,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
observation_space,
|
||||
|
|
@ -942,10 +942,10 @@ class ContinuousCritic(BaseModel):
|
|||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Box,
|
||||
net_arch: List[int],
|
||||
net_arch: list[int],
|
||||
features_extractor: BaseFeaturesExtractor,
|
||||
features_dim: int,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
activation_fn: type[nn.Module] = nn.ReLU,
|
||||
normalize_images: bool = True,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
|
|
@ -961,14 +961,14 @@ class ContinuousCritic(BaseModel):
|
|||
|
||||
self.share_features_extractor = share_features_extractor
|
||||
self.n_critics = n_critics
|
||||
self.q_networks: List[nn.Module] = []
|
||||
self.q_networks: list[nn.Module] = []
|
||||
for idx in range(n_critics):
|
||||
q_net_list = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
|
||||
q_net = nn.Sequential(*q_net_list)
|
||||
self.add_module(f"qf{idx}", q_net)
|
||||
self.q_networks.append(q_net)
|
||||
|
||||
def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]:
|
||||
def forward(self, obs: th.Tensor, actions: th.Tensor) -> tuple[th.Tensor, ...]:
|
||||
# Learn the features extractor using the policy loss only
|
||||
# when the features_extractor is shared with the actor
|
||||
with th.set_grad_enabled(not self.share_features_extractor):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import warnings
|
||||
from typing import Dict, Tuple, Union
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -90,10 +90,10 @@ def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) ->
|
|||
|
||||
|
||||
def preprocess_obs(
|
||||
obs: Union[th.Tensor, Dict[str, th.Tensor]],
|
||||
obs: Union[th.Tensor, dict[str, th.Tensor]],
|
||||
observation_space: spaces.Space,
|
||||
normalize_images: bool = True,
|
||||
) -> Union[th.Tensor, Dict[str, th.Tensor]]:
|
||||
) -> Union[th.Tensor, dict[str, th.Tensor]]:
|
||||
"""
|
||||
Preprocess observation to be to a neural network.
|
||||
For images, it normalizes the values by dividing them by 255 (to have values in [0, 1])
|
||||
|
|
@ -107,7 +107,7 @@ def preprocess_obs(
|
|||
"""
|
||||
if isinstance(observation_space, spaces.Dict):
|
||||
# Do not modify by reference the original observation
|
||||
assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}"
|
||||
assert isinstance(obs, dict), f"Expected dict, got {type(obs)}"
|
||||
preprocessed_obs = {}
|
||||
for key, _obs in obs.items():
|
||||
preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
|
||||
|
|
@ -142,7 +142,7 @@ def preprocess_obs(
|
|||
|
||||
def get_obs_shape(
|
||||
observation_space: spaces.Space,
|
||||
) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]:
|
||||
) -> Union[tuple[int, ...], dict[str, tuple[int, ...]]]:
|
||||
"""
|
||||
Get the shape of the observation (useful for the buffers).
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
|
@ -29,7 +29,7 @@ def rolling_window(array: np.ndarray, window: int) -> np.ndarray:
|
|||
return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides)
|
||||
|
||||
|
||||
def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> Tuple[np.ndarray, np.ndarray]:
|
||||
def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Apply a function to the rolling window of 2 arrays
|
||||
|
||||
|
|
@ -44,7 +44,7 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callabl
|
|||
return var_1[window - 1 :], function_on_var2
|
||||
|
||||
|
||||
def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Decompose a data frame variable to x and ys
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray
|
|||
|
||||
|
||||
def plot_curves(
|
||||
xy_list: List[Tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: Tuple[int, int] = (8, 2)
|
||||
xy_list: list[tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: tuple[int, int] = (8, 2)
|
||||
) -> None:
|
||||
"""
|
||||
plot the curves
|
||||
|
|
@ -99,7 +99,7 @@ def plot_curves(
|
|||
|
||||
|
||||
def plot_results(
|
||||
dirs: List[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2)
|
||||
dirs: list[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: tuple[int, int] = (8, 2)
|
||||
) -> None:
|
||||
"""
|
||||
Plot the results using csv files from ``Monitor`` wrapper.
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class RunningMeanStd:
|
||||
def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
|
||||
def __init__(self, epsilon: float = 1e-4, shape: tuple[int, ...] = ()):
|
||||
"""
|
||||
Calculates the running mean and std of a data stream
|
||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import pathlib
|
|||
import pickle
|
||||
import warnings
|
||||
import zipfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
import torch as th
|
||||
|
|
@ -73,7 +73,7 @@ def is_json_serializable(item: Any) -> bool:
|
|||
return json_serializable
|
||||
|
||||
|
||||
def data_to_json(data: Dict[str, Any]) -> str:
|
||||
def data_to_json(data: dict[str, Any]) -> str:
|
||||
"""
|
||||
Turn data (class parameters) into a JSON string for storing
|
||||
|
||||
|
|
@ -128,7 +128,7 @@ def data_to_json(data: Dict[str, Any]) -> str:
|
|||
return json_string
|
||||
|
||||
|
||||
def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
def json_to_data(json_string: str, custom_objects: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
||||
"""
|
||||
Turn JSON serialization of class-parameters back into dictionary.
|
||||
|
||||
|
|
@ -293,9 +293,9 @@ def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: O
|
|||
|
||||
def save_to_zip_file(
|
||||
save_path: Union[str, pathlib.Path, io.BufferedIOBase],
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
pytorch_variables: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
pytorch_variables: Optional[dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -376,11 +376,11 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: in
|
|||
def load_from_zip_file(
|
||||
load_path: Union[str, pathlib.Path, io.BufferedIOBase],
|
||||
load_data: bool = True,
|
||||
custom_objects: Optional[Dict[str, Any]] = None,
|
||||
custom_objects: Optional[dict[str, Any]] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
verbose: int = 0,
|
||||
print_system_info: bool = False,
|
||||
) -> Tuple[Optional[Dict[str, Any]], TensorDict, Optional[TensorDict]]:
|
||||
) -> tuple[Optional[dict[str, Any]], TensorDict, Optional[TensorDict]]:
|
||||
"""
|
||||
Load model data from a .zip archive
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Any, Callable, Dict, Iterable, Optional
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
|
@ -67,7 +68,7 @@ class RMSpropTFLike(Optimizer):
|
|||
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault("momentum", 0)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import gymnasium as gym
|
||||
import torch as th
|
||||
|
|
@ -110,13 +110,13 @@ class NatureCNN(BaseFeaturesExtractor):
|
|||
def create_mlp(
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
net_arch: List[int],
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
net_arch: list[int],
|
||||
activation_fn: type[nn.Module] = nn.ReLU,
|
||||
squash_output: bool = False,
|
||||
with_bias: bool = True,
|
||||
pre_linear_modules: Optional[List[Type[nn.Module]]] = None,
|
||||
post_linear_modules: Optional[List[Type[nn.Module]]] = None,
|
||||
) -> List[nn.Module]:
|
||||
pre_linear_modules: Optional[list[type[nn.Module]]] = None,
|
||||
post_linear_modules: Optional[list[type[nn.Module]]] = None,
|
||||
) -> list[nn.Module]:
|
||||
"""
|
||||
Create a multi layer perceptron (MLP), which is
|
||||
a collection of fully-connected layers each followed by an activation function.
|
||||
|
|
@ -211,14 +211,14 @@ class MlpExtractor(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
feature_dim: int,
|
||||
net_arch: Union[List[int], Dict[str, List[int]]],
|
||||
activation_fn: Type[nn.Module],
|
||||
net_arch: Union[list[int], dict[str, list[int]]],
|
||||
activation_fn: type[nn.Module],
|
||||
device: Union[th.device, str] = "auto",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
device = get_device(device)
|
||||
policy_net: List[nn.Module] = []
|
||||
value_net: List[nn.Module] = []
|
||||
policy_net: list[nn.Module] = []
|
||||
value_net: list[nn.Module] = []
|
||||
last_layer_dim_pi = feature_dim
|
||||
last_layer_dim_vf = feature_dim
|
||||
|
||||
|
|
@ -249,7 +249,7 @@ class MlpExtractor(nn.Module):
|
|||
self.policy_net = nn.Sequential(*policy_net).to(device)
|
||||
self.value_net = nn.Sequential(*value_net).to(device)
|
||||
|
||||
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
def forward(self, features: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
:return: latent_policy, latent_value of the specified network.
|
||||
If all layers are shared, then ``latent_policy == latent_value``
|
||||
|
|
@ -288,7 +288,7 @@ class CombinedExtractor(BaseFeaturesExtractor):
|
|||
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
|
||||
super().__init__(observation_space, features_dim=1)
|
||||
|
||||
extractors: Dict[str, nn.Module] = {}
|
||||
extractors: dict[str, nn.Module] = {}
|
||||
|
||||
total_concat_size = 0
|
||||
for key, subspace in observation_space.spaces.items():
|
||||
|
|
@ -313,7 +313,7 @@ class CombinedExtractor(BaseFeaturesExtractor):
|
|||
return th.cat(encoded_tensor_list, dim=1)
|
||||
|
||||
|
||||
def get_actor_critic_arch(net_arch: Union[List[int], Dict[str, List[int]]]) -> Tuple[List[int], List[int]]:
|
||||
def get_actor_critic_arch(net_arch: Union[list[int], dict[str, list[int]]]) -> tuple[list[int], list[int]]:
|
||||
"""
|
||||
Get the actor and critic network architectures for off-policy actor-critic algorithms (SAC, TD3, DDPG).
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Common aliases for type hints"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Protocol, SupportsFloat, Union
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -13,14 +13,14 @@ if TYPE_CHECKING:
|
|||
from stable_baselines3.common.vec_env import VecEnv
|
||||
|
||||
GymEnv = Union[gym.Env, "VecEnv"]
|
||||
GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
|
||||
GymResetReturn = Tuple[GymObs, Dict]
|
||||
AtariResetReturn = Tuple[np.ndarray, Dict[str, Any]]
|
||||
GymStepReturn = Tuple[GymObs, float, bool, bool, Dict]
|
||||
AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]]
|
||||
TensorDict = Dict[str, th.Tensor]
|
||||
OptimizerStateDict = Dict[str, Any]
|
||||
MaybeCallback = Union[None, Callable, List["BaseCallback"], "BaseCallback"]
|
||||
GymObs = Union[tuple, dict[str, Any], np.ndarray, int]
|
||||
GymResetReturn = tuple[GymObs, dict]
|
||||
AtariResetReturn = tuple[np.ndarray, dict[str, Any]]
|
||||
GymStepReturn = tuple[GymObs, float, bool, bool, dict]
|
||||
AtariStepReturn = tuple[np.ndarray, SupportsFloat, bool, bool, dict[str, Any]]
|
||||
TensorDict = dict[str, th.Tensor]
|
||||
OptimizerStateDict = dict[str, Any]
|
||||
MaybeCallback = Union[None, Callable, list["BaseCallback"], "BaseCallback"]
|
||||
PyTorchObs = Union[th.Tensor, TensorDict]
|
||||
|
||||
# A schedule takes the remaining progress as input
|
||||
|
|
@ -81,11 +81,11 @@ class TrainFreq(NamedTuple):
|
|||
class PolicyPredictor(Protocol):
|
||||
def predict(
|
||||
self,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
observation: Union[np.ndarray, dict[str, np.ndarray]],
|
||||
state: Optional[tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Get the policy action from an observation (and optional hidden state).
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
|
|
|||
|
|
@ -4,8 +4,9 @@ import platform
|
|||
import random
|
||||
import re
|
||||
from collections import deque
|
||||
from collections.abc import Iterable
|
||||
from itertools import zip_longest
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
import gymnasium as gym
|
||||
|
|
@ -415,7 +416,7 @@ def safe_mean(arr: Union[np.ndarray, list, deque]) -> float:
|
|||
return np.nan if len(arr) == 0 else float(np.mean(arr)) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> List[th.Tensor]:
|
||||
def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> list[th.Tensor]:
|
||||
"""
|
||||
Extract parameters from the state dict of ``model``
|
||||
if the name contains one of the strings in ``included_names``.
|
||||
|
|
@ -473,7 +474,7 @@ def polyak_update(
|
|||
th.add(target_param.data, param.data, alpha=tau, out=target_param.data)
|
||||
|
||||
|
||||
def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.device) -> Union[th.Tensor, TensorDict]:
|
||||
def obs_as_tensor(obs: Union[np.ndarray, dict[str, np.ndarray]], device: th.device) -> Union[th.Tensor, TensorDict]:
|
||||
"""
|
||||
Moves the observation to the given device.
|
||||
|
||||
|
|
@ -517,7 +518,7 @@ def should_collect_more_steps(
|
|||
)
|
||||
|
||||
|
||||
def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
|
||||
def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:
|
||||
"""
|
||||
Retrieve system and python env info for the current system.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from copy import deepcopy
|
||||
from typing import Optional, Type, TypeVar
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
|
||||
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
|
|
@ -16,7 +16,7 @@ from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
|
|||
VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper)
|
||||
|
||||
|
||||
def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]:
|
||||
def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]:
|
||||
"""
|
||||
Retrieve a ``VecEnvWrapper`` object by recursively searching.
|
||||
|
||||
|
|
@ -42,7 +42,7 @@ def unwrap_vec_normalize(env: VecEnv) -> Optional[VecNormalize]:
|
|||
return unwrap_vec_wrapper(env, VecNormalize)
|
||||
|
||||
|
||||
def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
|
||||
def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: type[VecEnvWrapper]) -> bool:
|
||||
"""
|
||||
Check if an environment is already wrapped in a given ``VecEnvWrapper``.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
import gymnasium as gym
|
||||
|
|
@ -14,10 +15,10 @@ from gymnasium import spaces
|
|||
VecEnvIndices = Union[None, int, Iterable[int]]
|
||||
# VecEnvObs is what is returned by the reset() method
|
||||
# it contains the observation for each env
|
||||
VecEnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]]
|
||||
VecEnvObs = Union[np.ndarray, dict[str, np.ndarray], tuple[np.ndarray, ...]]
|
||||
# VecEnvStepReturn is what is returned by the step() method
|
||||
# it contains the observation, reward, done, info for each env
|
||||
VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]]
|
||||
VecEnvStepReturn = tuple[VecEnvObs, np.ndarray, np.ndarray, list[dict]]
|
||||
|
||||
|
||||
def tile_images(images_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
|
||||
|
|
@ -65,11 +66,11 @@ class VecEnv(ABC):
|
|||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
# store info returned by the reset method
|
||||
self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)]
|
||||
self.reset_infos: list[dict[str, Any]] = [{} for _ in range(num_envs)]
|
||||
# seeds to be used in the next call to env.reset()
|
||||
self._seeds: List[Optional[int]] = [None for _ in range(num_envs)]
|
||||
self._seeds: list[Optional[int]] = [None for _ in range(num_envs)]
|
||||
# options to be used in the next call to env.reset()
|
||||
self._options: List[Dict[str, Any]] = [{} for _ in range(num_envs)]
|
||||
self._options: list[dict[str, Any]] = [{} for _ in range(num_envs)]
|
||||
|
||||
try:
|
||||
render_modes = self.get_attr("render_mode")
|
||||
|
|
@ -147,7 +148,7 @@ class VecEnv(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
|
||||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
|
||||
"""
|
||||
Return attribute from vectorized environment.
|
||||
|
||||
|
|
@ -170,7 +171,7 @@ class VecEnv(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
|
||||
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]:
|
||||
"""
|
||||
Call instance methods of vectorized environments.
|
||||
|
||||
|
|
@ -183,7 +184,7 @@ class VecEnv(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
|
||||
def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]:
|
||||
"""
|
||||
Check if environments are wrapped with a given wrapper.
|
||||
|
||||
|
|
@ -292,7 +293,7 @@ class VecEnv(ABC):
|
|||
self._seeds = [seed + idx for idx in range(self.num_envs)]
|
||||
return self._seeds
|
||||
|
||||
def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None:
|
||||
def set_options(self, options: Optional[Union[list[dict], dict]] = None) -> None:
|
||||
"""
|
||||
Set environment options for all environments.
|
||||
If a dict is passed instead of a list, the same options will be used for all environments.
|
||||
|
|
@ -379,7 +380,7 @@ class VecEnvWrapper(VecEnv):
|
|||
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
|
||||
return self.venv.seed(seed)
|
||||
|
||||
def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None:
|
||||
def set_options(self, options: Optional[Union[list[dict], dict]] = None) -> None:
|
||||
return self.venv.set_options(options)
|
||||
|
||||
def close(self) -> None:
|
||||
|
|
@ -391,16 +392,16 @@ class VecEnvWrapper(VecEnv):
|
|||
def get_images(self) -> Sequence[Optional[np.ndarray]]:
|
||||
return self.venv.get_images()
|
||||
|
||||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
|
||||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
|
||||
return self.venv.get_attr(attr_name, indices)
|
||||
|
||||
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
|
||||
return self.venv.set_attr(attr_name, value, indices)
|
||||
|
||||
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
|
||||
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]:
|
||||
return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)
|
||||
|
||||
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
|
||||
def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]:
|
||||
return self.venv.env_is_wrapped(wrapper_class, indices=indices)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
|
|
@ -419,7 +420,7 @@ class VecEnvWrapper(VecEnv):
|
|||
|
||||
return self.getattr_recursive(name)
|
||||
|
||||
def _get_all_attributes(self) -> Dict[str, Any]:
|
||||
def _get_all_attributes(self) -> dict[str, Any]:
|
||||
"""Get all (inherited) instance and class attributes
|
||||
|
||||
:return: all_attributes
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import warnings
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Type
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -26,7 +27,7 @@ class DummyVecEnv(VecEnv):
|
|||
|
||||
actions: np.ndarray
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]):
|
||||
def __init__(self, env_fns: list[Callable[[], gym.Env]]):
|
||||
self.envs = [_patch_env(fn()) for fn in env_fns]
|
||||
if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs):
|
||||
raise ValueError(
|
||||
|
|
@ -46,7 +47,7 @@ class DummyVecEnv(VecEnv):
|
|||
self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys])
|
||||
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
|
||||
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
||||
self.buf_infos: List[Dict[str, Any]] = [{} for _ in range(self.num_envs)]
|
||||
self.buf_infos: list[dict[str, Any]] = [{} for _ in range(self.num_envs)]
|
||||
self.metadata = env.metadata
|
||||
|
||||
def step_async(self, actions: np.ndarray) -> None:
|
||||
|
|
@ -112,7 +113,7 @@ class DummyVecEnv(VecEnv):
|
|||
def _obs_from_buf(self) -> VecEnvObs:
|
||||
return dict_to_obs(self.observation_space, deepcopy(self.buf_obs))
|
||||
|
||||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
|
||||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
|
||||
"""Return attribute from vectorized environment (see base class)."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
return [env_i.get_wrapper_attr(attr_name) for env_i in target_envs]
|
||||
|
|
@ -123,12 +124,12 @@ class DummyVecEnv(VecEnv):
|
|||
for env_i in target_envs:
|
||||
setattr(env_i, attr_name, value)
|
||||
|
||||
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
|
||||
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]:
|
||||
"""Call instance methods of vectorized environments."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
return [env_i.get_wrapper_attr(method_name)(*method_args, **method_kwargs) for env_i in target_envs]
|
||||
|
||||
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
|
||||
def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]:
|
||||
"""Check if worker environments are wrapped with a given wrapper"""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
# Import here to avoid a circular import
|
||||
|
|
@ -136,6 +137,6 @@ class DummyVecEnv(VecEnv):
|
|||
|
||||
return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
|
||||
|
||||
def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
|
||||
def _get_target_envs(self, indices: VecEnvIndices) -> list[gym.Env]:
|
||||
indices = self._get_indices(indices)
|
||||
return [self.envs[i] for i in indices]
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
import warnings
|
||||
from typing import Any, Dict, Generic, List, Mapping, Optional, Tuple, TypeVar, Union
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Generic, Optional, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
||||
|
||||
TObs = TypeVar("TObs", np.ndarray, Dict[str, np.ndarray])
|
||||
TObs = TypeVar("TObs", np.ndarray, dict[str, np.ndarray])
|
||||
|
||||
|
||||
class StackedObservations(Generic[TObs]):
|
||||
|
|
@ -66,7 +67,7 @@ class StackedObservations(Generic[TObs]):
|
|||
@staticmethod
|
||||
def compute_stacking(
|
||||
n_stack: int, observation_space: spaces.Box, channels_order: Optional[str] = None
|
||||
) -> Tuple[bool, int, Tuple[int, ...], int]:
|
||||
) -> tuple[bool, int, tuple[int, ...], int]:
|
||||
"""
|
||||
Calculates the parameters in order to stack observations
|
||||
|
||||
|
|
@ -119,8 +120,8 @@ class StackedObservations(Generic[TObs]):
|
|||
self,
|
||||
observations: TObs,
|
||||
dones: np.ndarray,
|
||||
infos: List[Dict[str, Any]],
|
||||
) -> Tuple[TObs, List[Dict[str, Any]]]:
|
||||
infos: list[dict[str, Any]],
|
||||
) -> tuple[TObs, list[dict[str, Any]]]:
|
||||
"""
|
||||
Add the observations to the stack and use the dones to update the infos.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import multiprocessing as mp
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -26,7 +27,7 @@ def _worker(
|
|||
|
||||
parent_remote.close()
|
||||
env = _patch_env(env_fn_wrapper.var())
|
||||
reset_info: Optional[Dict[str, Any]] = {}
|
||||
reset_info: Optional[dict[str, Any]] = {}
|
||||
while True:
|
||||
try:
|
||||
cmd, data = remote.recv()
|
||||
|
|
@ -91,7 +92,7 @@ class SubprocVecEnv(VecEnv):
|
|||
Defaults to 'forkserver' on available platforms, and 'spawn' otherwise.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[str] = None):
|
||||
def __init__(self, env_fns: list[Callable[[], gym.Env]], start_method: Optional[str] = None):
|
||||
self.waiting = False
|
||||
self.closed = False
|
||||
n_envs = len(env_fns)
|
||||
|
|
@ -164,7 +165,7 @@ class SubprocVecEnv(VecEnv):
|
|||
outputs = [pipe.recv() for pipe in self.remotes]
|
||||
return outputs
|
||||
|
||||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
|
||||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
|
||||
"""Return attribute from vectorized environment (see base class)."""
|
||||
target_remotes = self._get_target_remotes(indices)
|
||||
for remote in target_remotes:
|
||||
|
|
@ -179,21 +180,21 @@ class SubprocVecEnv(VecEnv):
|
|||
for remote in target_remotes:
|
||||
remote.recv()
|
||||
|
||||
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
|
||||
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]:
|
||||
"""Call instance methods of vectorized environments."""
|
||||
target_remotes = self._get_target_remotes(indices)
|
||||
for remote in target_remotes:
|
||||
remote.send(("env_method", (method_name, method_args, method_kwargs)))
|
||||
return [remote.recv() for remote in target_remotes]
|
||||
|
||||
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
|
||||
def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]:
|
||||
"""Check if worker environments are wrapped with a given wrapper"""
|
||||
target_remotes = self._get_target_remotes(indices)
|
||||
for remote in target_remotes:
|
||||
remote.send(("is_wrapped", wrapper_class))
|
||||
return [remote.recv() for remote in target_remotes]
|
||||
|
||||
def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]:
|
||||
def _get_target_remotes(self, indices: VecEnvIndices) -> list[Any]:
|
||||
"""
|
||||
Get the connection object needed to communicate with the wanted
|
||||
envs that are in subprocesses.
|
||||
|
|
@ -205,7 +206,7 @@ class SubprocVecEnv(VecEnv):
|
|||
return [self.remotes[i] for i in indices]
|
||||
|
||||
|
||||
def _stack_obs(obs_list: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
|
||||
def _stack_obs(obs_list: Union[list[VecEnvObs], tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
|
||||
"""
|
||||
Stack observations (convert from a list of single env obs to a stack of obs),
|
||||
depending on the observation space.
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
Helpers for dealing with vectorized environments.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
|
@ -11,7 +11,7 @@ from stable_baselines3.common.preprocessing import check_for_nested_spaces
|
|||
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
|
||||
|
||||
|
||||
def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs:
|
||||
def dict_to_obs(obs_space: spaces.Space, obs_dict: dict[Any, np.ndarray]) -> VecEnvObs:
|
||||
"""
|
||||
Convert an internal representation raw_obs into the appropriate type
|
||||
specified by space.
|
||||
|
|
@ -32,7 +32,7 @@ def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> Vec
|
|||
return obs_dict[None]
|
||||
|
||||
|
||||
def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]:
|
||||
def obs_space_info(obs_space: spaces.Space) -> tuple[list[str], dict[Any, tuple[int, ...]], dict[Any, np.dtype]]:
|
||||
"""
|
||||
Get dict-structured information about a gym.Space.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import warnings
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
|
@ -48,7 +47,7 @@ class VecCheckNan(VecEnvWrapper):
|
|||
self._observations = observations
|
||||
return observations
|
||||
|
||||
def check_array_value(self, name: str, value: np.ndarray) -> List[Tuple[str, str]]:
|
||||
def check_array_value(self, name: str, value: np.ndarray) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Check for inf and NaN for a single numpy array.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
|
@ -29,17 +30,17 @@ class VecFrameStack(VecEnvWrapper):
|
|||
|
||||
def step_wait(
|
||||
self,
|
||||
) -> Tuple[
|
||||
Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
) -> tuple[
|
||||
Union[np.ndarray, dict[str, np.ndarray]],
|
||||
np.ndarray,
|
||||
np.ndarray,
|
||||
List[Dict[str, Any]],
|
||||
list[dict[str, Any]],
|
||||
]:
|
||||
observations, rewards, dones, infos = self.venv.step_wait()
|
||||
observations, infos = self.stacked_obs.update(observations, dones, infos) # type: ignore[arg-type]
|
||||
return observations, rewards, dones, infos
|
||||
|
||||
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
||||
def reset(self) -> Union[np.ndarray, dict[str, np.ndarray]]:
|
||||
"""
|
||||
Reset all environments
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import time
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ class VecMonitor(VecEnvWrapper):
|
|||
self,
|
||||
venv: VecEnv,
|
||||
filename: Optional[str] = None,
|
||||
info_keywords: Tuple[str, ...] = (),
|
||||
info_keywords: tuple[str, ...] = (),
|
||||
):
|
||||
# Avoid circular import
|
||||
from stable_baselines3.common.monitor import Monitor, ResultsWriter
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import inspect
|
||||
import pickle
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
|
@ -29,8 +29,8 @@ class VecNormalize(VecEnvWrapper):
|
|||
If not specified, all keys will be normalized.
|
||||
"""
|
||||
|
||||
obs_spaces: Dict[str, spaces.Space]
|
||||
old_obs: Union[np.ndarray, Dict[str, np.ndarray]]
|
||||
obs_spaces: dict[str, spaces.Space]
|
||||
old_obs: Union[np.ndarray, dict[str, np.ndarray]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -42,7 +42,7 @@ class VecNormalize(VecEnvWrapper):
|
|||
clip_reward: float = 10.0,
|
||||
gamma: float = 0.99,
|
||||
epsilon: float = 1e-8,
|
||||
norm_obs_keys: Optional[List[str]] = None,
|
||||
norm_obs_keys: Optional[list[str]] = None,
|
||||
):
|
||||
VecEnvWrapper.__init__(self, venv)
|
||||
|
||||
|
|
@ -125,7 +125,7 @@ class VecNormalize(VecEnvWrapper):
|
|||
f"not {self.observation_space}"
|
||||
)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
"""
|
||||
Gets state for pickling.
|
||||
|
||||
|
|
@ -138,7 +138,7 @@ class VecNormalize(VecEnvWrapper):
|
|||
del state["returns"]
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
"""
|
||||
Restores pickled state.
|
||||
|
||||
|
|
@ -229,7 +229,7 @@ class VecNormalize(VecEnvWrapper):
|
|||
"""
|
||||
return (obs * np.sqrt(obs_rms.var + self.epsilon)) + obs_rms.mean
|
||||
|
||||
def normalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
||||
def normalize_obs(self, obs: Union[np.ndarray, dict[str, np.ndarray]]) -> Union[np.ndarray, dict[str, np.ndarray]]:
|
||||
"""
|
||||
Normalize observations using this VecNormalize's observations statistics.
|
||||
Calling this method does not update statistics.
|
||||
|
|
@ -254,9 +254,11 @@ class VecNormalize(VecEnvWrapper):
|
|||
"""
|
||||
if self.norm_reward:
|
||||
reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward)
|
||||
return reward
|
||||
# Note: we cast to float32 as it correspond to Python default float type
|
||||
# This cast is needed because `RunningMeanStd` keeps stats in float64
|
||||
return reward.astype(np.float32)
|
||||
|
||||
def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
||||
def unnormalize_obs(self, obs: Union[np.ndarray, dict[str, np.ndarray]]) -> Union[np.ndarray, dict[str, np.ndarray]]:
|
||||
# Avoid modifying by reference the original object
|
||||
obs_ = deepcopy(obs)
|
||||
if self.norm_obs:
|
||||
|
|
@ -274,7 +276,7 @@ class VecNormalize(VecEnvWrapper):
|
|||
return reward * np.sqrt(self.ret_rms.var + self.epsilon)
|
||||
return reward
|
||||
|
||||
def get_original_obs(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
||||
def get_original_obs(self) -> Union[np.ndarray, dict[str, np.ndarray]]:
|
||||
"""
|
||||
Returns an unnormalized version of the observations from the most recent
|
||||
step or reset.
|
||||
|
|
@ -287,7 +289,7 @@ class VecNormalize(VecEnvWrapper):
|
|||
"""
|
||||
return self.old_reward.copy()
|
||||
|
||||
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
||||
def reset(self) -> Union[np.ndarray, dict[str, np.ndarray]]:
|
||||
"""
|
||||
Reset all environments
|
||||
:return: first observation of the episode
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from copy import deepcopy
|
||||
from typing import Dict, Union
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
|
@ -73,7 +73,7 @@ class VecTransposeImage(VecEnvWrapper):
|
|||
return np.transpose(image, (2, 0, 1))
|
||||
return np.transpose(image, (0, 3, 1, 2))
|
||||
|
||||
def transpose_observations(self, observations: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]:
|
||||
def transpose_observations(self, observations: Union[np.ndarray, dict]) -> Union[np.ndarray, dict]:
|
||||
"""
|
||||
Transpose (if needed) and return new observations.
|
||||
|
||||
|
|
@ -106,7 +106,7 @@ class VecTransposeImage(VecEnvWrapper):
|
|||
assert isinstance(observations, (np.ndarray, dict))
|
||||
return self.transpose_observations(observations), rewards, dones, infos
|
||||
|
||||
def reset(self) -> Union[np.ndarray, Dict]:
|
||||
def reset(self) -> Union[np.ndarray, dict]:
|
||||
"""
|
||||
Reset all environments
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import os.path
|
||||
from typing import Callable, List
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
from gymnasium import error, logger
|
||||
|
|
@ -109,7 +109,7 @@ class VecVideoRecorder(VecEnvWrapper):
|
|||
assert self.recording, "Cannot capture a frame, recording wasn't started."
|
||||
|
||||
frame = self.env.render()
|
||||
if isinstance(frame, List):
|
||||
if isinstance(frame, list):
|
||||
frame = frame[-1]
|
||||
|
||||
if isinstance(frame, np.ndarray):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Optional, TypeVar, Union
|
||||
|
||||
import torch as th
|
||||
|
||||
|
|
@ -55,7 +55,7 @@ class DDPG(TD3):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[TD3Policy]],
|
||||
policy: Union[str, type[TD3Policy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 1e-3,
|
||||
buffer_size: int = 1_000_000, # 1e6
|
||||
|
|
@ -63,14 +63,14 @@ class DDPG(TD3):
|
|||
batch_size: int = 256,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
train_freq: Union[int, Tuple[int, str]] = 1,
|
||||
train_freq: Union[int, tuple[int, str]] = 1,
|
||||
gradient_steps: int = 1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
|
||||
optimize_memory_usage: bool = False,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
policy_kwargs: Optional[dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import warnings
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, ClassVar, Optional, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -62,7 +62,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
|
||||
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
|
||||
"MlpPolicy": MlpPolicy,
|
||||
"CnnPolicy": CnnPolicy,
|
||||
"MultiInputPolicy": MultiInputPolicy,
|
||||
|
|
@ -75,7 +75,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[DQNPolicy]],
|
||||
policy: Union[str, type[DQNPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 1e-4,
|
||||
buffer_size: int = 1_000_000, # 1e6
|
||||
|
|
@ -83,10 +83,10 @@ class DQN(OffPolicyAlgorithm):
|
|||
batch_size: int = 32,
|
||||
tau: float = 1.0,
|
||||
gamma: float = 0.99,
|
||||
train_freq: Union[int, Tuple[int, str]] = 4,
|
||||
train_freq: Union[int, tuple[int, str]] = 4,
|
||||
gradient_steps: int = 1,
|
||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
|
||||
optimize_memory_usage: bool = False,
|
||||
target_update_interval: int = 10000,
|
||||
exploration_fraction: float = 0.1,
|
||||
|
|
@ -95,7 +95,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
max_grad_norm: float = 10,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
policy_kwargs: Optional[dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
|
|
@ -227,11 +227,11 @@ class DQN(OffPolicyAlgorithm):
|
|||
|
||||
def predict(
|
||||
self,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
observation: Union[np.ndarray, dict[str, np.ndarray]],
|
||||
state: Optional[tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Overrides the base_class predict function to include epsilon-greedy exploration.
|
||||
|
||||
|
|
@ -273,10 +273,10 @@ class DQN(OffPolicyAlgorithm):
|
|||
progress_bar=progress_bar,
|
||||
)
|
||||
|
||||
def _excluded_save_params(self) -> List[str]:
|
||||
def _excluded_save_params(self) -> list[str]:
|
||||
return [*super()._excluded_save_params(), "q_net", "q_net_target"]
|
||||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
|
||||
state_dicts = ["policy", "policy.optimizer"]
|
||||
|
||||
return state_dicts, []
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch as th
|
||||
from gymnasium import spaces
|
||||
|
|
@ -35,8 +35,8 @@ class QNetwork(BasePolicy):
|
|||
action_space: spaces.Discrete,
|
||||
features_extractor: BaseFeaturesExtractor,
|
||||
features_dim: int,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
net_arch: Optional[list[int]] = None,
|
||||
activation_fn: type[nn.Module] = nn.ReLU,
|
||||
normalize_images: bool = True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
|
|
@ -71,7 +71,7 @@ class QNetwork(BasePolicy):
|
|||
action = q_values.argmax(dim=1).reshape(-1)
|
||||
return action
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
def _get_constructor_parameters(self) -> dict[str, Any]:
|
||||
data = super()._get_constructor_parameters()
|
||||
|
||||
data.update(
|
||||
|
|
@ -113,13 +113,13 @@ class DQNPolicy(BasePolicy):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Discrete,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
net_arch: Optional[list[int]] = None,
|
||||
activation_fn: type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
observation_space,
|
||||
|
|
@ -183,7 +183,7 @@ class DQNPolicy(BasePolicy):
|
|||
def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
|
||||
return self.q_net._predict(obs, deterministic=deterministic)
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
def _get_constructor_parameters(self) -> dict[str, Any]:
|
||||
data = super()._get_constructor_parameters()
|
||||
|
||||
data.update(
|
||||
|
|
@ -237,13 +237,13 @@ class CnnPolicy(DQNPolicy):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Discrete,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
net_arch: Optional[list[int]] = None,
|
||||
activation_fn: type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
observation_space,
|
||||
|
|
@ -282,13 +282,13 @@ class MultiInputPolicy(DQNPolicy):
|
|||
observation_space: spaces.Dict,
|
||||
action_space: spaces.Discrete,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
net_arch: Optional[list[int]] = None,
|
||||
activation_fn: type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
features_extractor_kwargs: Optional[dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
observation_space,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import copy
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -98,7 +98,7 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
self.ep_length = np.zeros((self.buffer_size, self.n_envs), dtype=np.int64)
|
||||
self._current_ep_start = np.zeros(self.n_envs, dtype=np.int64)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
"""
|
||||
Gets state for pickling.
|
||||
|
||||
|
|
@ -109,7 +109,7 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
del state["env"]
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
"""
|
||||
Restores pickled state.
|
||||
|
||||
|
|
@ -134,12 +134,12 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
|
||||
def add( # type: ignore[override]
|
||||
self,
|
||||
obs: Dict[str, np.ndarray],
|
||||
next_obs: Dict[str, np.ndarray],
|
||||
obs: dict[str, np.ndarray],
|
||||
next_obs: dict[str, np.ndarray],
|
||||
action: np.ndarray,
|
||||
reward: np.ndarray,
|
||||
done: np.ndarray,
|
||||
infos: List[Dict[str, Any]],
|
||||
infos: list[dict[str, Any]],
|
||||
) -> None:
|
||||
# When the buffer is full, we rewrite on old episodes. When we start to
|
||||
# rewrite on an old episodes, we want the whole old episode to be deleted
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import warnings
|
||||
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
|
||||
from typing import Any, ClassVar, Optional, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -71,7 +71,7 @@ class PPO(OnPolicyAlgorithm):
|
|||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
|
||||
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
|
||||
"MlpPolicy": ActorCriticPolicy,
|
||||
"CnnPolicy": ActorCriticCnnPolicy,
|
||||
"MultiInputPolicy": MultiInputActorCriticPolicy,
|
||||
|
|
@ -79,7 +79,7 @@ class PPO(OnPolicyAlgorithm):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[ActorCriticPolicy]],
|
||||
policy: Union[str, type[ActorCriticPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 3e-4,
|
||||
n_steps: int = 2048,
|
||||
|
|
@ -95,12 +95,12 @@ class PPO(OnPolicyAlgorithm):
|
|||
max_grad_norm: float = 0.5,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
rollout_buffer_class: Optional[type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[dict[str, Any]] = None,
|
||||
target_kl: Optional[float] = None,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
policy_kwargs: Optional[dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch as th
|
||||
from gymnasium import spaces
|
||||
|
|
@ -51,10 +51,10 @@ class Actor(BasePolicy):
|
|||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Box,
|
||||
net_arch: List[int],
|
||||
net_arch: list[int],
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
activation_fn: type[nn.Module] = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
full_std: bool = True,
|
||||
|
|
@ -102,7 +102,7 @@ class Actor(BasePolicy):
|
|||
self.mu = nn.Linear(last_layer_dim, action_dim)
|
||||
self.log_std = nn.Linear(last_layer_dim, action_dim) # type: ignore[assignment]
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
def _get_constructor_parameters(self) -> dict[str, Any]:
|
||||
data = super()._get_constructor_parameters()
|
||||
|
||||
data.update(
|
||||
|
|
@ -144,7 +144,7 @@ class Actor(BasePolicy):
|
|||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
|
||||
|
||||
def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||
def get_action_dist_params(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tensor, dict[str, th.Tensor]]:
|
||||
"""
|
||||
Get the parameters for the action distribution.
|
||||
|
||||
|
|
@ -169,7 +169,7 @@ class Actor(BasePolicy):
|
|||
# Note: the action is squashed
|
||||
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
||||
|
||||
def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]:
|
||||
def action_log_prob(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tensor]:
|
||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||
# return action and associated log prob
|
||||
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
||||
|
|
@ -216,17 +216,17 @@ class SACPolicy(BasePolicy):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Box,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
|
||||
activation_fn: type[nn.Module] = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
use_expln: bool = False,
|
||||
clip_mean: float = 2.0,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = False,
|
||||
):
|
||||
|
|
@ -309,7 +309,7 @@ class SACPolicy(BasePolicy):
|
|||
# Target networks should always be in eval mode
|
||||
self.critic_target.set_training_mode(False)
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
def _get_constructor_parameters(self) -> dict[str, Any]:
|
||||
data = super()._get_constructor_parameters()
|
||||
|
||||
data.update(
|
||||
|
|
@ -400,17 +400,17 @@ class CnnPolicy(SACPolicy):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Box,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
|
||||
activation_fn: type[nn.Module] = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
use_expln: bool = False,
|
||||
clip_mean: float = 2.0,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = False,
|
||||
):
|
||||
|
|
@ -466,17 +466,17 @@ class MultiInputPolicy(SACPolicy):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Box,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
|
||||
activation_fn: type[nn.Module] = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
use_expln: bool = False,
|
||||
clip_mean: float = 2.0,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
features_extractor_kwargs: Optional[dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = False,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, ClassVar, Optional, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -77,7 +77,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
|
||||
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
|
||||
"MlpPolicy": MlpPolicy,
|
||||
"CnnPolicy": CnnPolicy,
|
||||
"MultiInputPolicy": MultiInputPolicy,
|
||||
|
|
@ -89,7 +89,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[SACPolicy]],
|
||||
policy: Union[str, type[SACPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 3e-4,
|
||||
buffer_size: int = 1_000_000, # 1e6
|
||||
|
|
@ -97,11 +97,11 @@ class SAC(OffPolicyAlgorithm):
|
|||
batch_size: int = 256,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
train_freq: Union[int, Tuple[int, str]] = 1,
|
||||
train_freq: Union[int, tuple[int, str]] = 1,
|
||||
gradient_steps: int = 1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
|
||||
optimize_memory_usage: bool = False,
|
||||
ent_coef: Union[str, float] = "auto",
|
||||
target_update_interval: int = 1,
|
||||
|
|
@ -111,7 +111,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
use_sde_at_warmup: bool = False,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
policy_kwargs: Optional[dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
|
|
@ -313,10 +313,10 @@ class SAC(OffPolicyAlgorithm):
|
|||
progress_bar=progress_bar,
|
||||
)
|
||||
|
||||
def _excluded_save_params(self) -> List[str]:
|
||||
def _excluded_save_params(self) -> list[str]:
|
||||
return super()._excluded_save_params() + ["actor", "critic", "critic_target"] # noqa: RUF005
|
||||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
|
||||
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
||||
if self.ent_coef_optimizer is not None:
|
||||
saved_pytorch_variables = ["log_ent_coef"]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch as th
|
||||
from gymnasium import spaces
|
||||
|
|
@ -36,10 +36,10 @@ class Actor(BasePolicy):
|
|||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Box,
|
||||
net_arch: List[int],
|
||||
net_arch: list[int],
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
activation_fn: type[nn.Module] = nn.ReLU,
|
||||
normalize_images: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
|
|
@ -59,7 +59,7 @@ class Actor(BasePolicy):
|
|||
# Deterministic action
|
||||
self.mu = nn.Sequential(*actor_net)
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
def _get_constructor_parameters(self) -> dict[str, Any]:
|
||||
data = super()._get_constructor_parameters()
|
||||
|
||||
data.update(
|
||||
|
|
@ -116,13 +116,13 @@ class TD3Policy(BasePolicy):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Box,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
|
||||
activation_fn: type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = False,
|
||||
):
|
||||
|
|
@ -207,7 +207,7 @@ class TD3Policy(BasePolicy):
|
|||
self.actor_target.set_training_mode(False)
|
||||
self.critic_target.set_training_mode(False)
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
def _get_constructor_parameters(self) -> dict[str, Any]:
|
||||
data = super()._get_constructor_parameters()
|
||||
|
||||
data.update(
|
||||
|
|
@ -285,13 +285,13 @@ class CnnPolicy(TD3Policy):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Box,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
|
||||
activation_fn: type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = False,
|
||||
):
|
||||
|
|
@ -339,13 +339,13 @@ class MultiInputPolicy(TD3Policy):
|
|||
observation_space: spaces.Dict,
|
||||
action_space: spaces.Box,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
|
||||
activation_fn: type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
features_extractor_kwargs: Optional[dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = False,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, ClassVar, Optional, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -65,7 +65,7 @@ class TD3(OffPolicyAlgorithm):
|
|||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
|
||||
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
|
||||
"MlpPolicy": MlpPolicy,
|
||||
"CnnPolicy": CnnPolicy,
|
||||
"MultiInputPolicy": MultiInputPolicy,
|
||||
|
|
@ -78,7 +78,7 @@ class TD3(OffPolicyAlgorithm):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[TD3Policy]],
|
||||
policy: Union[str, type[TD3Policy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 1e-3,
|
||||
buffer_size: int = 1_000_000, # 1e6
|
||||
|
|
@ -86,18 +86,18 @@ class TD3(OffPolicyAlgorithm):
|
|||
batch_size: int = 256,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
train_freq: Union[int, Tuple[int, str]] = 1,
|
||||
train_freq: Union[int, tuple[int, str]] = 1,
|
||||
gradient_steps: int = 1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
|
||||
optimize_memory_usage: bool = False,
|
||||
policy_delay: int = 2,
|
||||
target_policy_noise: float = 0.2,
|
||||
target_noise_clip: float = 0.5,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
policy_kwargs: Optional[dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
|
|
@ -228,9 +228,9 @@ class TD3(OffPolicyAlgorithm):
|
|||
progress_bar=progress_bar,
|
||||
)
|
||||
|
||||
def _excluded_save_params(self) -> List[str]:
|
||||
def _excluded_save_params(self) -> list[str]:
|
||||
return super()._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"] # noqa: RUF005
|
||||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
|
||||
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
||||
return state_dicts, []
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.4.0
|
||||
2.5.0a0
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -72,7 +72,7 @@ class DummyDictEnv(gym.Env):
|
|||
terminated = truncated = False
|
||||
return self.observation_space.sample(), reward, terminated, truncated, {}
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
if seed is not None:
|
||||
self.observation_space.seed(seed)
|
||||
return self.observation_space.sample(), {}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -55,7 +54,7 @@ def test_squashed_gaussian(model_class):
|
|||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_model_distribution_obs_and_actions() -> Tuple[A2C, np.ndarray, np.ndarray]:
|
||||
def dummy_model_distribution_obs_and_actions() -> tuple[A2C, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Fixture creating a Pendulum-v1 gym env, an A2C model and sampling 10 random observations and actions from the env
|
||||
:return: A2C model, random observations, random actions
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -135,7 +135,7 @@ def test_check_env_detailed_error(obs_tuple, method):
|
|||
class TestEnv(gym.Env):
|
||||
action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
return wrong_obs if method == "reset" else good_obs, {}
|
||||
|
||||
def step(self, action):
|
||||
|
|
@ -162,7 +162,7 @@ class LimitedStepsTestEnv(gym.Env):
|
|||
self._steps_called = 0
|
||||
self._terminated = False
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[int, Dict]:
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[int, dict]:
|
||||
super().reset(seed=seed)
|
||||
|
||||
self._steps_called = 0
|
||||
|
|
@ -170,7 +170,7 @@ class LimitedStepsTestEnv(gym.Env):
|
|||
|
||||
return 0, {}
|
||||
|
||||
def step(self, action: np.ndarray) -> Tuple[int, float, bool, bool, Dict[str, Any]]:
|
||||
def step(self, action: np.ndarray) -> tuple[int, float, bool, bool, dict[str, Any]]:
|
||||
self._steps_called += 1
|
||||
|
||||
assert not self._terminated
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -23,7 +23,7 @@ class CustomEnv(gym.Env):
|
|||
def seed(self, seed):
|
||||
self.observation_space.seed(seed)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
if seed is not None:
|
||||
self.observation_space.seed(seed)
|
||||
self.n_steps = 0
|
||||
|
|
@ -53,7 +53,7 @@ class InfiniteHorizonEnv(gym.Env):
|
|||
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.current_state = 0
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
if seed is not None:
|
||||
super().reset(seed=seed)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ import importlib.util
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from io import TextIOBase
|
||||
from typing import Sequence
|
||||
from unittest import mock
|
||||
|
||||
import gymnasium as gym
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -24,7 +24,7 @@ class DummyEnv(gym.Env):
|
|||
def step(self, action):
|
||||
return self.observation_space.sample(), 0.0, False, False, {}
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
if seed is not None:
|
||||
super().reset(seed=seed)
|
||||
return self.observation_space.sample(), {}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from typing import Dict, Union
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -24,7 +24,7 @@ class HParamCallback(BaseCallback):
|
|||
"""
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
hparam_dict: Dict[str, Union[str, float]] = {
|
||||
hparam_dict: dict[str, Union[str, float]] = {
|
||||
"algorithm": self.model.__class__.__name__,
|
||||
# Ignore type checking for gamma, see https://github.com/DLR-RM/stable-baselines3/pull/1194/files#r1035006458
|
||||
"gamma": self.model.gamma, # type: ignore[attr-defined]
|
||||
|
|
@ -33,7 +33,7 @@ class HParamCallback(BaseCallback):
|
|||
hparam_dict["learning rate"] = self.model.learning_rate
|
||||
# define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
|
||||
# Tensorbaord will find & display metrics from the `SCALARS` tab
|
||||
metric_dict: Dict[str, float] = {
|
||||
metric_dict: dict[str, float] = {
|
||||
"rollout/ep_len_mean": 0,
|
||||
}
|
||||
self.logger.record(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import itertools
|
|||
import multiprocessing
|
||||
import os
|
||||
import warnings
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -30,9 +30,9 @@ class CustomGymEnv(gym.Env):
|
|||
self.current_step = 0
|
||||
self.ep_length = 4
|
||||
self.render_mode = render_mode
|
||||
self.current_options: Optional[Dict] = None
|
||||
self.current_options: Optional[dict] = None
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
if seed is not None:
|
||||
self.seed(seed)
|
||||
self.current_step = 0
|
||||
|
|
@ -193,7 +193,7 @@ class StepEnv(gym.Env):
|
|||
self.max_steps = max_steps
|
||||
self.current_step = 0
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
self.current_step = 0
|
||||
return np.array([self.current_step], dtype="int"), {}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import operator
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -22,7 +22,7 @@ ENV_ID = "Pendulum-v1"
|
|||
|
||||
|
||||
class DummyRewardEnv(gym.Env):
|
||||
metadata: Dict[str, Any] = {}
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
def __init__(self, return_reward_idx=0):
|
||||
self.action_space = spaces.Discrete(2)
|
||||
|
|
@ -39,7 +39,7 @@ class DummyRewardEnv(gym.Env):
|
|||
truncated = self.t == len(self.returned_rewards)
|
||||
return np.array([returned_value]), returned_value, terminated, truncated, {}
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
if seed is not None:
|
||||
super().reset(seed=seed)
|
||||
self.t = 0
|
||||
|
|
@ -62,7 +62,7 @@ class DummyDictEnv(gym.Env):
|
|||
)
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
if seed is not None:
|
||||
super().reset(seed=seed)
|
||||
return self.observation_space.sample(), {}
|
||||
|
|
@ -94,7 +94,7 @@ class DummyMixedDictEnv(gym.Env):
|
|||
)
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
if seed is not None:
|
||||
super().reset(seed=seed)
|
||||
return self.observation_space.sample(), {}
|
||||
|
|
|
|||
Loading…
Reference in a new issue