Use MPS device when available

This commit is contained in:
Antonin Raffin 2022-07-04 14:51:46 +02:00
parent d64bcb401a
commit ace05162c5
No known key found for this signature in database
GPG key ID: B8B48F65CAD6232C
3 changed files with 28 additions and 9 deletions

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.5.1a9 (WIP)
Release 1.5.1a10 (WIP)
---------------------------
Breaking Changes:
@ -17,6 +17,8 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Use MacOS Metal "mps" device when available
- Save cloudpickle version
SB3-Contrib
^^^^^^^^^^^

View file

@ -6,6 +6,7 @@ from collections import deque
from itertools import zip_longest
from typing import Dict, Iterable, Optional, Tuple, Union
import cloudpickle
import gym
import numpy as np
import torch as th
@ -135,19 +136,20 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device:
"""
Retrieve PyTorch device.
It checks that the requested device is available first.
For now, it supports only cpu and cuda.
By default, it tries to use the gpu.
For now, it supports only CPU and CUDA.
By default, it tries to use the GPU.
:param device: One for 'auto', 'cuda', 'cpu'
:param device: One of "auto", "cuda", "cpu",
or any PyTorch supported device (for instance "mps")
:return:
"""
# Cuda by default
# MPS/CUDA by default
if device == "auto":
device = "cuda"
device = get_available_accelerator()
# Force conversion to th.device
device = th.device(device)
# Cuda not available
# CUDA not available
if device.type == th.device("cuda").type and not th.cuda.is_available():
return th.device("cpu")
@ -483,6 +485,20 @@ def should_collect_more_steps(
)
def get_available_accelerator() -> str:
"""
Return the available accelerator
(currently checking only for CUDA and MPS device)
"""
if hasattr(th, "has_mps") and th.backends.mps.is_available():
# MacOS Metal GPU
return "mps"
elif th.cuda.is_available():
return "cuda"
else:
return "cpu"
def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
"""
Retrieve system and python env info for the current system.
@ -496,9 +512,10 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
"Python": platform.python_version(),
"Stable-Baselines3": sb3.__version__,
"PyTorch": th.__version__,
"GPU Enabled": str(th.cuda.is_available()),
"Accelerator": get_available_accelerator(),
"Numpy": np.__version__,
"Gym": gym.__version__,
"Cloudpickle": cloudpickle.__version__,
}
env_info_str = ""
for key, value in env_info.items():

View file

@ -1 +1 @@
1.5.1a9
1.5.1a10