diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 822e0cb..cb90552 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -31,18 +31,21 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
+ # Use uv for faster downloads
+ pip install uv
# cpu version of pytorch
- pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
+ # See https://github.com/astral-sh/uv/issues/1497
+ uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
# Install Atari Roms
- pip install autorom
+ uv pip install --system autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
- pip install .[extra_no_roms,tests,docs]
+ uv pip install --system .[extra_no_roms,tests,docs]
# Use headless version
- pip install opencv-python-headless
+ uv pip install --system opencv-python-headless
- name: Lint with ruff
run: |
make lint
diff --git a/docs/guide/sb3_contrib.rst b/docs/guide/sb3_contrib.rst
index 445832c..8ec912e 100644
--- a/docs/guide/sb3_contrib.rst
+++ b/docs/guide/sb3_contrib.rst
@@ -42,6 +42,7 @@ See documentation for the full list of included features.
- `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) `_
- `Truncated Quantile Critics (TQC)`_
- `Trust Region Policy Optimization (TRPO) `_
+- `Batch Normalization in Deep Reinforcement Learning (CrossQ) `_
**Gym Wrappers**:
diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst
index af83d23..2c0974a 100644
--- a/docs/misc/changelog.rst
+++ b/docs/misc/changelog.rst
@@ -6,6 +6,8 @@ Changelog
Release 2.4.0a10 (WIP)
--------------------------
+**New algorithm: CrossQ in SB3 Contrib**
+
.. note::
DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about
@@ -43,6 +45,10 @@ Bug Fixes:
`SB3-Contrib`_
^^^^^^^^^^^^^^
+- Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen)
+- Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen)
+- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger)
+- Fixed loading QRDQN changes `target_update_interval` (@jak3122)
`RL Zoo`_
^^^^^^^^^
@@ -61,6 +67,7 @@ Others:
- Remove unnecessary SDE noise resampling in PPO update (@brn-dev)
- Updated PyTorch version on CI to 2.3.1
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``
+- Switched to uv to download packages faster on GitHub CI
Bug Fixes:
^^^^^^^^^^
diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py
index bcde1cf..4e9fbc2 100644
--- a/stable_baselines3/common/utils.py
+++ b/stable_baselines3/common/utils.py
@@ -46,7 +46,7 @@ def set_random_seed(seed: int, using_cuda: bool = False) -> None:
# From stable baselines
-def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
+def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> float:
"""
Computes fraction of variance that ypred explains about y.
Returns 1 - Var[y-ypred] / Var[y]
@@ -62,7 +62,7 @@ def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
"""
assert y_true.ndim == 1 and y_pred.ndim == 1
var_y = np.var(y_true)
- return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
+ return np.nan if var_y == 0 else float(1 - np.var(y_true - y_pred) / var_y)
def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None:
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 4cc8b7e..81f1341 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -177,7 +177,7 @@ def test_custom_vec_env(tmp_path):
@pytest.mark.parametrize("direct_policy", [False, True])
-def test_evaluate_policy(direct_policy: bool):
+def test_evaluate_policy(direct_policy):
model = A2C("MlpPolicy", "Pendulum-v1", seed=0)
n_steps_per_episode, n_eval_episodes = 200, 2