Revert "[BE][accelerator] formalize API name {current,set}_device_{idx => index} (#140542)"

This reverts commit fb02b40d27.

Reverted https://github.com/pytorch/pytorch/pull/140542 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but I need to revert this in order to revert https://github.com/pytorch/pytorch/pull/133572#issuecomment-2537204202 due to a conflict ([comment](https://github.com/pytorch/pytorch/pull/140542#issuecomment-2537253665))
This commit is contained in:
PyTorch MergeBot 2024-12-11 21:44:23 +00:00
parent de313f1155
commit cd50bd8477
4 changed files with 31 additions and 46 deletions

View file

@ -10,9 +10,7 @@ torch.accelerator
device_count
is_available
current_accelerator
set_device_index
set_device_idx
current_device_index
current_device_idx
set_stream
current_stream

View file

@ -27,23 +27,23 @@ class TestAccelerator(TestCase):
with self.assertRaisesRegex(
ValueError, "doesn't match the current accelerator"
):
torch.accelerator.set_device_index("cpu")
torch.accelerator.set_device_idx("cpu")
@unittest.skipIf(not TEST_MULTIACCELERATOR, "only one accelerator detected")
def test_generic_multi_device_behavior(self):
orig_device = torch.accelerator.current_device_index()
orig_device = torch.accelerator.current_device_idx()
target_device = (orig_device + 1) % torch.accelerator.device_count()
torch.accelerator.set_device_index(target_device)
self.assertEqual(target_device, torch.accelerator.current_device_index())
torch.accelerator.set_device_index(orig_device)
self.assertEqual(orig_device, torch.accelerator.current_device_index())
torch.accelerator.set_device_idx(target_device)
self.assertEqual(target_device, torch.accelerator.current_device_idx())
torch.accelerator.set_device_idx(orig_device)
self.assertEqual(orig_device, torch.accelerator.current_device_idx())
s1 = torch.Stream(target_device)
torch.accelerator.set_stream(s1)
self.assertEqual(target_device, torch.accelerator.current_device_index())
self.assertEqual(target_device, torch.accelerator.current_device_idx())
torch.accelerator.synchronize(orig_device)
self.assertEqual(target_device, torch.accelerator.current_device_index())
self.assertEqual(target_device, torch.accelerator.current_device_idx())
def test_generic_stream_behavior(self):
s1 = torch.Stream()

View file

@ -2,27 +2,11 @@ r"""
This package introduces support for the current :ref:`accelerator<accelerators>` in python.
"""
from typing_extensions import deprecated
import torch
from ._utils import _device_t, _get_device_index
__all__ = [
"current_accelerator",
"current_device_idx", # deprecated
"current_device_index",
"current_stream",
"device_count",
"is_available",
"set_device_idx", # deprecated
"set_device_index",
"set_stream",
"synchronize",
]
def device_count() -> int:
r"""Return the number of current :ref:`accelerator<accelerators>` available.
@ -53,7 +37,7 @@ def current_accelerator() -> torch.device:
torch.device: return the current accelerator as :class:`torch.device`.
.. note:: The index of the returned :class:`torch.device` will be ``None``, please use
:func:`torch.accelerator.current_device_index` to know the current index being used.
:func:`torch.accelerator.current_device_idx` to know the current index being used.
And ensure to use :func:`torch.accelerator.is_available` to check if there is an available
accelerator. If there is no available accelerator, this function will raise an exception.
@ -74,7 +58,7 @@ def current_accelerator() -> torch.device:
return torch._C._accelerator_getAccelerator()
def current_device_index() -> int:
def current_device_idx() -> int:
r"""Return the index of a currently selected device for the current :ref:`accelerator<accelerators>`.
Returns:
@ -83,13 +67,7 @@ def current_device_index() -> int:
return torch._C._accelerator_getDeviceIndex()
current_device_idx = deprecated(
"Use `current_device_index` instead.",
category=FutureWarning,
)(current_device_index)
def set_device_index(device: _device_t, /) -> None:
def set_device_idx(device: _device_t, /) -> None:
r"""Set the current device index to a given device.
Args:
@ -102,19 +80,13 @@ def set_device_index(device: _device_t, /) -> None:
torch._C._accelerator_setDeviceIndex(device_index)
set_device_idx = deprecated(
"Use `set_device_index` instead.",
category=FutureWarning,
)(set_device_index)
def current_stream(device: _device_t = None, /) -> torch.Stream:
r"""Return the currently selected stream for a given device.
Args:
device (:class:`torch.device`, str, int, optional): a given device that must match the current
:ref:`accelerator<accelerators>` device type. If not given,
use :func:`torch.accelerator.current_device_index` by default.
use :func:`torch.accelerator.current_device_idx` by default.
Returns:
torch.Stream: the currently selected stream for a given device.
@ -140,7 +112,7 @@ def synchronize(device: _device_t = None, /) -> None:
Args:
device (:class:`torch.device`, str, int, optional): device for which to synchronize. It must match
the current :ref:`accelerator<accelerators>` device type. If not given,
use :func:`torch.accelerator.current_device_index` by default.
use :func:`torch.accelerator.current_device_idx` by default.
.. note:: This function is a no-op if the current :ref:`accelerator<accelerators>` is not initialized.
@ -159,3 +131,15 @@ def synchronize(device: _device_t = None, /) -> None:
"""
device_index = _get_device_index(device, True)
torch._C._accelerator_synchronizeDevice(device_index)
__all__ = [
"current_accelerator",
"current_device_idx",
"current_stream",
"device_count",
"is_available",
"set_device_idx",
"set_stream",
"synchronize",
]

View file

@ -1,7 +1,10 @@
from typing import Optional
from typing import Optional, Union
import torch
from torch.types import Device as _device_t
from torch import device as _device
_device_t = Union[_device, str, int, None]
def _get_device_index(device: _device_t, optional: bool = False) -> int:
@ -21,5 +24,5 @@ def _get_device_index(device: _device_t, optional: bool = False) -> int:
raise ValueError(
f"Expected a torch.device with a specified index or an integer, but got:{device}"
)
return torch.accelerator.current_device_index()
return torch.accelerator.current_device_idx()
return device_index