diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 822e0cb..cb90552 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,18 +31,21 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + # Use uv for faster downloads + pip install uv # cpu version of pytorch - pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu + # See https://github.com/astral-sh/uv/issues/1497 + uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu # Install Atari Roms - pip install autorom + uv pip install --system autorom wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz AutoROM --accept-license --source-file Roms.tar.gz - pip install .[extra_no_roms,tests,docs] + uv pip install --system .[extra_no_roms,tests,docs] # Use headless version - pip install opencv-python-headless + uv pip install --system opencv-python-headless - name: Lint with ruff run: | make lint diff --git a/docs/guide/sb3_contrib.rst b/docs/guide/sb3_contrib.rst index 445832c..8ec912e 100644 --- a/docs/guide/sb3_contrib.rst +++ b/docs/guide/sb3_contrib.rst @@ -42,6 +42,7 @@ See documentation for the full list of included features. - `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) `_ - `Truncated Quantile Critics (TQC)`_ - `Trust Region Policy Optimization (TRPO) `_ +- `Batch Normalization in Deep Reinforcement Learning (CrossQ) `_ **Gym Wrappers**: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index af83d23..2c0974a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -6,6 +6,8 @@ Changelog Release 2.4.0a10 (WIP) -------------------------- +**New algorithm: CrossQ in SB3 Contrib** + .. note:: DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about @@ -43,6 +45,10 @@ Bug Fixes: `SB3-Contrib`_ ^^^^^^^^^^^^^^ +- Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen) +- Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen) +- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger) +- Fixed loading QRDQN changes `target_update_interval` (@jak3122) `RL Zoo`_ ^^^^^^^^^ @@ -61,6 +67,7 @@ Others: - Remove unnecessary SDE noise resampling in PPO update (@brn-dev) - Updated PyTorch version on CI to 2.3.1 - Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy`` +- Switched to uv to download packages faster on GitHub CI Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index bcde1cf..4e9fbc2 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -46,7 +46,7 @@ def set_random_seed(seed: int, using_cuda: bool = False) -> None: # From stable baselines -def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: +def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> float: """ Computes fraction of variance that ypred explains about y. Returns 1 - Var[y-ypred] / Var[y] @@ -62,7 +62,7 @@ def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: """ assert y_true.ndim == 1 and y_pred.ndim == 1 var_y = np.var(y_true) - return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + return np.nan if var_y == 0 else float(1 - np.var(y_true - y_pred) / var_y) def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None: diff --git a/tests/test_utils.py b/tests/test_utils.py index 4cc8b7e..81f1341 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -177,7 +177,7 @@ def test_custom_vec_env(tmp_path): @pytest.mark.parametrize("direct_policy", [False, True]) -def test_evaluate_policy(direct_policy: bool): +def test_evaluate_policy(direct_policy): model = A2C("MlpPolicy", "Pendulum-v1", seed=0) n_steps_per_episode, n_eval_episodes = 200, 2