mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
Fix VecEnv type hints (#1736)
* Fix VecNormalize type hints * Fix VecEnv utils type annotations * Apply suggestions from code review Co-authored-by: M. Ernestus <maximilian@ernestus.de> * Remove PyType --------- Co-authored-by: M. Ernestus <maximilian@ernestus.de>
This commit is contained in:
parent
d671402c93
commit
b413f4c285
15 changed files with 63 additions and 67 deletions
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
|
|
@ -55,8 +55,6 @@ jobs:
|
|||
- name: Type check
|
||||
run: |
|
||||
make type
|
||||
# skip PyType, doesn't support 3.11 yet
|
||||
if: "!(matrix.python-version == '3.11')"
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make pytest
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ To run tests with `pytest`:
|
|||
make pytest
|
||||
```
|
||||
|
||||
Type checking with `pytype` and `mypy`:
|
||||
Type checking with `mypy`:
|
||||
|
||||
```
|
||||
make type
|
||||
|
|
|
|||
5
Makefile
5
Makefile
|
|
@ -4,9 +4,6 @@ LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py
|
|||
pytest:
|
||||
./scripts/run_tests.sh
|
||||
|
||||
pytype:
|
||||
pytype -j auto
|
||||
|
||||
mypy:
|
||||
mypy ${LINT_PATHS}
|
||||
|
||||
|
|
@ -16,7 +13,7 @@ missing-annotations:
|
|||
# missing docstrings
|
||||
# pylint -d R,C,W,E -e C0116 stable_baselines3 -j 4
|
||||
|
||||
type: pytype mypy
|
||||
type: mypy
|
||||
|
||||
lint:
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
|
|
|
|||
|
|
@ -61,8 +61,11 @@ Others:
|
|||
- Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``)
|
||||
- Fixed ``stable_baselines3/common/off_policy_algorithm.py`` type hints
|
||||
- Fixed ``stable_baselines3/common/distributions.py`` type hints
|
||||
- Fixed ``stable_baselines3/common/vec_env/vec_normalize.py`` type hints
|
||||
- Fixed ``stable_baselines3/common/vec_env/__init__.py`` type hints
|
||||
- Switched to PyTorch 2.1.0 in the CI (fixes type annotations)
|
||||
- Fixed ``stable_baselines3/common/policies.py`` type hints
|
||||
- Switched to ``mypy`` only for checking types
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -24,31 +24,12 @@ max-complexity = 15
|
|||
[tool.black]
|
||||
line-length = 127
|
||||
|
||||
[tool.pytype]
|
||||
inputs = ["stable_baselines3"]
|
||||
disable = ["pyi-error"]
|
||||
# Checked with mypy
|
||||
exclude = [
|
||||
"stable_baselines3/common/buffers.py",
|
||||
"stable_baselines3/common/base_class.py",
|
||||
"stable_baselines3/common/callbacks.py",
|
||||
"stable_baselines3/common/on_policy_algorithm.py",
|
||||
"stable_baselines3/common/vec_env/stacked_observations.py",
|
||||
"stable_baselines3/common/vec_env/subproc_vec_env.py",
|
||||
"stable_baselines3/common/vec_env/patch_gym.py",
|
||||
"stable_baselines3/common/off_policy_algorithm.py",
|
||||
"stable_baselines3/common/distributions.py",
|
||||
"stable_baselines3/common/policies.py",
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true
|
||||
follow_imports = "silent"
|
||||
show_error_codes = true
|
||||
exclude = """(?x)(
|
||||
stable_baselines3/common/vec_env/__init__.py$
|
||||
| stable_baselines3/common/vec_env/vec_normalize.py$
|
||||
| tests/test_logger.py$
|
||||
tests/test_logger.py$
|
||||
| tests/test_train_eval_mode.py$
|
||||
)"""
|
||||
|
||||
|
|
|
|||
1
setup.py
1
setup.py
|
|
@ -118,7 +118,6 @@ setup(
|
|||
"pytest-env",
|
||||
"pytest-xdist",
|
||||
# Type check
|
||||
"pytype",
|
||||
"mypy",
|
||||
# Lint code and sort imports (flake8 and isort replacement)
|
||||
"ruff>=0.0.288",
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from gymnasium import spaces
|
|||
from stable_baselines3.common.type_aliases import AtariResetReturn, AtariStepReturn
|
||||
|
||||
try:
|
||||
import cv2 # pytype:disable=import-error
|
||||
import cv2
|
||||
|
||||
cv2.ocl.setUseOpenCL(False)
|
||||
except ImportError:
|
||||
|
|
|
|||
|
|
@ -193,9 +193,7 @@ class ResultsWriter:
|
|||
mode = "w" if override_existing else "a"
|
||||
# Prevent newline issue on Windows, see GH issue #692
|
||||
self.file_handler = open(filename, f"{mode}t", newline="\n")
|
||||
self.logger = csv.DictWriter(
|
||||
self.file_handler, fieldnames=("r", "l", "t", *extra_keys)
|
||||
) # pytype: disable=wrong-arg-types
|
||||
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t", *extra_keys))
|
||||
if override_existing:
|
||||
self.file_handler.write(f"#{json.dumps(header)}\n")
|
||||
self.logger.writeheader()
|
||||
|
|
|
|||
|
|
@ -193,14 +193,14 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
device=self.device,
|
||||
n_envs=self.n_envs,
|
||||
optimize_memory_usage=self.optimize_memory_usage,
|
||||
**replay_buffer_kwargs, # pytype:disable=wrong-keyword-args
|
||||
**replay_buffer_kwargs,
|
||||
)
|
||||
|
||||
self.policy = self.policy_class( # pytype:disable=not-instantiable
|
||||
self.policy = self.policy_class(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.lr_schedule,
|
||||
**self.policy_kwargs, # pytype:disable=not-instantiable
|
||||
**self.policy_kwargs,
|
||||
)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ class BaseModel(nn.Module):
|
|||
saved_variables = th.load(path, map_location=device)
|
||||
|
||||
# Create policy object
|
||||
model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable
|
||||
model = cls(**saved_variables["data"])
|
||||
# Load weights
|
||||
model.load_state_dict(saved_variables["state_dict"])
|
||||
model.to(device)
|
||||
|
|
|
|||
|
|
@ -536,7 +536,7 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
|
|||
"Gymnasium": gym.__version__,
|
||||
}
|
||||
try:
|
||||
import gym as openai_gym # pytype: disable=import-error
|
||||
import gym as openai_gym
|
||||
|
||||
env_info.update({"OpenAI Gym": openai_gym.__version__})
|
||||
except ImportError:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import typing
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Type, Union
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
|
||||
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
|
|
@ -14,18 +13,16 @@ from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
|
|||
from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage
|
||||
from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
|
||||
|
||||
# Avoid circular import
|
||||
if typing.TYPE_CHECKING:
|
||||
from stable_baselines3.common.type_aliases import GymEnv
|
||||
VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper)
|
||||
|
||||
|
||||
def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]:
|
||||
def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]:
|
||||
"""
|
||||
Retrieve a ``VecEnvWrapper`` object by recursively searching.
|
||||
|
||||
:param env:
|
||||
:param vec_wrapper_class:
|
||||
:return:
|
||||
:param env: The ``VecEnv`` that is going to be unwrapped
|
||||
:param vec_wrapper_class: The desired ``VecEnvWrapper`` class.
|
||||
:return: The ``VecEnvWrapper`` object if the ``VecEnv`` is wrapped with the desired wrapper, None otherwise
|
||||
"""
|
||||
env_tmp = env
|
||||
while isinstance(env_tmp, VecEnvWrapper):
|
||||
|
|
@ -35,36 +32,50 @@ def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[Vec
|
|||
return None
|
||||
|
||||
|
||||
def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]:
|
||||
def unwrap_vec_normalize(env: VecEnv) -> Optional[VecNormalize]:
|
||||
"""
|
||||
:param env:
|
||||
:return:
|
||||
Retrieve a ``VecNormalize`` object by recursively searching.
|
||||
|
||||
:param env: The VecEnv that is going to be unwrapped
|
||||
:return: The ``VecNormalize`` object if the ``VecEnv`` is wrapped with ``VecNormalize``, None otherwise
|
||||
"""
|
||||
return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type
|
||||
return unwrap_vec_wrapper(env, VecNormalize)
|
||||
|
||||
|
||||
def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
|
||||
def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
|
||||
"""
|
||||
Check if an environment is already wrapped by a given ``VecEnvWrapper``.
|
||||
Check if an environment is already wrapped in a given ``VecEnvWrapper``.
|
||||
|
||||
:param env:
|
||||
:param vec_wrapper_class:
|
||||
:return:
|
||||
:param env: The VecEnv that is going to be checked
|
||||
:param vec_wrapper_class: The desired ``VecEnvWrapper`` class.
|
||||
:return: True if the ``VecEnv`` is wrapped with the desired wrapper, False otherwise
|
||||
"""
|
||||
return unwrap_vec_wrapper(env, vec_wrapper_class) is not None
|
||||
|
||||
|
||||
# Define here to avoid circular import
|
||||
def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
|
||||
def sync_envs_normalization(env: VecEnv, eval_env: VecEnv) -> None:
|
||||
"""
|
||||
Sync eval env and train env when using VecNormalize
|
||||
Synchronize the normalization statistics of an eval environment and train environment
|
||||
when they are both wrapped in a ``VecNormalize`` wrapper.
|
||||
|
||||
:param env:
|
||||
:param eval_env:
|
||||
:param env: Training env
|
||||
:param eval_env: Environment used for evaluation.
|
||||
"""
|
||||
env_tmp, eval_env_tmp = env, eval_env
|
||||
while isinstance(env_tmp, VecEnvWrapper):
|
||||
assert isinstance(eval_env_tmp, VecEnvWrapper), (
|
||||
"Error while synchronizing normalization stats: expected the eval env to be "
|
||||
f"a VecEnvWrapper but got {eval_env_tmp} instead. "
|
||||
"This is probably due to the training env not being wrapped the same way as the evaluation env. "
|
||||
f"Training env type: {env_tmp}."
|
||||
)
|
||||
if isinstance(env_tmp, VecNormalize):
|
||||
assert isinstance(eval_env_tmp, VecNormalize), (
|
||||
"Error while synchronizing normalization stats: expected the eval env to be "
|
||||
f"a VecNormalize but got {eval_env_tmp} instead. "
|
||||
"This is probably due to the training env not being wrapped the same way as the evaluation env. "
|
||||
f"Training env type: {env_tmp}."
|
||||
)
|
||||
# Only synchronize if observation normalization exists
|
||||
if hasattr(env_tmp, "obs_rms"):
|
||||
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
|
||||
|
|
|
|||
|
|
@ -258,7 +258,7 @@ class VecEnv(ABC):
|
|||
|
||||
if mode == "human":
|
||||
# Display it using OpenCV
|
||||
import cv2 # pytype:disable=import-error
|
||||
import cv2
|
||||
|
||||
cv2.imshow("vecenv", bigimg[:, :, ::-1])
|
||||
cv2.waitKey(1)
|
||||
|
|
|
|||
|
|
@ -38,6 +38,6 @@ class VecFrameStack(VecEnvWrapper):
|
|||
"""
|
||||
Reset all environments
|
||||
"""
|
||||
observation = self.venv.reset() # pytype:disable=annotation-type-mismatch
|
||||
observation = self.venv.reset()
|
||||
observation = self.stacked_obs.reset(observation) # type: ignore[arg-type]
|
||||
return observation
|
||||
|
|
|
|||
|
|
@ -29,6 +29,9 @@ class VecNormalize(VecEnvWrapper):
|
|||
If not specified, all keys will be normalized.
|
||||
"""
|
||||
|
||||
obs_spaces: Dict[str, spaces.Space]
|
||||
old_obs: Union[np.ndarray, Dict[str, np.ndarray]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
venv: VecEnv,
|
||||
|
|
@ -47,11 +50,12 @@ class VecNormalize(VecEnvWrapper):
|
|||
self.norm_obs_keys = norm_obs_keys
|
||||
# Check observation spaces
|
||||
if self.norm_obs:
|
||||
# Note: mypy doesn't take into account the sanity checks, which lead to several type: ignore...
|
||||
self._sanity_checks()
|
||||
|
||||
if isinstance(self.observation_space, spaces.Dict):
|
||||
self.obs_spaces = self.observation_space.spaces
|
||||
self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys}
|
||||
self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys} # type: ignore[arg-type, union-attr]
|
||||
# Update observation space when using image
|
||||
# See explanation below and GH #1214
|
||||
for key in self.obs_rms.keys():
|
||||
|
|
@ -64,8 +68,7 @@ class VecNormalize(VecEnvWrapper):
|
|||
)
|
||||
|
||||
else:
|
||||
self.obs_spaces = None
|
||||
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
|
||||
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) # type: ignore[assignment, arg-type]
|
||||
# Update observation space when using image
|
||||
# See GH #1214
|
||||
# This is to raise proper error when
|
||||
|
|
@ -92,7 +95,6 @@ class VecNormalize(VecEnvWrapper):
|
|||
self.training = training
|
||||
self.norm_obs = norm_obs
|
||||
self.norm_reward = norm_reward
|
||||
self.old_obs = np.array([])
|
||||
self.old_reward = np.array([])
|
||||
|
||||
def _sanity_checks(self) -> None:
|
||||
|
|
@ -148,7 +150,7 @@ class VecNormalize(VecEnvWrapper):
|
|||
state["norm_obs_keys"] = list(state["observation_space"].spaces.keys())
|
||||
self.__dict__.update(state)
|
||||
assert "venv" not in state
|
||||
self.venv = None
|
||||
self.venv = None # type: ignore[assignment]
|
||||
|
||||
def set_venv(self, venv: VecEnv) -> None:
|
||||
"""
|
||||
|
|
@ -177,6 +179,7 @@ class VecNormalize(VecEnvWrapper):
|
|||
where ``dones`` is a boolean vector indicating whether each element is new.
|
||||
"""
|
||||
obs, rewards, dones, infos = self.venv.step_wait()
|
||||
assert isinstance(obs, (np.ndarray, dict)) # for mypy
|
||||
self.old_obs = obs
|
||||
self.old_reward = rewards
|
||||
|
||||
|
|
@ -235,10 +238,12 @@ class VecNormalize(VecEnvWrapper):
|
|||
obs_ = deepcopy(obs)
|
||||
if self.norm_obs:
|
||||
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
|
||||
assert self.norm_obs_keys is not None
|
||||
# Only normalize the specified keys
|
||||
for key in self.norm_obs_keys:
|
||||
obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32)
|
||||
else:
|
||||
assert isinstance(self.obs_rms, RunningMeanStd)
|
||||
obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32)
|
||||
return obs_
|
||||
|
||||
|
|
@ -256,9 +261,11 @@ class VecNormalize(VecEnvWrapper):
|
|||
obs_ = deepcopy(obs)
|
||||
if self.norm_obs:
|
||||
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
|
||||
assert self.norm_obs_keys is not None
|
||||
for key in self.norm_obs_keys:
|
||||
obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key])
|
||||
else:
|
||||
assert isinstance(self.obs_rms, RunningMeanStd)
|
||||
obs_ = self._unnormalize_obs(obs, self.obs_rms)
|
||||
return obs_
|
||||
|
||||
|
|
@ -286,6 +293,7 @@ class VecNormalize(VecEnvWrapper):
|
|||
:return: first observation of the episode
|
||||
"""
|
||||
obs = self.venv.reset()
|
||||
assert isinstance(obs, (np.ndarray, dict))
|
||||
self.old_obs = obs
|
||||
self.returns = np.zeros(self.num_envs)
|
||||
if self.training and self.norm_obs:
|
||||
|
|
@ -293,6 +301,7 @@ class VecNormalize(VecEnvWrapper):
|
|||
for key in self.obs_rms.keys():
|
||||
self.obs_rms[key].update(obs[key])
|
||||
else:
|
||||
assert isinstance(self.obs_rms, RunningMeanStd)
|
||||
self.obs_rms.update(obs)
|
||||
return self.normalize_obs(obs)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue