From 10e83865ecb3cdcf0e203ee5e1e807d15539ce91 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 11 Mar 2023 22:15:26 +0100 Subject: [PATCH] Switch to `pyproject.toml` and `ruff` (#1361) * Switch to `pyproject.toml` and `ruff` * Fix for Atari ROMs and mypy * Switch order in CI, lint first --- .github/workflows/ci.yml | 19 +++-- CONTRIBUTING.md | 2 +- Makefile | 15 +--- README.md | 4 +- docs/misc/changelog.rst | 5 +- setup.cfg => pyproject.toml | 110 ++++++++++++------------- setup.py | 16 ++-- stable_baselines3/common/base_class.py | 4 +- stable_baselines3/common/logger.py | 4 +- stable_baselines3/version.txt | 2 +- 10 files changed, 92 insertions(+), 89 deletions(-) rename setup.cfg => pyproject.toml (52%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a834fdc..fc47e9d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,6 +14,8 @@ jobs: env: TERM: xterm-256color FORCE_COLOR: 1 + ATARI_ROMS: ${{ secrets.ATARI_ROMS }} + # Skip CI if [ci skip] in the commit message if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')" runs-on: ubuntu-latest @@ -32,21 +34,28 @@ jobs: python -m pip install --upgrade pip # cpu version of pytorch pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html + + # Install Atari Roms + pip install autorom + wget $ATARI_ROMS + base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz + AutoROM --accept-license --source-file Roms.tar.gz + pip install .[extra,tests,docs] # Use headless version pip install opencv-python-headless + - name: Lint with ruff + run: | + make lint - name: Build the doc run: | make doc - - name: Type check - run: | - make type - name: Check codestyle run: | make check-codestyle - - name: Lint with flake8 + - name: Type check run: | - make lint + make type - name: Test with pytest run: | make pytest diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index eb1d08f..62f47ce 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -85,7 +85,7 @@ Type checking with `pytype` and `mypy`: make type ``` -Codestyle check with `black`, `isort` and `flake8`: +Codestyle check with `black`, `isort` and `ruff`: ``` make check-codestyle diff --git a/Makefile b/Makefile index 6351162..29ac5e7 100644 --- a/Makefile +++ b/Makefile @@ -14,29 +14,22 @@ type: pytype mypy lint: # stop the build if there are Python syntax errors or undefined names - # see https://lintlyci.github.io/Flake8Rules/ - flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. - flake8 ${LINT_PATHS} --count --exit-zero --statistics - -ruff: - # stop the build if there are Python syntax errors or undefined names - # see https://lintlyci.github.io/Flake8Rules/ + # see https://www.flake8rules.com/ ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source # exit-zero treats all errors as warnings. - ruff ${LINT_PATHS} --exit-zero --line-length 127 + ruff ${LINT_PATHS} --exit-zero format: # Sort imports isort ${LINT_PATHS} # Reformat using black - black -l 127 ${LINT_PATHS} + black ${LINT_PATHS} check-codestyle: # Sort imports isort --check ${LINT_PATHS} # Reformat using black - black --check -l 127 ${LINT_PATHS} + black --check ${LINT_PATHS} commit-checks: format type lint diff --git a/README.md b/README.md index b487bd6..a77dad4 100644 --- a/README.md +++ b/README.md @@ -206,9 +206,9 @@ pip install pytype mypy make type ``` -Codestyle check with `flake8`: +Codestyle check with `ruff`: ``` -pip install flake8 +pip install ruff make lint ``` diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index bf2e60a..c6f8211 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.8.0a7 (WIP) +Release 1.8.0a8 (WIP) -------------------------- @@ -43,6 +43,9 @@ Others: - Fixed ``stable_baselines3/common/monitor.py`` type hint - Added tests for StackedObservations - Removed Gitlab CI file +- Moved from ``setup.cg`` to ``pyproject.toml`` configuration file +- Switched from ``flake8`` to ``ruff`` +- Upgraded AutoROM to latest version Documentation: ^^^^^^^^^^^^^^ diff --git a/setup.cfg b/pyproject.toml similarity index 52% rename from setup.cfg rename to pyproject.toml index 11bb464..461e727 100644 --- a/setup.cfg +++ b/pyproject.toml @@ -1,35 +1,39 @@ -[metadata] -# This includes the license file in the wheel. -license_files = LICENSE -project_urls = - Code = https://github.com/DLR-RM/stable-baselines3 - Documentation = https://stable-baselines3.readthedocs.io/ +[tool.ruff] +# Same as Black. +line-length = 127 +# Assume Python 3.7 +target-version = "py37" +# TODO(antonin): activate "RUF" https://beta.ruff.rs/docs/rules/#ruff-specific-rules-ruf +select = ["E", "F", "B", "UP", "C90"] +ignore = [] -[tool:pytest] -# Deterministic ordering for tests; useful for pytest-xdist. -env = - PYTHONHASHSEED=0 -filterwarnings = - # Tensorboard warnings - ignore::DeprecationWarning:tensorboard - # Gym warnings - ignore:Parameters to load are deprecated.:DeprecationWarning - ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning - ignore::UserWarning:gym - ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning - ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning -markers = - expensive: marks tests as expensive (deselect with '-m "not expensive"') +[tool.ruff.per-file-ignores] +# Default implementation in abstract methods +"./stable_baselines3/common/callbacks.py"= ["B027"] +"./stable_baselines3/common/noise.py"= ["B027"] -[pytype] -inputs = stable_baselines3 -disable = pyi-error -[mypy] -ignore_missing_imports = True -follow_imports = silent -show_error_codes = True -exclude = (?x)( +[tool.ruff.mccabe] +# Unlike Flake8, default to a complexity level of 10. +max-complexity = 15 + +[tool.black] +line-length = 127 + +[tool.isort] +profile = "black" +line_length = 127 +src_paths = ["stable_baselines3"] + +[tool.pytype] +inputs = ["stable_baselines3"] +disable = ["pyi-error"] + +[tool.mypy] +ignore_missing_imports = true +follow_imports = "silent" +show_error_codes = true +exclude = """(?x)( stable_baselines3/a2c/a2c.py$ | stable_baselines3/common/base_class.py$ | stable_baselines3/common/buffers.py$ @@ -66,34 +70,24 @@ exclude = (?x)( | stable_baselines3/td3/td3.py$ | tests/test_logger.py$ | tests/test_train_eval_mode.py$ - ) + )""" -[flake8] -# line breaks before and after binary operators -# ignore explicit stack level -ignore = W503,W504,E203,E231,B028 -# Ignore import not used when aliases are defined -per-file-ignores = - # Default implementation in abstract methods - ./stable_baselines3/common/callbacks.py:B027 - ./stable_baselines3/common/noise.py:B027 -exclude = - # No need to traverse our git directory - .git, - # There's no value in checking cache directories - __pycache__, - # Don't check the doc - docs/ - # This contains our built documentation - build, - # This contains builds of flake8 that we don't want to check - dist - *.egg-info -max-complexity = 15 -# The GitHub editor is 127 chars wide -max-line-length = 127 +[tool.pytest.ini_options] +# Deterministic ordering for tests; useful for pytest-xdist. +env = [ + "PYTHONHASHSEED=0" +] -[isort] -profile = black -line_length = 127 -src_paths = stable_baselines3 +filterwarnings = [ + # Tensorboard warnings + "ignore::DeprecationWarning:tensorboard", + # Gym warnings + "ignore:Parameters to load are deprecated.:DeprecationWarning", + "ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning", + "ignore::UserWarning:gym", + "ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning", + "ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning", +] +markers = [ + "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')" +] diff --git a/setup.py b/setup.py index 612b2be..dcff0a2 100644 --- a/setup.py +++ b/setup.py @@ -86,7 +86,7 @@ setup( "pandas", # Plotting learning curves "matplotlib", - # gym and flake8 not compatible with importlib-metadata>5.0 + # gym not compatible with importlib-metadata>5.0 "importlib-metadata~=4.13", ], extras_require={ @@ -99,10 +99,8 @@ setup( # Type check "pytype", "mypy", - # Lint code - "flake8>=3.8", - # Find likely bugs - "flake8-bugbear", + # Lint code (flake8 replacement) + "ruff", # Sort imports "isort>=5.0", # Reformat @@ -126,7 +124,7 @@ setup( "opencv-python", # For atari games, "ale-py==0.7.4", - "autorom[accept-rom-license]~=0.4.2", + "autorom[accept-rom-license]~=0.5.5", "pillow", # Tensorboard support "tensorboard>=2.9.1", @@ -149,6 +147,12 @@ setup( version=__version__, python_requires=">=3.7", # PyPI package information. + project_urls={ + "Code": "https://github.com/DLR-RM/stable-baselines3", + "Documentation": "https://stable-baselines3.readthedocs.io/", + "SB3-Contrib": "https://github.com/Stable-Baselines-Team/stable-baselines3-contrib", + "RL-Zoo": "https://github.com/DLR-RM/rl-baselines3-zoo", + }, classifiers=[ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.7", diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 17be67a..0ab03d8 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -617,8 +617,8 @@ class BaseAlgorithm(ABC): f"expected {objects_needing_update}, got {updated_objects}" ) - @classmethod # noqa: C901 - def load( + @classmethod + def load( # noqa: C901 cls: Type[SelfBaseAlgorithm], path: Union[str, pathlib.Path, io.BufferedIOBase], env: Optional[GymEnv] = None, diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index a8aa766..c379388 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -161,7 +161,7 @@ class HumanOutputFormat(KVWriter, SeqWriter): def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36): self.max_length = max_length if isinstance(filename_or_file, str): - self.file = open(filename_or_file, "wt") + self.file = open(filename_or_file, "w") self.own_file = True else: assert hasattr(filename_or_file, "write"), f"Expected file or str, got {filename_or_file}" @@ -283,7 +283,7 @@ class JSONOutputFormat(KVWriter): """ def __init__(self, filename: str): - self.file = open(filename, "wt") + self.file = open(filename, "w") def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: def cast_to_json_serializable(value: Any): diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 61ecd05..8daa30f 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.8.0a7 +1.8.0a8