mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-21 22:00:21 +00:00
Switch to pyproject.toml and ruff (#1361)
* Switch to `pyproject.toml` and `ruff` * Fix for Atari ROMs and mypy * Switch order in CI, lint first
This commit is contained in:
parent
f0382a25bd
commit
10e83865ec
10 changed files with 92 additions and 89 deletions
19
.github/workflows/ci.yml
vendored
19
.github/workflows/ci.yml
vendored
|
|
@ -14,6 +14,8 @@ jobs:
|
|||
env:
|
||||
TERM: xterm-256color
|
||||
FORCE_COLOR: 1
|
||||
ATARI_ROMS: ${{ secrets.ATARI_ROMS }}
|
||||
|
||||
# Skip CI if [ci skip] in the commit message
|
||||
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
|
||||
runs-on: ubuntu-latest
|
||||
|
|
@ -32,21 +34,28 @@ jobs:
|
|||
python -m pip install --upgrade pip
|
||||
# cpu version of pytorch
|
||||
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
|
||||
# Install Atari Roms
|
||||
pip install autorom
|
||||
wget $ATARI_ROMS
|
||||
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
|
||||
AutoROM --accept-license --source-file Roms.tar.gz
|
||||
|
||||
pip install .[extra,tests,docs]
|
||||
# Use headless version
|
||||
pip install opencv-python-headless
|
||||
- name: Lint with ruff
|
||||
run: |
|
||||
make lint
|
||||
- name: Build the doc
|
||||
run: |
|
||||
make doc
|
||||
- name: Type check
|
||||
run: |
|
||||
make type
|
||||
- name: Check codestyle
|
||||
run: |
|
||||
make check-codestyle
|
||||
- name: Lint with flake8
|
||||
- name: Type check
|
||||
run: |
|
||||
make lint
|
||||
make type
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make pytest
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ Type checking with `pytype` and `mypy`:
|
|||
make type
|
||||
```
|
||||
|
||||
Codestyle check with `black`, `isort` and `flake8`:
|
||||
Codestyle check with `black`, `isort` and `ruff`:
|
||||
|
||||
```
|
||||
make check-codestyle
|
||||
|
|
|
|||
15
Makefile
15
Makefile
|
|
@ -14,29 +14,22 @@ type: pytype mypy
|
|||
|
||||
lint:
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
# see https://lintlyci.github.io/Flake8Rules/
|
||||
flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# exit-zero treats all errors as warnings.
|
||||
flake8 ${LINT_PATHS} --count --exit-zero --statistics
|
||||
|
||||
ruff:
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
# see https://lintlyci.github.io/Flake8Rules/
|
||||
# see https://www.flake8rules.com/
|
||||
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
|
||||
# exit-zero treats all errors as warnings.
|
||||
ruff ${LINT_PATHS} --exit-zero --line-length 127
|
||||
ruff ${LINT_PATHS} --exit-zero
|
||||
|
||||
format:
|
||||
# Sort imports
|
||||
isort ${LINT_PATHS}
|
||||
# Reformat using black
|
||||
black -l 127 ${LINT_PATHS}
|
||||
black ${LINT_PATHS}
|
||||
|
||||
check-codestyle:
|
||||
# Sort imports
|
||||
isort --check ${LINT_PATHS}
|
||||
# Reformat using black
|
||||
black --check -l 127 ${LINT_PATHS}
|
||||
black --check ${LINT_PATHS}
|
||||
|
||||
commit-checks: format type lint
|
||||
|
||||
|
|
|
|||
|
|
@ -206,9 +206,9 @@ pip install pytype mypy
|
|||
make type
|
||||
```
|
||||
|
||||
Codestyle check with `flake8`:
|
||||
Codestyle check with `ruff`:
|
||||
```
|
||||
pip install flake8
|
||||
pip install ruff
|
||||
make lint
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.8.0a7 (WIP)
|
||||
Release 1.8.0a8 (WIP)
|
||||
--------------------------
|
||||
|
||||
|
||||
|
|
@ -43,6 +43,9 @@ Others:
|
|||
- Fixed ``stable_baselines3/common/monitor.py`` type hint
|
||||
- Added tests for StackedObservations
|
||||
- Removed Gitlab CI file
|
||||
- Moved from ``setup.cg`` to ``pyproject.toml`` configuration file
|
||||
- Switched from ``flake8`` to ``ruff``
|
||||
- Upgraded AutoROM to latest version
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -1,35 +1,39 @@
|
|||
[metadata]
|
||||
# This includes the license file in the wheel.
|
||||
license_files = LICENSE
|
||||
project_urls =
|
||||
Code = https://github.com/DLR-RM/stable-baselines3
|
||||
Documentation = https://stable-baselines3.readthedocs.io/
|
||||
[tool.ruff]
|
||||
# Same as Black.
|
||||
line-length = 127
|
||||
# Assume Python 3.7
|
||||
target-version = "py37"
|
||||
# TODO(antonin): activate "RUF" https://beta.ruff.rs/docs/rules/#ruff-specific-rules-ruf
|
||||
select = ["E", "F", "B", "UP", "C90"]
|
||||
ignore = []
|
||||
|
||||
[tool:pytest]
|
||||
# Deterministic ordering for tests; useful for pytest-xdist.
|
||||
env =
|
||||
PYTHONHASHSEED=0
|
||||
filterwarnings =
|
||||
# Tensorboard warnings
|
||||
ignore::DeprecationWarning:tensorboard
|
||||
# Gym warnings
|
||||
ignore:Parameters to load are deprecated.:DeprecationWarning
|
||||
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
|
||||
ignore::UserWarning:gym
|
||||
ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning
|
||||
ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning
|
||||
markers =
|
||||
expensive: marks tests as expensive (deselect with '-m "not expensive"')
|
||||
[tool.ruff.per-file-ignores]
|
||||
# Default implementation in abstract methods
|
||||
"./stable_baselines3/common/callbacks.py"= ["B027"]
|
||||
"./stable_baselines3/common/noise.py"= ["B027"]
|
||||
|
||||
[pytype]
|
||||
inputs = stable_baselines3
|
||||
disable = pyi-error
|
||||
|
||||
[mypy]
|
||||
ignore_missing_imports = True
|
||||
follow_imports = silent
|
||||
show_error_codes = True
|
||||
exclude = (?x)(
|
||||
[tool.ruff.mccabe]
|
||||
# Unlike Flake8, default to a complexity level of 10.
|
||||
max-complexity = 15
|
||||
|
||||
[tool.black]
|
||||
line-length = 127
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 127
|
||||
src_paths = ["stable_baselines3"]
|
||||
|
||||
[tool.pytype]
|
||||
inputs = ["stable_baselines3"]
|
||||
disable = ["pyi-error"]
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true
|
||||
follow_imports = "silent"
|
||||
show_error_codes = true
|
||||
exclude = """(?x)(
|
||||
stable_baselines3/a2c/a2c.py$
|
||||
| stable_baselines3/common/base_class.py$
|
||||
| stable_baselines3/common/buffers.py$
|
||||
|
|
@ -66,34 +70,24 @@ exclude = (?x)(
|
|||
| stable_baselines3/td3/td3.py$
|
||||
| tests/test_logger.py$
|
||||
| tests/test_train_eval_mode.py$
|
||||
)
|
||||
)"""
|
||||
|
||||
[flake8]
|
||||
# line breaks before and after binary operators
|
||||
# ignore explicit stack level
|
||||
ignore = W503,W504,E203,E231,B028
|
||||
# Ignore import not used when aliases are defined
|
||||
per-file-ignores =
|
||||
# Default implementation in abstract methods
|
||||
./stable_baselines3/common/callbacks.py:B027
|
||||
./stable_baselines3/common/noise.py:B027
|
||||
exclude =
|
||||
# No need to traverse our git directory
|
||||
.git,
|
||||
# There's no value in checking cache directories
|
||||
__pycache__,
|
||||
# Don't check the doc
|
||||
docs/
|
||||
# This contains our built documentation
|
||||
build,
|
||||
# This contains builds of flake8 that we don't want to check
|
||||
dist
|
||||
*.egg-info
|
||||
max-complexity = 15
|
||||
# The GitHub editor is 127 chars wide
|
||||
max-line-length = 127
|
||||
[tool.pytest.ini_options]
|
||||
# Deterministic ordering for tests; useful for pytest-xdist.
|
||||
env = [
|
||||
"PYTHONHASHSEED=0"
|
||||
]
|
||||
|
||||
[isort]
|
||||
profile = black
|
||||
line_length = 127
|
||||
src_paths = stable_baselines3
|
||||
filterwarnings = [
|
||||
# Tensorboard warnings
|
||||
"ignore::DeprecationWarning:tensorboard",
|
||||
# Gym warnings
|
||||
"ignore:Parameters to load are deprecated.:DeprecationWarning",
|
||||
"ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning",
|
||||
"ignore::UserWarning:gym",
|
||||
"ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning",
|
||||
"ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning",
|
||||
]
|
||||
markers = [
|
||||
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')"
|
||||
]
|
||||
16
setup.py
16
setup.py
|
|
@ -86,7 +86,7 @@ setup(
|
|||
"pandas",
|
||||
# Plotting learning curves
|
||||
"matplotlib",
|
||||
# gym and flake8 not compatible with importlib-metadata>5.0
|
||||
# gym not compatible with importlib-metadata>5.0
|
||||
"importlib-metadata~=4.13",
|
||||
],
|
||||
extras_require={
|
||||
|
|
@ -99,10 +99,8 @@ setup(
|
|||
# Type check
|
||||
"pytype",
|
||||
"mypy",
|
||||
# Lint code
|
||||
"flake8>=3.8",
|
||||
# Find likely bugs
|
||||
"flake8-bugbear",
|
||||
# Lint code (flake8 replacement)
|
||||
"ruff",
|
||||
# Sort imports
|
||||
"isort>=5.0",
|
||||
# Reformat
|
||||
|
|
@ -126,7 +124,7 @@ setup(
|
|||
"opencv-python",
|
||||
# For atari games,
|
||||
"ale-py==0.7.4",
|
||||
"autorom[accept-rom-license]~=0.4.2",
|
||||
"autorom[accept-rom-license]~=0.5.5",
|
||||
"pillow",
|
||||
# Tensorboard support
|
||||
"tensorboard>=2.9.1",
|
||||
|
|
@ -149,6 +147,12 @@ setup(
|
|||
version=__version__,
|
||||
python_requires=">=3.7",
|
||||
# PyPI package information.
|
||||
project_urls={
|
||||
"Code": "https://github.com/DLR-RM/stable-baselines3",
|
||||
"Documentation": "https://stable-baselines3.readthedocs.io/",
|
||||
"SB3-Contrib": "https://github.com/Stable-Baselines-Team/stable-baselines3-contrib",
|
||||
"RL-Zoo": "https://github.com/DLR-RM/rl-baselines3-zoo",
|
||||
},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
|
|
|
|||
|
|
@ -617,8 +617,8 @@ class BaseAlgorithm(ABC):
|
|||
f"expected {objects_needing_update}, got {updated_objects}"
|
||||
)
|
||||
|
||||
@classmethod # noqa: C901
|
||||
def load(
|
||||
@classmethod
|
||||
def load( # noqa: C901
|
||||
cls: Type[SelfBaseAlgorithm],
|
||||
path: Union[str, pathlib.Path, io.BufferedIOBase],
|
||||
env: Optional[GymEnv] = None,
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36):
|
||||
self.max_length = max_length
|
||||
if isinstance(filename_or_file, str):
|
||||
self.file = open(filename_or_file, "wt")
|
||||
self.file = open(filename_or_file, "w")
|
||||
self.own_file = True
|
||||
else:
|
||||
assert hasattr(filename_or_file, "write"), f"Expected file or str, got {filename_or_file}"
|
||||
|
|
@ -283,7 +283,7 @@ class JSONOutputFormat(KVWriter):
|
|||
"""
|
||||
|
||||
def __init__(self, filename: str):
|
||||
self.file = open(filename, "wt")
|
||||
self.file = open(filename, "w")
|
||||
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
|
||||
def cast_to_json_serializable(value: Any):
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.8.0a7
|
||||
1.8.0a8
|
||||
|
|
|
|||
Loading…
Reference in a new issue