diff --git a/docs/source/accelerator.rst b/docs/source/accelerator.rst index 760056806b1..6e4d7a541ee 100644 --- a/docs/source/accelerator.rst +++ b/docs/source/accelerator.rst @@ -10,9 +10,7 @@ torch.accelerator device_count is_available current_accelerator - set_device_index set_device_idx - current_device_index current_device_idx set_stream current_stream diff --git a/test/test_accelerator.py b/test/test_accelerator.py index 51f2602a923..7864e0ca39b 100644 --- a/test/test_accelerator.py +++ b/test/test_accelerator.py @@ -27,23 +27,23 @@ class TestAccelerator(TestCase): with self.assertRaisesRegex( ValueError, "doesn't match the current accelerator" ): - torch.accelerator.set_device_index("cpu") + torch.accelerator.set_device_idx("cpu") @unittest.skipIf(not TEST_MULTIACCELERATOR, "only one accelerator detected") def test_generic_multi_device_behavior(self): - orig_device = torch.accelerator.current_device_index() + orig_device = torch.accelerator.current_device_idx() target_device = (orig_device + 1) % torch.accelerator.device_count() - torch.accelerator.set_device_index(target_device) - self.assertEqual(target_device, torch.accelerator.current_device_index()) - torch.accelerator.set_device_index(orig_device) - self.assertEqual(orig_device, torch.accelerator.current_device_index()) + torch.accelerator.set_device_idx(target_device) + self.assertEqual(target_device, torch.accelerator.current_device_idx()) + torch.accelerator.set_device_idx(orig_device) + self.assertEqual(orig_device, torch.accelerator.current_device_idx()) s1 = torch.Stream(target_device) torch.accelerator.set_stream(s1) - self.assertEqual(target_device, torch.accelerator.current_device_index()) + self.assertEqual(target_device, torch.accelerator.current_device_idx()) torch.accelerator.synchronize(orig_device) - self.assertEqual(target_device, torch.accelerator.current_device_index()) + self.assertEqual(target_device, torch.accelerator.current_device_idx()) def test_generic_stream_behavior(self): s1 = torch.Stream() diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index 8f912b468a3..f4d7593175b 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -2,27 +2,11 @@ r""" This package introduces support for the current :ref:`accelerator` in python. """ -from typing_extensions import deprecated - import torch from ._utils import _device_t, _get_device_index -__all__ = [ - "current_accelerator", - "current_device_idx", # deprecated - "current_device_index", - "current_stream", - "device_count", - "is_available", - "set_device_idx", # deprecated - "set_device_index", - "set_stream", - "synchronize", -] - - def device_count() -> int: r"""Return the number of current :ref:`accelerator` available. @@ -53,7 +37,7 @@ def current_accelerator() -> torch.device: torch.device: return the current accelerator as :class:`torch.device`. .. note:: The index of the returned :class:`torch.device` will be ``None``, please use - :func:`torch.accelerator.current_device_index` to know the current index being used. + :func:`torch.accelerator.current_device_idx` to know the current index being used. And ensure to use :func:`torch.accelerator.is_available` to check if there is an available accelerator. If there is no available accelerator, this function will raise an exception. @@ -74,7 +58,7 @@ def current_accelerator() -> torch.device: return torch._C._accelerator_getAccelerator() -def current_device_index() -> int: +def current_device_idx() -> int: r"""Return the index of a currently selected device for the current :ref:`accelerator`. Returns: @@ -83,13 +67,7 @@ def current_device_index() -> int: return torch._C._accelerator_getDeviceIndex() -current_device_idx = deprecated( - "Use `current_device_index` instead.", - category=FutureWarning, -)(current_device_index) - - -def set_device_index(device: _device_t, /) -> None: +def set_device_idx(device: _device_t, /) -> None: r"""Set the current device index to a given device. Args: @@ -102,19 +80,13 @@ def set_device_index(device: _device_t, /) -> None: torch._C._accelerator_setDeviceIndex(device_index) -set_device_idx = deprecated( - "Use `set_device_index` instead.", - category=FutureWarning, -)(set_device_index) - - def current_stream(device: _device_t = None, /) -> torch.Stream: r"""Return the currently selected stream for a given device. Args: device (:class:`torch.device`, str, int, optional): a given device that must match the current :ref:`accelerator` device type. If not given, - use :func:`torch.accelerator.current_device_index` by default. + use :func:`torch.accelerator.current_device_idx` by default. Returns: torch.Stream: the currently selected stream for a given device. @@ -140,7 +112,7 @@ def synchronize(device: _device_t = None, /) -> None: Args: device (:class:`torch.device`, str, int, optional): device for which to synchronize. It must match the current :ref:`accelerator` device type. If not given, - use :func:`torch.accelerator.current_device_index` by default. + use :func:`torch.accelerator.current_device_idx` by default. .. note:: This function is a no-op if the current :ref:`accelerator` is not initialized. @@ -159,3 +131,15 @@ def synchronize(device: _device_t = None, /) -> None: """ device_index = _get_device_index(device, True) torch._C._accelerator_synchronizeDevice(device_index) + + +__all__ = [ + "current_accelerator", + "current_device_idx", + "current_stream", + "device_count", + "is_available", + "set_device_idx", + "set_stream", + "synchronize", +] diff --git a/torch/accelerator/_utils.py b/torch/accelerator/_utils.py index 1828715bcdf..abaa00c44b5 100644 --- a/torch/accelerator/_utils.py +++ b/torch/accelerator/_utils.py @@ -1,7 +1,10 @@ -from typing import Optional +from typing import Optional, Union import torch -from torch.types import Device as _device_t +from torch import device as _device + + +_device_t = Union[_device, str, int, None] def _get_device_index(device: _device_t, optional: bool = False) -> int: @@ -21,5 +24,5 @@ def _get_device_index(device: _device_t, optional: bool = False) -> int: raise ValueError( f"Expected a torch.device with a specified index or an integer, but got:{device}" ) - return torch.accelerator.current_device_index() + return torch.accelerator.current_device_idx() return device_index