mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Add warning when using PPO on GPU and update doc (#2017)
* Update documentation Added comment to PPO documentation that CPU should primarily be used unless using CNN as well as sample code. Added warning to user for both PPO and A2C that CPU should be used if the user is running GPU without using a CNN, reference Issue #1245. * Add warning to base class and add test --------- Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
512eea923a
commit
56c153f048
5 changed files with 56 additions and 4 deletions
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.4.0a9 (WIP)
|
||||
Release 2.4.0a10 (WIP)
|
||||
--------------------------
|
||||
|
||||
.. note::
|
||||
|
|
@ -60,12 +60,14 @@ Others:
|
|||
- Fixed various typos (@cschindlbeck)
|
||||
- Remove unnecessary SDE noise resampling in PPO update (@brn-dev)
|
||||
- Updated PyTorch version on CI to 2.3.1
|
||||
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- Updated PPO doc to recommend using CPU with ``MlpPolicy``
|
||||
|
||||
Release 2.3.2 (2024-04-27)
|
||||
--------------------------
|
||||
|
|
|
|||
|
|
@ -88,6 +88,23 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
|
|||
vec_env.render("human")
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
PPO is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``:
|
||||
|
||||
.. code-block::
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
|
||||
if __name__=="__main__":
|
||||
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
|
||||
model = PPO("MlpPolicy", env, device="cpu")
|
||||
model.learn(total_timesteps=25_000)
|
||||
|
||||
For more information, see :ref:`Vectorized Environments <vec_env>`, `Issue #1245 <https://github.com/DLR-RM/stable-baselines3/issues/1245#issuecomment-1435766949>`_ or the `Multiprocessing notebook <https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb>`_.
|
||||
|
||||
Results
|
||||
-------
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -135,6 +136,28 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
|
||||
)
|
||||
self.policy = self.policy.to(self.device)
|
||||
# Warn when not using CPU with MlpPolicy
|
||||
self._maybe_recommend_cpu()
|
||||
|
||||
def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None:
|
||||
"""
|
||||
Recommend to use CPU only when using A2C/PPO with MlpPolicy.
|
||||
|
||||
:param: The name of the class for the default MlpPolicy.
|
||||
"""
|
||||
policy_class_name = self.policy_class.__name__
|
||||
if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
|
||||
warnings.warn(
|
||||
f"You are trying to run {self.__class__.__name__} on the GPU, "
|
||||
"but it is primarily intended to run on the CPU when not using a CNN policy "
|
||||
f"(you are using {policy_class_name} which should be a MlpPolicy). "
|
||||
"See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
|
||||
"for more info. "
|
||||
"You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
|
||||
"Note: The model will train, but the GPU utilization will be poor and "
|
||||
"the training might take longer than on CPU.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
def collect_rollouts(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.4.0a9
|
||||
2.4.0a10
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
|
||||
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
|
|
@ -211,8 +212,11 @@ def test_warn_dqn_multi_env():
|
|||
|
||||
|
||||
def test_ppo_warnings():
|
||||
"""Test that PPO warns and errors correctly on
|
||||
problematic rollout buffer sizes"""
|
||||
"""
|
||||
Test that PPO warns and errors correctly on
|
||||
problematic rollout buffer sizes,
|
||||
and recommend using CPU.
|
||||
"""
|
||||
|
||||
# Only 1 step: advantage normalization will return NaN
|
||||
with pytest.raises(AssertionError):
|
||||
|
|
@ -234,3 +238,9 @@ def test_ppo_warnings():
|
|||
loss = model.logger.name_to_value["train/loss"]
|
||||
assert loss > 0
|
||||
assert not np.isnan(loss) # check not nan (since nan does not equal nan)
|
||||
|
||||
with pytest.warns(UserWarning, match="You are trying to run PPO on the GPU"):
|
||||
model = PPO("MlpPolicy", "Pendulum-v1")
|
||||
# Pretend to be on the GPU
|
||||
model.device = th.device("cuda")
|
||||
model._maybe_recommend_cpu()
|
||||
|
|
|
|||
Loading…
Reference in a new issue