Switch to pyproject.toml and ruff (#1361)

* Switch to `pyproject.toml` and `ruff`

* Fix for Atari ROMs and mypy

* Switch order in CI, lint first
This commit is contained in:
Antonin RAFFIN 2023-03-11 22:15:26 +01:00 committed by GitHub
parent f0382a25bd
commit 10e83865ec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 92 additions and 89 deletions

View file

@ -14,6 +14,8 @@ jobs:
env:
TERM: xterm-256color
FORCE_COLOR: 1
ATARI_ROMS: ${{ secrets.ATARI_ROMS }}
# Skip CI if [ci skip] in the commit message
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
runs-on: ubuntu-latest
@ -32,21 +34,28 @@ jobs:
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
# Install Atari Roms
pip install autorom
wget $ATARI_ROMS
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
pip install .[extra,tests,docs]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Type check
run: |
make type
- name: Check codestyle
run: |
make check-codestyle
- name: Lint with flake8
- name: Type check
run: |
make lint
make type
- name: Test with pytest
run: |
make pytest

View file

@ -85,7 +85,7 @@ Type checking with `pytype` and `mypy`:
make type
```
Codestyle check with `black`, `isort` and `flake8`:
Codestyle check with `black`, `isort` and `ruff`:
```
make check-codestyle

View file

@ -14,29 +14,22 @@ type: pytype mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://lintlyci.github.io/Flake8Rules/
flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings.
flake8 ${LINT_PATHS} --count --exit-zero --statistics
ruff:
# stop the build if there are Python syntax errors or undefined names
# see https://lintlyci.github.io/Flake8Rules/
# see https://www.flake8rules.com/
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
# exit-zero treats all errors as warnings.
ruff ${LINT_PATHS} --exit-zero --line-length 127
ruff ${LINT_PATHS} --exit-zero
format:
# Sort imports
isort ${LINT_PATHS}
# Reformat using black
black -l 127 ${LINT_PATHS}
black ${LINT_PATHS}
check-codestyle:
# Sort imports
isort --check ${LINT_PATHS}
# Reformat using black
black --check -l 127 ${LINT_PATHS}
black --check ${LINT_PATHS}
commit-checks: format type lint

View file

@ -206,9 +206,9 @@ pip install pytype mypy
make type
```
Codestyle check with `flake8`:
Codestyle check with `ruff`:
```
pip install flake8
pip install ruff
make lint
```

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.8.0a7 (WIP)
Release 1.8.0a8 (WIP)
--------------------------
@ -43,6 +43,9 @@ Others:
- Fixed ``stable_baselines3/common/monitor.py`` type hint
- Added tests for StackedObservations
- Removed Gitlab CI file
- Moved from ``setup.cg`` to ``pyproject.toml`` configuration file
- Switched from ``flake8`` to ``ruff``
- Upgraded AutoROM to latest version
Documentation:
^^^^^^^^^^^^^^

View file

@ -1,35 +1,39 @@
[metadata]
# This includes the license file in the wheel.
license_files = LICENSE
project_urls =
Code = https://github.com/DLR-RM/stable-baselines3
Documentation = https://stable-baselines3.readthedocs.io/
[tool.ruff]
# Same as Black.
line-length = 127
# Assume Python 3.7
target-version = "py37"
# TODO(antonin): activate "RUF" https://beta.ruff.rs/docs/rules/#ruff-specific-rules-ruf
select = ["E", "F", "B", "UP", "C90"]
ignore = []
[tool:pytest]
# Deterministic ordering for tests; useful for pytest-xdist.
env =
PYTHONHASHSEED=0
filterwarnings =
# Tensorboard warnings
ignore::DeprecationWarning:tensorboard
# Gym warnings
ignore:Parameters to load are deprecated.:DeprecationWarning
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
ignore::UserWarning:gym
ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning
ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning
markers =
expensive: marks tests as expensive (deselect with '-m "not expensive"')
[tool.ruff.per-file-ignores]
# Default implementation in abstract methods
"./stable_baselines3/common/callbacks.py"= ["B027"]
"./stable_baselines3/common/noise.py"= ["B027"]
[pytype]
inputs = stable_baselines3
disable = pyi-error
[mypy]
ignore_missing_imports = True
follow_imports = silent
show_error_codes = True
exclude = (?x)(
[tool.ruff.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 15
[tool.black]
line-length = 127
[tool.isort]
profile = "black"
line_length = 127
src_paths = ["stable_baselines3"]
[tool.pytype]
inputs = ["stable_baselines3"]
disable = ["pyi-error"]
[tool.mypy]
ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
stable_baselines3/a2c/a2c.py$
| stable_baselines3/common/base_class.py$
| stable_baselines3/common/buffers.py$
@ -66,34 +70,24 @@ exclude = (?x)(
| stable_baselines3/td3/td3.py$
| tests/test_logger.py$
| tests/test_train_eval_mode.py$
)
)"""
[flake8]
# line breaks before and after binary operators
# ignore explicit stack level
ignore = W503,W504,E203,E231,B028
# Ignore import not used when aliases are defined
per-file-ignores =
# Default implementation in abstract methods
./stable_baselines3/common/callbacks.py:B027
./stable_baselines3/common/noise.py:B027
exclude =
# No need to traverse our git directory
.git,
# There's no value in checking cache directories
__pycache__,
# Don't check the doc
docs/
# This contains our built documentation
build,
# This contains builds of flake8 that we don't want to check
dist
*.egg-info
max-complexity = 15
# The GitHub editor is 127 chars wide
max-line-length = 127
[tool.pytest.ini_options]
# Deterministic ordering for tests; useful for pytest-xdist.
env = [
"PYTHONHASHSEED=0"
]
[isort]
profile = black
line_length = 127
src_paths = stable_baselines3
filterwarnings = [
# Tensorboard warnings
"ignore::DeprecationWarning:tensorboard",
# Gym warnings
"ignore:Parameters to load are deprecated.:DeprecationWarning",
"ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning",
"ignore::UserWarning:gym",
"ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning",
"ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning",
]
markers = [
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')"
]

View file

@ -86,7 +86,7 @@ setup(
"pandas",
# Plotting learning curves
"matplotlib",
# gym and flake8 not compatible with importlib-metadata>5.0
# gym not compatible with importlib-metadata>5.0
"importlib-metadata~=4.13",
],
extras_require={
@ -99,10 +99,8 @@ setup(
# Type check
"pytype",
"mypy",
# Lint code
"flake8>=3.8",
# Find likely bugs
"flake8-bugbear",
# Lint code (flake8 replacement)
"ruff",
# Sort imports
"isort>=5.0",
# Reformat
@ -126,7 +124,7 @@ setup(
"opencv-python",
# For atari games,
"ale-py==0.7.4",
"autorom[accept-rom-license]~=0.4.2",
"autorom[accept-rom-license]~=0.5.5",
"pillow",
# Tensorboard support
"tensorboard>=2.9.1",
@ -149,6 +147,12 @@ setup(
version=__version__,
python_requires=">=3.7",
# PyPI package information.
project_urls={
"Code": "https://github.com/DLR-RM/stable-baselines3",
"Documentation": "https://stable-baselines3.readthedocs.io/",
"SB3-Contrib": "https://github.com/Stable-Baselines-Team/stable-baselines3-contrib",
"RL-Zoo": "https://github.com/DLR-RM/rl-baselines3-zoo",
},
classifiers=[
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",

View file

@ -617,8 +617,8 @@ class BaseAlgorithm(ABC):
f"expected {objects_needing_update}, got {updated_objects}"
)
@classmethod # noqa: C901
def load(
@classmethod
def load( # noqa: C901
cls: Type[SelfBaseAlgorithm],
path: Union[str, pathlib.Path, io.BufferedIOBase],
env: Optional[GymEnv] = None,

View file

@ -161,7 +161,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36):
self.max_length = max_length
if isinstance(filename_or_file, str):
self.file = open(filename_or_file, "wt")
self.file = open(filename_or_file, "w")
self.own_file = True
else:
assert hasattr(filename_or_file, "write"), f"Expected file or str, got {filename_or_file}"
@ -283,7 +283,7 @@ class JSONOutputFormat(KVWriter):
"""
def __init__(self, filename: str):
self.file = open(filename, "wt")
self.file = open(filename, "w")
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
def cast_to_json_serializable(value: Any):

View file

@ -1 +1 @@
1.8.0a7
1.8.0a8