Mypy type checking (#1143)

* Install and configure mypy

* Test if github CI uses setup.cfg for mypy

* force color output

* tab to space

* Try to fix regex

* follow_imports silent

* use space as indentation

* fix indentation setup.cfg

* Show error code

* Update doc

* Udate changelog

* Ignore mypy cache files from commit

* Update gitlab CI

* Add pytype and mypy entry in Makefile

* Make mypy happy

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Quentin Gallouédec 2022-11-16 13:22:57 +01:00 committed by GitHub
parent 8641b05b09
commit abffa16198
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 73 additions and 9 deletions

1
.gitignore vendored
View file

@ -4,6 +4,7 @@
*.py~
*.bak
.pytest_cache
.mypy_cache
.DS_Store
.idea
.vscode

View file

@ -2,7 +2,7 @@ image: stablebaselines/stable-baselines3-cpu:1.4.1a0
type-check:
script:
- pip install pytype --upgrade
- pip install pytype mypy --upgrade
- make type
pytest:

View file

@ -78,7 +78,7 @@ To run tests with `pytest`:
make pytest
```
Type checking with `pytype`:
Type checking with `pytype` and `mypy`:
```
make type
@ -91,7 +91,7 @@ make check-codestyle
make lint
```
To run `pytype`, `format` and `lint` in one command:
To run `type`, `format` and `lint` in one command:
```
make commit-checks
```

View file

@ -4,9 +4,14 @@ LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py
pytest:
./scripts/run_tests.sh
type:
pytype:
pytype -j auto
mypy:
MYPY_FORCE_COLOR=1 mypy ${LINT_PATHS}
type: pytype mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://lintlyci.github.io/Flake8Rules/

View file

@ -198,9 +198,9 @@ pip install pytest pytest-cov
make pytest
```
You can also do a static type check using `pytype`:
You can also do a static type check using `pytype` and `mypy`:
```
pip install pytype
pip install pytype mypy
make type
```

View file

@ -13,6 +13,7 @@
#
import os
import sys
from typing import Dict, List
from unittest.mock import MagicMock
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
@ -37,7 +38,7 @@ sys.path.insert(0, os.path.abspath(".."))
class Mock(MagicMock):
__subclasses__ = []
__subclasses__ = [] # type: ignore
@classmethod
def __getattr__(cls, name):
@ -48,7 +49,7 @@ class Mock(MagicMock):
# Note: because of that we cannot test examples using CI
# 'torch', 'torch.nn', 'torch.nn.functional',
# DO not mock modules for now, we will need to do that for read the docs later
MOCK_MODULES = []
MOCK_MODULES: List[str] = []
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
# Read version from file
@ -171,7 +172,7 @@ htmlhelp_basename = "StableBaselines3doc"
# -- Options for LaTeX output ------------------------------------------------
latex_elements = {
latex_elements: Dict[str, str] = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',

View file

@ -16,6 +16,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Introduced mypy type checking
SB3-Contrib
^^^^^^^^^^^

View file

@ -25,6 +25,60 @@ markers =
inputs = stable_baselines3
disable = pyi-error
[mypy]
ignore_missing_imports = True
follow_imports = silent
show_error_codes = True
exclude = (?x)(
stable_baselines3/a2c/a2c.py$
| stable_baselines3/common/atari_wrappers.py$
| stable_baselines3/common/base_class.py$
| stable_baselines3/common/buffers.py$
| stable_baselines3/common/callbacks.py$
| stable_baselines3/common/distributions.py$
| stable_baselines3/common/env_util.py$
| stable_baselines3/common/envs/bit_flipping_env.py$
| stable_baselines3/common/envs/identity_env.py$
| stable_baselines3/common/envs/multi_input_envs.py$
| stable_baselines3/common/logger.py$
| stable_baselines3/common/monitor.py$
| stable_baselines3/common/off_policy_algorithm.py$
| stable_baselines3/common/on_policy_algorithm.py$
| stable_baselines3/common/policies.py$
| stable_baselines3/common/preprocessing.py$
| stable_baselines3/common/save_util.py$
| stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$
| stable_baselines3/common/torch_layers.py$
| stable_baselines3/common/type_aliases.py$
| stable_baselines3/common/utils.py$
| stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/base_vec_env.py$
| stable_baselines3/common/vec_env/dummy_vec_env.py$
| stable_baselines3/common/vec_env/stacked_observations.py$
| stable_baselines3/common/vec_env/subproc_vec_env.py$
| stable_baselines3/common/vec_env/util.py$
| stable_baselines3/common/vec_env/vec_check_nan.py$
| stable_baselines3/common/vec_env/vec_extract_dict_obs.py$
| stable_baselines3/common/vec_env/vec_frame_stack.py$
| stable_baselines3/common/vec_env/vec_monitor.py$
| stable_baselines3/common/vec_env/vec_normalize.py$
| stable_baselines3/common/vec_env/vec_transpose.py$
| stable_baselines3/common/vec_env/vec_video_recorder.py$
| stable_baselines3/dqn/dqn.py$
| stable_baselines3/dqn/policies.py$
| stable_baselines3/her/her_replay_buffer.py$
| stable_baselines3/ppo/ppo.py$
| stable_baselines3/sac/policies.py$
| stable_baselines3/sac/sac.py$
| stable_baselines3/td3/policies.py$
| stable_baselines3/td3/td3.py$
| tests/test_distributions.py$
| tests/test_logger.py$
| tests/test_tensorboard.py$
| tests/test_train_eval_mode.py$
| tests/test_vec_normalize.py$
)
[flake8]
ignore = W503,W504,E203,E231 # line breaks before and after binary operators
# Ignore import not used when aliases are defined

View file

@ -95,6 +95,7 @@ setup(
"pytest-xdist",
# Type check
"pytype",
"mypy",
# Lint code
"flake8>=3.8",
# Find likely bugs

View file

@ -149,6 +149,7 @@ def test_evaluate_policy(direct_policy: bool):
def dummy_callback(locals_, _globals):
locals_["model"].n_callback_calls += 1
assert model.policy is not None
policy = model.policy if direct_policy else model
policy.n_callback_calls = 0
_, episode_lengths = evaluate_policy(