diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 8928233..97ad152 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -521,7 +521,7 @@ def get_available_accelerator() -> str: Return the available accelerator (currently checking only for CUDA and MPS device) """ - if hasattr(th, "has_mps") and th.backends.mps.is_available(): + if hasattr(th, "has_mps") and th.backends.mps.is_built(): # MacOS Metal GPU return "mps" elif th.cuda.is_available():