Fix VecEnv type hints (#1736)

* Fix VecNormalize type hints

* Fix VecEnv utils type annotations

* Apply suggestions from code review

Co-authored-by: M. Ernestus <maximilian@ernestus.de>

* Remove PyType

---------

Co-authored-by: M. Ernestus <maximilian@ernestus.de>
This commit is contained in:
Antonin RAFFIN 2023-11-08 09:46:40 +01:00 committed by GitHub
parent d671402c93
commit b413f4c285
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 63 additions and 67 deletions

View file

@ -55,8 +55,6 @@ jobs:
- name: Type check
run: |
make type
# skip PyType, doesn't support 3.11 yet
if: "!(matrix.python-version == '3.11')"
- name: Test with pytest
run: |
make pytest

View file

@ -79,7 +79,7 @@ To run tests with `pytest`:
make pytest
```
Type checking with `pytype` and `mypy`:
Type checking with `mypy`:
```
make type

View file

@ -4,9 +4,6 @@ LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py
pytest:
./scripts/run_tests.sh
pytype:
pytype -j auto
mypy:
mypy ${LINT_PATHS}
@ -16,7 +13,7 @@ missing-annotations:
# missing docstrings
# pylint -d R,C,W,E -e C0116 stable_baselines3 -j 4
type: pytype mypy
type: mypy
lint:
# stop the build if there are Python syntax errors or undefined names

View file

@ -61,8 +61,11 @@ Others:
- Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``)
- Fixed ``stable_baselines3/common/off_policy_algorithm.py`` type hints
- Fixed ``stable_baselines3/common/distributions.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/vec_normalize.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/__init__.py`` type hints
- Switched to PyTorch 2.1.0 in the CI (fixes type annotations)
- Fixed ``stable_baselines3/common/policies.py`` type hints
- Switched to ``mypy`` only for checking types
Documentation:
^^^^^^^^^^^^^^

View file

@ -24,31 +24,12 @@ max-complexity = 15
[tool.black]
line-length = 127
[tool.pytype]
inputs = ["stable_baselines3"]
disable = ["pyi-error"]
# Checked with mypy
exclude = [
"stable_baselines3/common/buffers.py",
"stable_baselines3/common/base_class.py",
"stable_baselines3/common/callbacks.py",
"stable_baselines3/common/on_policy_algorithm.py",
"stable_baselines3/common/vec_env/stacked_observations.py",
"stable_baselines3/common/vec_env/subproc_vec_env.py",
"stable_baselines3/common/vec_env/patch_gym.py",
"stable_baselines3/common/off_policy_algorithm.py",
"stable_baselines3/common/distributions.py",
"stable_baselines3/common/policies.py",
]
[tool.mypy]
ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/vec_normalize.py$
| tests/test_logger.py$
tests/test_logger.py$
| tests/test_train_eval_mode.py$
)"""

View file

@ -118,7 +118,6 @@ setup(
"pytest-env",
"pytest-xdist",
# Type check
"pytype",
"mypy",
# Lint code and sort imports (flake8 and isort replacement)
"ruff>=0.0.288",

View file

@ -7,7 +7,7 @@ from gymnasium import spaces
from stable_baselines3.common.type_aliases import AtariResetReturn, AtariStepReturn
try:
import cv2 # pytype:disable=import-error
import cv2
cv2.ocl.setUseOpenCL(False)
except ImportError:

View file

@ -193,9 +193,7 @@ class ResultsWriter:
mode = "w" if override_existing else "a"
# Prevent newline issue on Windows, see GH issue #692
self.file_handler = open(filename, f"{mode}t", newline="\n")
self.logger = csv.DictWriter(
self.file_handler, fieldnames=("r", "l", "t", *extra_keys)
) # pytype: disable=wrong-arg-types
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t", *extra_keys))
if override_existing:
self.file_handler.write(f"#{json.dumps(header)}\n")
self.logger.writeheader()

View file

@ -193,14 +193,14 @@ class OffPolicyAlgorithm(BaseAlgorithm):
device=self.device,
n_envs=self.n_envs,
optimize_memory_usage=self.optimize_memory_usage,
**replay_buffer_kwargs, # pytype:disable=wrong-keyword-args
**replay_buffer_kwargs,
)
self.policy = self.policy_class( # pytype:disable=not-instantiable
self.policy = self.policy_class(
self.observation_space,
self.action_space,
self.lr_schedule,
**self.policy_kwargs, # pytype:disable=not-instantiable
**self.policy_kwargs,
)
self.policy = self.policy.to(self.device)

View file

@ -176,7 +176,7 @@ class BaseModel(nn.Module):
saved_variables = th.load(path, map_location=device)
# Create policy object
model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable
model = cls(**saved_variables["data"])
# Load weights
model.load_state_dict(saved_variables["state_dict"])
model.to(device)

View file

@ -536,7 +536,7 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
"Gymnasium": gym.__version__,
}
try:
import gym as openai_gym # pytype: disable=import-error
import gym as openai_gym
env_info.update({"OpenAI Gym": openai_gym.__version__})
except ImportError:

View file

@ -1,6 +1,5 @@
import typing
from copy import deepcopy
from typing import Optional, Type, Union
from typing import Optional, Type, TypeVar
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
@ -14,18 +13,16 @@ from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage
from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
# Avoid circular import
if typing.TYPE_CHECKING:
from stable_baselines3.common.type_aliases import GymEnv
VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper)
def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]:
def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]:
"""
Retrieve a ``VecEnvWrapper`` object by recursively searching.
:param env:
:param vec_wrapper_class:
:return:
:param env: The ``VecEnv`` that is going to be unwrapped
:param vec_wrapper_class: The desired ``VecEnvWrapper`` class.
:return: The ``VecEnvWrapper`` object if the ``VecEnv`` is wrapped with the desired wrapper, None otherwise
"""
env_tmp = env
while isinstance(env_tmp, VecEnvWrapper):
@ -35,36 +32,50 @@ def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[Vec
return None
def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]:
def unwrap_vec_normalize(env: VecEnv) -> Optional[VecNormalize]:
"""
:param env:
:return:
Retrieve a ``VecNormalize`` object by recursively searching.
:param env: The VecEnv that is going to be unwrapped
:return: The ``VecNormalize`` object if the ``VecEnv`` is wrapped with ``VecNormalize``, None otherwise
"""
return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type
return unwrap_vec_wrapper(env, VecNormalize)
def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
"""
Check if an environment is already wrapped by a given ``VecEnvWrapper``.
Check if an environment is already wrapped in a given ``VecEnvWrapper``.
:param env:
:param vec_wrapper_class:
:return:
:param env: The VecEnv that is going to be checked
:param vec_wrapper_class: The desired ``VecEnvWrapper`` class.
:return: True if the ``VecEnv`` is wrapped with the desired wrapper, False otherwise
"""
return unwrap_vec_wrapper(env, vec_wrapper_class) is not None
# Define here to avoid circular import
def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
def sync_envs_normalization(env: VecEnv, eval_env: VecEnv) -> None:
"""
Sync eval env and train env when using VecNormalize
Synchronize the normalization statistics of an eval environment and train environment
when they are both wrapped in a ``VecNormalize`` wrapper.
:param env:
:param eval_env:
:param env: Training env
:param eval_env: Environment used for evaluation.
"""
env_tmp, eval_env_tmp = env, eval_env
while isinstance(env_tmp, VecEnvWrapper):
assert isinstance(eval_env_tmp, VecEnvWrapper), (
"Error while synchronizing normalization stats: expected the eval env to be "
f"a VecEnvWrapper but got {eval_env_tmp} instead. "
"This is probably due to the training env not being wrapped the same way as the evaluation env. "
f"Training env type: {env_tmp}."
)
if isinstance(env_tmp, VecNormalize):
assert isinstance(eval_env_tmp, VecNormalize), (
"Error while synchronizing normalization stats: expected the eval env to be "
f"a VecNormalize but got {eval_env_tmp} instead. "
"This is probably due to the training env not being wrapped the same way as the evaluation env. "
f"Training env type: {env_tmp}."
)
# Only synchronize if observation normalization exists
if hasattr(env_tmp, "obs_rms"):
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)

View file

@ -258,7 +258,7 @@ class VecEnv(ABC):
if mode == "human":
# Display it using OpenCV
import cv2 # pytype:disable=import-error
import cv2
cv2.imshow("vecenv", bigimg[:, :, ::-1])
cv2.waitKey(1)

View file

@ -38,6 +38,6 @@ class VecFrameStack(VecEnvWrapper):
"""
Reset all environments
"""
observation = self.venv.reset() # pytype:disable=annotation-type-mismatch
observation = self.venv.reset()
observation = self.stacked_obs.reset(observation) # type: ignore[arg-type]
return observation

View file

@ -29,6 +29,9 @@ class VecNormalize(VecEnvWrapper):
If not specified, all keys will be normalized.
"""
obs_spaces: Dict[str, spaces.Space]
old_obs: Union[np.ndarray, Dict[str, np.ndarray]]
def __init__(
self,
venv: VecEnv,
@ -47,11 +50,12 @@ class VecNormalize(VecEnvWrapper):
self.norm_obs_keys = norm_obs_keys
# Check observation spaces
if self.norm_obs:
# Note: mypy doesn't take into account the sanity checks, which lead to several type: ignore...
self._sanity_checks()
if isinstance(self.observation_space, spaces.Dict):
self.obs_spaces = self.observation_space.spaces
self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys}
self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys} # type: ignore[arg-type, union-attr]
# Update observation space when using image
# See explanation below and GH #1214
for key in self.obs_rms.keys():
@ -64,8 +68,7 @@ class VecNormalize(VecEnvWrapper):
)
else:
self.obs_spaces = None
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) # type: ignore[assignment, arg-type]
# Update observation space when using image
# See GH #1214
# This is to raise proper error when
@ -92,7 +95,6 @@ class VecNormalize(VecEnvWrapper):
self.training = training
self.norm_obs = norm_obs
self.norm_reward = norm_reward
self.old_obs = np.array([])
self.old_reward = np.array([])
def _sanity_checks(self) -> None:
@ -148,7 +150,7 @@ class VecNormalize(VecEnvWrapper):
state["norm_obs_keys"] = list(state["observation_space"].spaces.keys())
self.__dict__.update(state)
assert "venv" not in state
self.venv = None
self.venv = None # type: ignore[assignment]
def set_venv(self, venv: VecEnv) -> None:
"""
@ -177,6 +179,7 @@ class VecNormalize(VecEnvWrapper):
where ``dones`` is a boolean vector indicating whether each element is new.
"""
obs, rewards, dones, infos = self.venv.step_wait()
assert isinstance(obs, (np.ndarray, dict)) # for mypy
self.old_obs = obs
self.old_reward = rewards
@ -235,10 +238,12 @@ class VecNormalize(VecEnvWrapper):
obs_ = deepcopy(obs)
if self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
assert self.norm_obs_keys is not None
# Only normalize the specified keys
for key in self.norm_obs_keys:
obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32)
else:
assert isinstance(self.obs_rms, RunningMeanStd)
obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32)
return obs_
@ -256,9 +261,11 @@ class VecNormalize(VecEnvWrapper):
obs_ = deepcopy(obs)
if self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
assert self.norm_obs_keys is not None
for key in self.norm_obs_keys:
obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key])
else:
assert isinstance(self.obs_rms, RunningMeanStd)
obs_ = self._unnormalize_obs(obs, self.obs_rms)
return obs_
@ -286,6 +293,7 @@ class VecNormalize(VecEnvWrapper):
:return: first observation of the episode
"""
obs = self.venv.reset()
assert isinstance(obs, (np.ndarray, dict))
self.old_obs = obs
self.returns = np.zeros(self.num_envs)
if self.training and self.norm_obs:
@ -293,6 +301,7 @@ class VecNormalize(VecEnvWrapper):
for key in self.obs_rms.keys():
self.obs_rms[key].update(obs[key])
else:
assert isinstance(self.obs_rms, RunningMeanStd)
self.obs_rms.update(obs)
return self.normalize_obs(obs)