mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
Use MPS device when available
This commit is contained in:
parent
d64bcb401a
commit
ace05162c5
3 changed files with 28 additions and 9 deletions
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.5.1a9 (WIP)
|
||||
Release 1.5.1a10 (WIP)
|
||||
---------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -17,6 +17,8 @@ Breaking Changes:
|
|||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Use MacOS Metal "mps" device when available
|
||||
- Save cloudpickle version
|
||||
|
||||
SB3-Contrib
|
||||
^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from collections import deque
|
|||
from itertools import zip_longest
|
||||
from typing import Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
import cloudpickle
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -135,19 +136,20 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device:
|
|||
"""
|
||||
Retrieve PyTorch device.
|
||||
It checks that the requested device is available first.
|
||||
For now, it supports only cpu and cuda.
|
||||
By default, it tries to use the gpu.
|
||||
For now, it supports only CPU and CUDA.
|
||||
By default, it tries to use the GPU.
|
||||
|
||||
:param device: One for 'auto', 'cuda', 'cpu'
|
||||
:param device: One of "auto", "cuda", "cpu",
|
||||
or any PyTorch supported device (for instance "mps")
|
||||
:return:
|
||||
"""
|
||||
# Cuda by default
|
||||
# MPS/CUDA by default
|
||||
if device == "auto":
|
||||
device = "cuda"
|
||||
device = get_available_accelerator()
|
||||
# Force conversion to th.device
|
||||
device = th.device(device)
|
||||
|
||||
# Cuda not available
|
||||
# CUDA not available
|
||||
if device.type == th.device("cuda").type and not th.cuda.is_available():
|
||||
return th.device("cpu")
|
||||
|
||||
|
|
@ -483,6 +485,20 @@ def should_collect_more_steps(
|
|||
)
|
||||
|
||||
|
||||
def get_available_accelerator() -> str:
|
||||
"""
|
||||
Return the available accelerator
|
||||
(currently checking only for CUDA and MPS device)
|
||||
"""
|
||||
if hasattr(th, "has_mps") and th.backends.mps.is_available():
|
||||
# MacOS Metal GPU
|
||||
return "mps"
|
||||
elif th.cuda.is_available():
|
||||
return "cuda"
|
||||
else:
|
||||
return "cpu"
|
||||
|
||||
|
||||
def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
|
||||
"""
|
||||
Retrieve system and python env info for the current system.
|
||||
|
|
@ -496,9 +512,10 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
|
|||
"Python": platform.python_version(),
|
||||
"Stable-Baselines3": sb3.__version__,
|
||||
"PyTorch": th.__version__,
|
||||
"GPU Enabled": str(th.cuda.is_available()),
|
||||
"Accelerator": get_available_accelerator(),
|
||||
"Numpy": np.__version__,
|
||||
"Gym": gym.__version__,
|
||||
"Cloudpickle": cloudpickle.__version__,
|
||||
}
|
||||
env_info_str = ""
|
||||
for key, value in env_info.items():
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.5.1a9
|
||||
1.5.1a10
|
||||
|
|
|
|||
Loading…
Reference in a new issue