mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "[BE][accelerator] formalize API name {current,set}_device_{idx => index} (#140542)"
This reverts commit fb02b40d27.
Reverted https://github.com/pytorch/pytorch/pull/140542 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but I need to revert this in order to revert https://github.com/pytorch/pytorch/pull/133572#issuecomment-2537204202 due to a conflict ([comment](https://github.com/pytorch/pytorch/pull/140542#issuecomment-2537253665))
This commit is contained in:
parent
de313f1155
commit
cd50bd8477
4 changed files with 31 additions and 46 deletions
|
|
@ -10,9 +10,7 @@ torch.accelerator
|
|||
device_count
|
||||
is_available
|
||||
current_accelerator
|
||||
set_device_index
|
||||
set_device_idx
|
||||
current_device_index
|
||||
current_device_idx
|
||||
set_stream
|
||||
current_stream
|
||||
|
|
|
|||
|
|
@ -27,23 +27,23 @@ class TestAccelerator(TestCase):
|
|||
with self.assertRaisesRegex(
|
||||
ValueError, "doesn't match the current accelerator"
|
||||
):
|
||||
torch.accelerator.set_device_index("cpu")
|
||||
torch.accelerator.set_device_idx("cpu")
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIACCELERATOR, "only one accelerator detected")
|
||||
def test_generic_multi_device_behavior(self):
|
||||
orig_device = torch.accelerator.current_device_index()
|
||||
orig_device = torch.accelerator.current_device_idx()
|
||||
target_device = (orig_device + 1) % torch.accelerator.device_count()
|
||||
|
||||
torch.accelerator.set_device_index(target_device)
|
||||
self.assertEqual(target_device, torch.accelerator.current_device_index())
|
||||
torch.accelerator.set_device_index(orig_device)
|
||||
self.assertEqual(orig_device, torch.accelerator.current_device_index())
|
||||
torch.accelerator.set_device_idx(target_device)
|
||||
self.assertEqual(target_device, torch.accelerator.current_device_idx())
|
||||
torch.accelerator.set_device_idx(orig_device)
|
||||
self.assertEqual(orig_device, torch.accelerator.current_device_idx())
|
||||
|
||||
s1 = torch.Stream(target_device)
|
||||
torch.accelerator.set_stream(s1)
|
||||
self.assertEqual(target_device, torch.accelerator.current_device_index())
|
||||
self.assertEqual(target_device, torch.accelerator.current_device_idx())
|
||||
torch.accelerator.synchronize(orig_device)
|
||||
self.assertEqual(target_device, torch.accelerator.current_device_index())
|
||||
self.assertEqual(target_device, torch.accelerator.current_device_idx())
|
||||
|
||||
def test_generic_stream_behavior(self):
|
||||
s1 = torch.Stream()
|
||||
|
|
|
|||
|
|
@ -2,27 +2,11 @@ r"""
|
|||
This package introduces support for the current :ref:`accelerator<accelerators>` in python.
|
||||
"""
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
||||
from ._utils import _device_t, _get_device_index
|
||||
|
||||
|
||||
__all__ = [
|
||||
"current_accelerator",
|
||||
"current_device_idx", # deprecated
|
||||
"current_device_index",
|
||||
"current_stream",
|
||||
"device_count",
|
||||
"is_available",
|
||||
"set_device_idx", # deprecated
|
||||
"set_device_index",
|
||||
"set_stream",
|
||||
"synchronize",
|
||||
]
|
||||
|
||||
|
||||
def device_count() -> int:
|
||||
r"""Return the number of current :ref:`accelerator<accelerators>` available.
|
||||
|
||||
|
|
@ -53,7 +37,7 @@ def current_accelerator() -> torch.device:
|
|||
torch.device: return the current accelerator as :class:`torch.device`.
|
||||
|
||||
.. note:: The index of the returned :class:`torch.device` will be ``None``, please use
|
||||
:func:`torch.accelerator.current_device_index` to know the current index being used.
|
||||
:func:`torch.accelerator.current_device_idx` to know the current index being used.
|
||||
And ensure to use :func:`torch.accelerator.is_available` to check if there is an available
|
||||
accelerator. If there is no available accelerator, this function will raise an exception.
|
||||
|
||||
|
|
@ -74,7 +58,7 @@ def current_accelerator() -> torch.device:
|
|||
return torch._C._accelerator_getAccelerator()
|
||||
|
||||
|
||||
def current_device_index() -> int:
|
||||
def current_device_idx() -> int:
|
||||
r"""Return the index of a currently selected device for the current :ref:`accelerator<accelerators>`.
|
||||
|
||||
Returns:
|
||||
|
|
@ -83,13 +67,7 @@ def current_device_index() -> int:
|
|||
return torch._C._accelerator_getDeviceIndex()
|
||||
|
||||
|
||||
current_device_idx = deprecated(
|
||||
"Use `current_device_index` instead.",
|
||||
category=FutureWarning,
|
||||
)(current_device_index)
|
||||
|
||||
|
||||
def set_device_index(device: _device_t, /) -> None:
|
||||
def set_device_idx(device: _device_t, /) -> None:
|
||||
r"""Set the current device index to a given device.
|
||||
|
||||
Args:
|
||||
|
|
@ -102,19 +80,13 @@ def set_device_index(device: _device_t, /) -> None:
|
|||
torch._C._accelerator_setDeviceIndex(device_index)
|
||||
|
||||
|
||||
set_device_idx = deprecated(
|
||||
"Use `set_device_index` instead.",
|
||||
category=FutureWarning,
|
||||
)(set_device_index)
|
||||
|
||||
|
||||
def current_stream(device: _device_t = None, /) -> torch.Stream:
|
||||
r"""Return the currently selected stream for a given device.
|
||||
|
||||
Args:
|
||||
device (:class:`torch.device`, str, int, optional): a given device that must match the current
|
||||
:ref:`accelerator<accelerators>` device type. If not given,
|
||||
use :func:`torch.accelerator.current_device_index` by default.
|
||||
use :func:`torch.accelerator.current_device_idx` by default.
|
||||
|
||||
Returns:
|
||||
torch.Stream: the currently selected stream for a given device.
|
||||
|
|
@ -140,7 +112,7 @@ def synchronize(device: _device_t = None, /) -> None:
|
|||
Args:
|
||||
device (:class:`torch.device`, str, int, optional): device for which to synchronize. It must match
|
||||
the current :ref:`accelerator<accelerators>` device type. If not given,
|
||||
use :func:`torch.accelerator.current_device_index` by default.
|
||||
use :func:`torch.accelerator.current_device_idx` by default.
|
||||
|
||||
.. note:: This function is a no-op if the current :ref:`accelerator<accelerators>` is not initialized.
|
||||
|
||||
|
|
@ -159,3 +131,15 @@ def synchronize(device: _device_t = None, /) -> None:
|
|||
"""
|
||||
device_index = _get_device_index(device, True)
|
||||
torch._C._accelerator_synchronizeDevice(device_index)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"current_accelerator",
|
||||
"current_device_idx",
|
||||
"current_stream",
|
||||
"device_count",
|
||||
"is_available",
|
||||
"set_device_idx",
|
||||
"set_stream",
|
||||
"synchronize",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.types import Device as _device_t
|
||||
from torch import device as _device
|
||||
|
||||
|
||||
_device_t = Union[_device, str, int, None]
|
||||
|
||||
|
||||
def _get_device_index(device: _device_t, optional: bool = False) -> int:
|
||||
|
|
@ -21,5 +24,5 @@ def _get_device_index(device: _device_t, optional: bool = False) -> int:
|
|||
raise ValueError(
|
||||
f"Expected a torch.device with a specified index or an integer, but got:{device}"
|
||||
)
|
||||
return torch.accelerator.current_device_index()
|
||||
return torch.accelerator.current_device_idx()
|
||||
return device_index
|
||||
|
|
|
|||
Loading…
Reference in a new issue