mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Mypy type checking (#1143)
* Install and configure mypy * Test if github CI uses setup.cfg for mypy * force color output * tab to space * Try to fix regex * follow_imports silent * use space as indentation * fix indentation setup.cfg * Show error code * Update doc * Udate changelog * Ignore mypy cache files from commit * Update gitlab CI * Add pytype and mypy entry in Makefile * Make mypy happy Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
8641b05b09
commit
abffa16198
10 changed files with 73 additions and 9 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -4,6 +4,7 @@
|
|||
*.py~
|
||||
*.bak
|
||||
.pytest_cache
|
||||
.mypy_cache
|
||||
.DS_Store
|
||||
.idea
|
||||
.vscode
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ image: stablebaselines/stable-baselines3-cpu:1.4.1a0
|
|||
|
||||
type-check:
|
||||
script:
|
||||
- pip install pytype --upgrade
|
||||
- pip install pytype mypy --upgrade
|
||||
- make type
|
||||
|
||||
pytest:
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ To run tests with `pytest`:
|
|||
make pytest
|
||||
```
|
||||
|
||||
Type checking with `pytype`:
|
||||
Type checking with `pytype` and `mypy`:
|
||||
|
||||
```
|
||||
make type
|
||||
|
|
@ -91,7 +91,7 @@ make check-codestyle
|
|||
make lint
|
||||
```
|
||||
|
||||
To run `pytype`, `format` and `lint` in one command:
|
||||
To run `type`, `format` and `lint` in one command:
|
||||
```
|
||||
make commit-checks
|
||||
```
|
||||
|
|
|
|||
7
Makefile
7
Makefile
|
|
@ -4,9 +4,14 @@ LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py
|
|||
pytest:
|
||||
./scripts/run_tests.sh
|
||||
|
||||
type:
|
||||
pytype:
|
||||
pytype -j auto
|
||||
|
||||
mypy:
|
||||
MYPY_FORCE_COLOR=1 mypy ${LINT_PATHS}
|
||||
|
||||
type: pytype mypy
|
||||
|
||||
lint:
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
# see https://lintlyci.github.io/Flake8Rules/
|
||||
|
|
|
|||
|
|
@ -198,9 +198,9 @@ pip install pytest pytest-cov
|
|||
make pytest
|
||||
```
|
||||
|
||||
You can also do a static type check using `pytype`:
|
||||
You can also do a static type check using `pytype` and `mypy`:
|
||||
```
|
||||
pip install pytype
|
||||
pip install pytype mypy
|
||||
make type
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
#
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
|
||||
|
|
@ -37,7 +38,7 @@ sys.path.insert(0, os.path.abspath(".."))
|
|||
|
||||
|
||||
class Mock(MagicMock):
|
||||
__subclasses__ = []
|
||||
__subclasses__ = [] # type: ignore
|
||||
|
||||
@classmethod
|
||||
def __getattr__(cls, name):
|
||||
|
|
@ -48,7 +49,7 @@ class Mock(MagicMock):
|
|||
# Note: because of that we cannot test examples using CI
|
||||
# 'torch', 'torch.nn', 'torch.nn.functional',
|
||||
# DO not mock modules for now, we will need to do that for read the docs later
|
||||
MOCK_MODULES = []
|
||||
MOCK_MODULES: List[str] = []
|
||||
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
|
||||
|
||||
# Read version from file
|
||||
|
|
@ -171,7 +172,7 @@ htmlhelp_basename = "StableBaselines3doc"
|
|||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
latex_elements = {
|
||||
latex_elements: Dict[str, str] = {
|
||||
# The paper size ('letterpaper' or 'a4paper').
|
||||
#
|
||||
# 'papersize': 'letterpaper',
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ Breaking Changes:
|
|||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Introduced mypy type checking
|
||||
|
||||
SB3-Contrib
|
||||
^^^^^^^^^^^
|
||||
|
|
|
|||
54
setup.cfg
54
setup.cfg
|
|
@ -25,6 +25,60 @@ markers =
|
|||
inputs = stable_baselines3
|
||||
disable = pyi-error
|
||||
|
||||
[mypy]
|
||||
ignore_missing_imports = True
|
||||
follow_imports = silent
|
||||
show_error_codes = True
|
||||
exclude = (?x)(
|
||||
stable_baselines3/a2c/a2c.py$
|
||||
| stable_baselines3/common/atari_wrappers.py$
|
||||
| stable_baselines3/common/base_class.py$
|
||||
| stable_baselines3/common/buffers.py$
|
||||
| stable_baselines3/common/callbacks.py$
|
||||
| stable_baselines3/common/distributions.py$
|
||||
| stable_baselines3/common/env_util.py$
|
||||
| stable_baselines3/common/envs/bit_flipping_env.py$
|
||||
| stable_baselines3/common/envs/identity_env.py$
|
||||
| stable_baselines3/common/envs/multi_input_envs.py$
|
||||
| stable_baselines3/common/logger.py$
|
||||
| stable_baselines3/common/monitor.py$
|
||||
| stable_baselines3/common/off_policy_algorithm.py$
|
||||
| stable_baselines3/common/on_policy_algorithm.py$
|
||||
| stable_baselines3/common/policies.py$
|
||||
| stable_baselines3/common/preprocessing.py$
|
||||
| stable_baselines3/common/save_util.py$
|
||||
| stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$
|
||||
| stable_baselines3/common/torch_layers.py$
|
||||
| stable_baselines3/common/type_aliases.py$
|
||||
| stable_baselines3/common/utils.py$
|
||||
| stable_baselines3/common/vec_env/__init__.py$
|
||||
| stable_baselines3/common/vec_env/base_vec_env.py$
|
||||
| stable_baselines3/common/vec_env/dummy_vec_env.py$
|
||||
| stable_baselines3/common/vec_env/stacked_observations.py$
|
||||
| stable_baselines3/common/vec_env/subproc_vec_env.py$
|
||||
| stable_baselines3/common/vec_env/util.py$
|
||||
| stable_baselines3/common/vec_env/vec_check_nan.py$
|
||||
| stable_baselines3/common/vec_env/vec_extract_dict_obs.py$
|
||||
| stable_baselines3/common/vec_env/vec_frame_stack.py$
|
||||
| stable_baselines3/common/vec_env/vec_monitor.py$
|
||||
| stable_baselines3/common/vec_env/vec_normalize.py$
|
||||
| stable_baselines3/common/vec_env/vec_transpose.py$
|
||||
| stable_baselines3/common/vec_env/vec_video_recorder.py$
|
||||
| stable_baselines3/dqn/dqn.py$
|
||||
| stable_baselines3/dqn/policies.py$
|
||||
| stable_baselines3/her/her_replay_buffer.py$
|
||||
| stable_baselines3/ppo/ppo.py$
|
||||
| stable_baselines3/sac/policies.py$
|
||||
| stable_baselines3/sac/sac.py$
|
||||
| stable_baselines3/td3/policies.py$
|
||||
| stable_baselines3/td3/td3.py$
|
||||
| tests/test_distributions.py$
|
||||
| tests/test_logger.py$
|
||||
| tests/test_tensorboard.py$
|
||||
| tests/test_train_eval_mode.py$
|
||||
| tests/test_vec_normalize.py$
|
||||
)
|
||||
|
||||
[flake8]
|
||||
ignore = W503,W504,E203,E231 # line breaks before and after binary operators
|
||||
# Ignore import not used when aliases are defined
|
||||
|
|
|
|||
1
setup.py
1
setup.py
|
|
@ -95,6 +95,7 @@ setup(
|
|||
"pytest-xdist",
|
||||
# Type check
|
||||
"pytype",
|
||||
"mypy",
|
||||
# Lint code
|
||||
"flake8>=3.8",
|
||||
# Find likely bugs
|
||||
|
|
|
|||
|
|
@ -149,6 +149,7 @@ def test_evaluate_policy(direct_policy: bool):
|
|||
def dummy_callback(locals_, _globals):
|
||||
locals_["model"].n_callback_calls += 1
|
||||
|
||||
assert model.policy is not None
|
||||
policy = model.policy if direct_policy else model
|
||||
policy.n_callback_calls = 0
|
||||
_, episode_lengths = evaluate_policy(
|
||||
|
|
|
|||
Loading…
Reference in a new issue