From ace05162c5b133b1b706ed5fcceecea6ccfdb004 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 4 Jul 2022 14:51:46 +0200 Subject: [PATCH] Use MPS device when available --- docs/misc/changelog.rst | 4 +++- stable_baselines3/common/utils.py | 31 ++++++++++++++++++++++++------- stable_baselines3/version.txt | 2 +- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 52bf3e4..576fbc8 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.5.1a9 (WIP) +Release 1.5.1a10 (WIP) --------------------------- Breaking Changes: @@ -17,6 +17,8 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Use MacOS Metal "mps" device when available +- Save cloudpickle version SB3-Contrib ^^^^^^^^^^^ diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 94cd658..fdd33a4 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -6,6 +6,7 @@ from collections import deque from itertools import zip_longest from typing import Dict, Iterable, Optional, Tuple, Union +import cloudpickle import gym import numpy as np import torch as th @@ -135,19 +136,20 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device: """ Retrieve PyTorch device. It checks that the requested device is available first. - For now, it supports only cpu and cuda. - By default, it tries to use the gpu. + For now, it supports only CPU and CUDA. + By default, it tries to use the GPU. - :param device: One for 'auto', 'cuda', 'cpu' + :param device: One of "auto", "cuda", "cpu", + or any PyTorch supported device (for instance "mps") :return: """ - # Cuda by default + # MPS/CUDA by default if device == "auto": - device = "cuda" + device = get_available_accelerator() # Force conversion to th.device device = th.device(device) - # Cuda not available + # CUDA not available if device.type == th.device("cuda").type and not th.cuda.is_available(): return th.device("cpu") @@ -483,6 +485,20 @@ def should_collect_more_steps( ) +def get_available_accelerator() -> str: + """ + Return the available accelerator + (currently checking only for CUDA and MPS device) + """ + if hasattr(th, "has_mps") and th.backends.mps.is_available(): + # MacOS Metal GPU + return "mps" + elif th.cuda.is_available(): + return "cuda" + else: + return "cpu" + + def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: """ Retrieve system and python env info for the current system. @@ -496,9 +512,10 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: "Python": platform.python_version(), "Stable-Baselines3": sb3.__version__, "PyTorch": th.__version__, - "GPU Enabled": str(th.cuda.is_available()), + "Accelerator": get_available_accelerator(), "Numpy": np.__version__, "Gym": gym.__version__, + "Cloudpickle": cloudpickle.__version__, } env_info_str = "" for key, value in env_info.items(): diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 125ec27..c43063f 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a9 +1.5.1a10