mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Use uv on GitHub CI for faster download and update changelog (#2026)
* Use uv on GitHub CI for faster download and update changelog * Fix new mypy issues
This commit is contained in:
parent
56c153f048
commit
3d59b5c86b
5 changed files with 18 additions and 7 deletions
11
.github/workflows/ci.yml
vendored
11
.github/workflows/ci.yml
vendored
|
|
@ -31,18 +31,21 @@ jobs:
|
|||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# Use uv for faster downloads
|
||||
pip install uv
|
||||
# cpu version of pytorch
|
||||
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
|
||||
# See https://github.com/astral-sh/uv/issues/1497
|
||||
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Install Atari Roms
|
||||
pip install autorom
|
||||
uv pip install --system autorom
|
||||
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
|
||||
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
|
||||
AutoROM --accept-license --source-file Roms.tar.gz
|
||||
|
||||
pip install .[extra_no_roms,tests,docs]
|
||||
uv pip install --system .[extra_no_roms,tests,docs]
|
||||
# Use headless version
|
||||
pip install opencv-python-headless
|
||||
uv pip install --system opencv-python-headless
|
||||
- name: Lint with ruff
|
||||
run: |
|
||||
make lint
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ See documentation for the full list of included features.
|
|||
- `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) <https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/>`_
|
||||
- `Truncated Quantile Critics (TQC)`_
|
||||
- `Trust Region Policy Optimization (TRPO) <https://arxiv.org/abs/1502.05477>`_
|
||||
- `Batch Normalization in Deep Reinforcement Learning (CrossQ) <https://openreview.net/forum?id=PczQtTsTIX>`_
|
||||
|
||||
|
||||
**Gym Wrappers**:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ Changelog
|
|||
Release 2.4.0a10 (WIP)
|
||||
--------------------------
|
||||
|
||||
**New algorithm: CrossQ in SB3 Contrib**
|
||||
|
||||
.. note::
|
||||
|
||||
DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about
|
||||
|
|
@ -43,6 +45,10 @@ Bug Fixes:
|
|||
|
||||
`SB3-Contrib`_
|
||||
^^^^^^^^^^^^^^
|
||||
- Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen)
|
||||
- Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen)
|
||||
- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger)
|
||||
- Fixed loading QRDQN changes `target_update_interval` (@jak3122)
|
||||
|
||||
`RL Zoo`_
|
||||
^^^^^^^^^
|
||||
|
|
@ -61,6 +67,7 @@ Others:
|
|||
- Remove unnecessary SDE noise resampling in PPO update (@brn-dev)
|
||||
- Updated PyTorch version on CI to 2.3.1
|
||||
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``
|
||||
- Switched to uv to download packages faster on GitHub CI
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ def set_random_seed(seed: int, using_cuda: bool = False) -> None:
|
|||
|
||||
|
||||
# From stable baselines
|
||||
def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
|
||||
def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> float:
|
||||
"""
|
||||
Computes fraction of variance that ypred explains about y.
|
||||
Returns 1 - Var[y-ypred] / Var[y]
|
||||
|
|
@ -62,7 +62,7 @@ def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
|
|||
"""
|
||||
assert y_true.ndim == 1 and y_pred.ndim == 1
|
||||
var_y = np.var(y_true)
|
||||
return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
||||
return np.nan if var_y == 0 else float(1 - np.var(y_true - y_pred) / var_y)
|
||||
|
||||
|
||||
def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None:
|
||||
|
|
|
|||
|
|
@ -177,7 +177,7 @@ def test_custom_vec_env(tmp_path):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("direct_policy", [False, True])
|
||||
def test_evaluate_policy(direct_policy: bool):
|
||||
def test_evaluate_policy(direct_policy):
|
||||
model = A2C("MlpPolicy", "Pendulum-v1", seed=0)
|
||||
n_steps_per_episode, n_eval_episodes = 200, 2
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue