diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 32b22855f93..b50f0479e2f 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -69,6 +69,8 @@ class TORCH_API Context { return at::detail::getMPSHooks(); } else if (device_type == at::kPrivateUse1) { return at::detail::getPrivateUse1Hooks(); + } else if (device_type == at::kMTIA) { + return at::detail::getMTIAHooks(); } else { AT_ERROR( c10::DeviceTypeName(device_type), " device type not an accelerator."); @@ -156,6 +158,9 @@ class TORCH_API Context { void lazyInitXPU() { c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); }); } + void lazyInitMTIA() { + c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); }); + } void lazyInitPrivateUse1() { c10::call_once(thp_init, [&] { if (isPrivateUse1HooksRegistered()) { @@ -349,6 +354,7 @@ class TORCH_API Context { c10::once_flag thc_init; c10::once_flag thh_init; c10::once_flag thx_init; + c10::once_flag th_mtia_init; c10::once_flag thp_init; bool enabled_cudnn = true; bool deterministic_cudnn = false; diff --git a/aten/src/ATen/DeviceAccelerator.cpp b/aten/src/ATen/DeviceAccelerator.cpp index 05327cc219e..ec3cd2a2f55 100644 --- a/aten/src/ATen/DeviceAccelerator.cpp +++ b/aten/src/ATen/DeviceAccelerator.cpp @@ -10,6 +10,9 @@ C10_API std::optional getAccelerator(bool checked) { #define CHECK_NO_PU1 \ TORCH_CHECK(!is_privateuse1_backend_registered(), "Cannot have both CUDA and PrivateUse1"); +#define CHECK_NO_MTIA \ + TORCH_CHECK(!at::hasMTIA(), "Cannot have MTIA with other devices"); + if (is_privateuse1_backend_registered()) { // We explicitly allow PrivateUse1 and another device at the same time // as we use this for testing. @@ -17,7 +20,12 @@ C10_API std::optional getAccelerator(bool checked) { return kPrivateUse1; } else if (at::hasCUDA()) { CHECK_NO_PU1 + CHECK_NO_MTIA return kCUDA; + } else if (at::hasMTIA()) { + CHECK_NO_CUDA + CHECK_NO_PU1 + return kMTIA; } else { TORCH_CHECK(!checked, "Cannot access accelerator device when none is available.") return std::nullopt; diff --git a/aten/src/ATen/detail/AcceleratorHooksInterface.h b/aten/src/ATen/detail/AcceleratorHooksInterface.h index c099c9f59a6..96e15e1f69d 100644 --- a/aten/src/ATen/detail/AcceleratorHooksInterface.h +++ b/aten/src/ATen/detail/AcceleratorHooksInterface.h @@ -1,7 +1,7 @@ #pragma once #include - +#include namespace at { // AcceleratorHooksInterface is a shared interface provided by all @@ -16,6 +16,29 @@ struct TORCH_API AcceleratorHooksInterface { // Whether the device at device_index is fully initialized or not. virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0; + + virtual DeviceIndex deviceCount() const { + return 0; + } + + virtual void setCurrentDevice(DeviceIndex device) const { + TORCH_CHECK(false, "Backend doesn't support setCurrentDevice()"); + } + + virtual DeviceIndex getCurrentDevice() const { + TORCH_CHECK(false, "Backend doesn't support getCurrentDevice()"); + return -1; + } + + virtual DeviceIndex exchangeDevice(DeviceIndex device) const { + TORCH_CHECK(false, "Backend doesn't support exchangeDevice()"); + return -1; + } + + virtual DeviceIndex maybeExchangeDevice(DeviceIndex device) const { + TORCH_CHECK(false, "Backend doesn't support maybeExchangeDevice()"); + return -1; + } }; } // namespace at diff --git a/aten/src/ATen/detail/MTIAHooksInterface.cpp b/aten/src/ATen/detail/MTIAHooksInterface.cpp index 6b69fdb03f3..09638817138 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.cpp +++ b/aten/src/ATen/detail/MTIAHooksInterface.cpp @@ -8,19 +8,22 @@ namespace at { namespace detail { - -const MTIAHooksInterface &getMTIAHooks() { - static MTIAHooksInterface* MTIA_hooks = nullptr; +const MTIAHooksInterface& getMTIAHooks() { + static std::unique_ptr mtia_hooks = nullptr; static c10::once_flag once; c10::call_once(once, [] { - MTIA_hooks = - MTIAHooksRegistry()->Create("MTIAHooks", MTIAHooksArgs{}).release(); - if (!MTIA_hooks) { - MTIA_hooks = new MTIAHooksInterface(); + mtia_hooks = MTIAHooksRegistry()->Create("MTIAHooks", MTIAHooksArgs{}); + if (!mtia_hooks) { + mtia_hooks = std::make_unique(); } }); - return *MTIA_hooks; + return *mtia_hooks; } + +bool isMTIAHooksBuilt() { + return MTIAHooksRegistry()->Has("MTIAHooks"); +} + } // namespace detail C10_DEFINE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs) diff --git a/aten/src/ATen/detail/MTIAHooksInterface.h b/aten/src/ATen/detail/MTIAHooksInterface.h index c843ca52c2b..1da1bda4e61 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.h +++ b/aten/src/ATen/detail/MTIAHooksInterface.h @@ -1,7 +1,9 @@ #pragma once +#include #include +#include #include #include @@ -20,33 +22,72 @@ constexpr const char* MTIA_HELP = "to use some MTIA's functionality without MTIA extension included."; struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { +// this fails the implementation if MTIAHooks functions are called, but +// MTIA backend is not present. +#define FAIL_MTIAHOOKS_FUNC(func) \ + TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend."); + virtual ~MTIAHooksInterface() override = default; virtual void initMTIA() const { - TORCH_CHECK( - false, - "Cannot initialize MTIA without MTIA Extension for PyTorch.", - MTIA_HELP); + // Avoid logging here, since MTIA needs init devices first then it will know + // how many devices are available. Make it as no-op if mtia extension is not + // dynamically loaded. + return; } virtual bool hasMTIA() const { return false; } + virtual DeviceIndex deviceCount() const override { + return 0; + } + + virtual void deviceSynchronize(c10::DeviceIndex device_index) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + virtual std::string showConfig() const { - TORCH_CHECK( - false, - "Cannot query detailed MTIA version without MTIA Extension for PyTorch.", - MTIA_HELP); + FAIL_MTIAHOOKS_FUNC(__func__); } virtual bool hasPrimaryContext(DeviceIndex device_index) const override { - TORCH_CHECK( - false, - "Cannot check MTIA primary context without MTIA Extension for PyTorch.", - MTIA_HELP); + return false; } + virtual void setCurrentDevice(DeviceIndex device) const override { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + virtual DeviceIndex getCurrentDevice() const override { + FAIL_MTIAHOOKS_FUNC(__func__); + return -1; + } + + virtual DeviceIndex exchangeDevice(DeviceIndex device) const override { + FAIL_MTIAHOOKS_FUNC(__func__); + return -1; + } + + virtual DeviceIndex maybeExchangeDevice(DeviceIndex device) const override { + FAIL_MTIAHOOKS_FUNC(__func__); + return -1; + } + + virtual c10::Stream getCurrentStream(DeviceIndex device) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA); + } + + virtual c10::Stream getDefaultStream(DeviceIndex device) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA); + } + + virtual void setCurrentStream(const c10::Stream& stream) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } }; struct TORCH_API MTIAHooksArgs {}; @@ -57,5 +98,6 @@ C10_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs); namespace detail { TORCH_API const MTIAHooksInterface& getMTIAHooks(); +TORCH_API bool isMTIAHooksBuilt(); } // namespace detail } // namespace at diff --git a/build_variables.bzl b/build_variables.bzl index cebda39f4b9..5939da825cc 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -822,6 +822,7 @@ libtorch_python_core_sources = [ "torch/csrc/dynamo/init.cpp", "torch/csrc/functorch/init.cpp", "torch/csrc/mps/Module.cpp", + "torch/csrc/mtia/Module.cpp", "torch/csrc/inductor/aoti_runner/pybind.cpp", "torch/csrc/jit/backends/backend_init.cpp", "torch/csrc/jit/python/init.cpp", diff --git a/docs/source/index.rst b/docs/source/index.rst index 9e7cc6a9a6d..a7afe60bc28 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -69,6 +69,7 @@ Features described in this documentation are classified by release status: torch.cuda.memory mps xpu + mtia meta torch.backends torch.export diff --git a/docs/source/mtia.rst b/docs/source/mtia.rst new file mode 100644 index 00000000000..f2f5b5195dc --- /dev/null +++ b/docs/source/mtia.rst @@ -0,0 +1,34 @@ +torch.mtia +=================================== + +The MTIA backend is implemented out of the tree, only interfaces are be defined here. + +.. automodule:: torch.mtia +.. currentmodule:: torch.mtia + +.. autosummary:: + :toctree: generated + :nosignatures: + + StreamContext + current_device + current_stream + default_stream + device_count + init + is_available + is_initialized + set_stream + stream + synchronize + device + DeferredMtiaCallError + +Streams and events +------------------ +.. autosummary:: + :toctree: generated + :nosignatures: + + Event + Stream diff --git a/docs/source/torch.rst b/docs/source/torch.rst index b65a7a52398..32bcadc1545 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -684,6 +684,7 @@ Utilities set_float32_matmul_precision get_float32_matmul_precision set_warn_always + get_device_module is_warn_always_enabled vmap _assert diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 8b23117704d..34e49e15d85 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1719,6 +1719,24 @@ _TensorBase = TensorBase # Defined in torch/csrc/multiprocessing/init.cpp def _multiprocessing_init() -> None: ... +# Defined in torch/csrc/Module.cpp +def _accelerator_hooks_device_count() -> _int: ... +def _accelerator_hooks_set_current_device(device_index: _int) -> None: ... +def _accelerator_hooks_get_current_device() -> _int: ... +def _accelerator_hooks_exchange_device(device_index: _int) -> _int: ... +def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int: ... +def _get_accelerator(check: _bool = False) -> _device: ... + +# Defined in torch/csrc/mtia/Module.cpp +def _mtia_init() -> None: ... +def _mtia_isBuilt() -> _bool: ... +def _mtia_isInBadFork() -> _bool: ... +def _mtia_deviceSynchronize() -> None: ... +def _mtia_getCurrentStream(device: _int) -> Stream: ... +def _mtia_setCurrentStream(stream: Stream) -> None: ... +def _mtia_getDefaultStream(device: _int) -> Stream: ... + + # Defined in torch/csrc/mps/Module.cpp def _mps_deviceSynchronize() -> None: ... def _mps_get_default_generator() -> Generator: ... diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index 34eb451be08..118d913f681 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -24,6 +24,7 @@ class DeviceType(Enum): FPGA = ... MAIA = ... XLA = ... + MTIA = ... MPS = ... HPU = ... Meta = ... diff --git a/torch/__init__.py b/torch/__init__.py index 9a7249f2202..846038e3510 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -58,6 +58,7 @@ __all__ = [ 'SymBool', 'sym_not', 'unravel_index', 'sym_int', 'sym_float', 'sym_max', 'sym_min', 'sym_ite', 'compile', 'vmap', 'export', 'autocast', 'cond', 'GradScaler', + 'get_device_module', ] ################################################################################ @@ -1579,6 +1580,7 @@ from torch import cuda as cuda from torch import cpu as cpu from torch import mps as mps from torch import xpu as xpu +from torch import mtia as mtia from torch import autograd as autograd from torch.autograd import ( no_grad as no_grad, @@ -2016,6 +2018,27 @@ else: raise AttributeError(f"module '{__name__}' has no attribute '{name}'") +def get_device_module(device: Optional[Union[torch.device, str]] = None): + """ + Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...). + If no device is given, return the module for the current accelerator or CPU if none is present. + """ + if isinstance(device, torch.device): + device_module_name = device.type + elif isinstance(device, str): + device_module_name = torch.device(device).type + elif device is None: + # Using default accelerator type. If no accelerator is available, it automatically returns CPU device. + device_module_name = torch._C._get_accelerator().type + else: + raise RuntimeError(f"Invalid value of device '{device}', expect torch.device, str, or None") + device_module = getattr(torch, device_module_name, None) + if device_module is None: + raise RuntimeError( + f"Device '{device_module_name}' does not have a corresponding module registered as 'torch.{device_module_name}'." + ) + return device_module + def _constrain_as_value(symbol, min: Optional[builtins.int] = None, max: Optional[builtins.int] = None): """ diff --git a/torch/_utils.py b/torch/_utils.py index 7f9a1af43fe..43c6284d241 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -713,6 +713,8 @@ def _get_available_device_type(): return "cuda" if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined] return "xpu" + if hasattr(torch, "mtia") and torch.mtia.is_available(): + return "mtia" custom_backend_name = torch._C._get_privateuse1_backend_name() custom_device_mod = getattr(torch, custom_backend_name, None) if custom_device_mod and custom_device_mod.is_available(): @@ -727,6 +729,8 @@ def _get_device_attr(get_member): return get_member(torch.cuda) if device_type and device_type.lower() == "xpu": return get_member(torch.xpu) # type: ignore[attr-defined] + if device_type and device_type.lower() == "mtia": + return get_member(torch.mtia) if device_type == torch._C._get_privateuse1_backend_name(): return get_member(getattr(torch, device_type)) # add more available device types here diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 5723a024e7a..b446d293957 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -16,10 +17,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -72,6 +75,7 @@ #include #include #include +#include #include #include #include @@ -1641,6 +1645,7 @@ PyObject* initModule() { #ifdef USE_XPU torch::xpu::initModule(module); #endif + torch::mtia::initModule(module); torch::cpu::initModule(module); torch::initVerboseBindings(module); ASSERT_TRUE(THPStorage_init(module)); @@ -1975,6 +1980,70 @@ Call this whenever a new thread is created in order to propagate values from return at::impl::ThreadLocalPythonObjects::get_state().contains(key); }); + py_module.def("_accelerator_hooks_device_count", []() { + auto device_type = at::getAccelerator(); + if (device_type.has_value()) { + return at::globalContext() + .getAcceleratorHooksInterface(device_type.value()) + .deviceCount(); + } + return c10::DeviceIndex(-1); + }); + + py_module.def( + "_accelerator_hooks_set_current_device", + [](c10::DeviceIndex device_index) { + auto device_type = at::getAccelerator(); + if (device_type.has_value()) { + at::globalContext() + .getAcceleratorHooksInterface(device_type.value()) + .setCurrentDevice(device_index); + } + }); + + py_module.def("_accelerator_hooks_get_current_device", []() { + auto device_type = at::getAccelerator(); + if (device_type.has_value()) { + return at::globalContext() + .getAcceleratorHooksInterface(device_type.value()) + .getCurrentDevice(); + } + return c10::DeviceIndex(-1); + }); + + py_module.def( + "_accelerator_hooks_exchange_device", [](c10::DeviceIndex device_index) { + auto device_type = at::getAccelerator(); + if (device_type.has_value()) { + return at::globalContext() + .getAcceleratorHooksInterface(device_type.value()) + .exchangeDevice(device_index); + } + return c10::DeviceIndex(-1); + }); + + py_module.def( + "_accelerator_hooks_maybe_exchange_device", + [](c10::DeviceIndex device_index) { + auto device_type = at::getAccelerator(); + if (device_type.has_value()) { + return at::globalContext() + .getAcceleratorHooksInterface(device_type.value()) + .maybeExchangeDevice(device_index); + } + return c10::DeviceIndex(-1); + }); + + py_module.def( + "_get_accelerator", + [](c10::optional check = c10::nullopt) { + return c10::Device( + at::getAccelerator(check.value_or(false)) + .value_or(c10::DeviceType::CPU), + -1); + }, + py::arg("check") = nullptr); + #ifdef USE_CUDA PyObject* has_cuda = Py_True; #else diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp new file mode 100644 index 00000000000..84cc11f7187 --- /dev/null +++ b/torch/csrc/mtia/Module.cpp @@ -0,0 +1,81 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#ifndef WIN32 +#include +#endif + +namespace torch { +namespace mtia { + +static bool in_bad_fork = false; // True for children forked after mtia init + +#ifndef WIN32 +// Called in the forked child if mtia has already been initialized +static void forked_child() { + in_bad_fork = true; + torch::utils::set_requires_device_init(at::kMTIA, true); +} +#endif + +// Should be called before the first mtia call. +// Note: This is distinct from initExtension because a stub mtia implementation +// has some working functions (e.g. device_count) but cannot fully initialize. +static void poison_fork() { +#ifndef WIN32 + static c10::once_flag flag; + c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); }); +#endif +} + +void initModule(PyObject* module) { + auto m = py::handle(module).cast(); + + m.def("_mtia_init", []() { + TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level + poison_fork(); + at::globalContext().lazyInitMTIA(); + }); + + m.def("_mtia_isBuilt", []() { + // Check if the MTIAHooks class has been registered with the registry. + return at::detail::isMTIAHooksBuilt(); + }); + + m.def("_mtia_isInBadFork", []() { return in_bad_fork; }); + + m.def("_mtia_getCurrentStream", [](c10::DeviceIndex device_index) { + torch::utils::device_lazy_init(at::kMTIA); + return at::detail::getMTIAHooks().getCurrentStream(device_index); + }); + + m.def("_mtia_deviceSynchronize", [](c10::DeviceIndex device_index) { + torch::utils::device_lazy_init(at::kMTIA); + at::detail::getMTIAHooks().deviceSynchronize( + at::detail::getMTIAHooks().getCurrentDevice()); + }); + + m.def("_mtia_getDefaultStream", [](c10::DeviceIndex device_index) { + torch::utils::device_lazy_init(at::kMTIA); + return at::detail::getMTIAHooks().getDefaultStream(device_index); + }); + + m.def("_mtia_setCurrentStream", [](const c10::Stream& stream) { + torch::utils::device_lazy_init(at::kMTIA); + auto device = at::detail::getMTIAHooks().getCurrentDevice(); + if (device != stream.device_index()) { + at::detail::getMTIAHooks().setCurrentDevice(stream.device_index()); + } + at::detail::getMTIAHooks().setCurrentStream(stream); + }); +} + +} // namespace mtia +} // namespace torch diff --git a/torch/csrc/mtia/Module.h b/torch/csrc/mtia/Module.h new file mode 100644 index 00000000000..96a98ed448e --- /dev/null +++ b/torch/csrc/mtia/Module.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace torch { +namespace mtia { + +// PyMethodDef* python_functions(); +void initModule(PyObject* module); + +} // namespace mtia +} // namespace torch diff --git a/torch/csrc/utils/pybind.h b/torch/csrc/utils/pybind.h index 36cb83659aa..1a4e7bb26fc 100644 --- a/torch/csrc/utils/pybind.h +++ b/torch/csrc/utils/pybind.h @@ -194,6 +194,12 @@ struct type_caster { // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) PYBIND11_TYPE_CASTER(c10::Stream, _("torch.Stream")); + // PYBIND11_TYPE_CASTER defines a member field called value. Since c10::Stream + // cannot be default-initialized, we provide this constructor to explicitly + // initialize that field. The value doesn't matter as it will be overwritten + // after a successful call to load. + type_caster() : value(c10::Stream::DEFAULT, c10::Device(c10::kCPU, 0)) {} + bool load(handle src, bool) { PyObject* obj = src.ptr(); if (THPStream_Check(obj)) { diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py new file mode 100644 index 00000000000..4007f0e584f --- /dev/null +++ b/torch/mtia/__init__.py @@ -0,0 +1,262 @@ +r""" +This package enables an interface for accessing MTIA backend in python +""" + +import threading +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch + +from torch.types import Device + +from .. import device as _device, Tensor +from .._utils import _dummy_type, _LazySeedTracker, classproperty +from ._utils import _get_device_index + +_device_t = Union[_device, str, int, None] + +# torch.mtia.Event/Stream is alias of torch.Event/Stream +Event = torch.Event +Stream = torch.Stream + +_initialized = False +_queued_calls: List[ + Tuple[Callable[[], None], List[str]] +] = [] # don't invoke these until initialization occurs +_tls = threading.local() +_initialization_lock = threading.Lock() +_lazy_seed_tracker = _LazySeedTracker() + + +def init(): + _lazy_init() + + +def is_initialized(): + r"""Return whether PyTorch's MTIA state has been initialized.""" + return _initialized and not _is_in_bad_fork() + + +def _is_in_bad_fork() -> bool: + return torch._C._mtia_isInBadFork() + + +def _lazy_init() -> None: + global _initialized, _queued_calls + if is_initialized() or hasattr(_tls, "is_initializing"): + return + with _initialization_lock: + # We be double-checked locking, boys! This is OK because + # the above test was GIL protected anyway. The inner test + # is for when a thread blocked on some other thread which was + # doing the initialization; when they get the lock, they will + # find there is nothing left to do. + if is_initialized(): + return + # It is important to prevent other threads from entering _lazy_init + # immediately, while we are still guaranteed to have the GIL, because some + # of the C calls we make below will release the GIL + if _is_in_bad_fork(): + raise RuntimeError( + "Cannot re-initialize MTIA in forked subprocess. To use MTIA with " + "multiprocessing, you must use the 'spawn' start method" + ) + if not _is_compiled(): + raise AssertionError("Torch not compiled with MTIA enabled") + + torch._C._mtia_init() + # Some of the queued calls may reentrantly call _lazy_init(); + # we need to just return without initializing in that case. + # However, we must not let any *other* threads in! + _tls.is_initializing = True + + for calls in _lazy_seed_tracker.get_calls(): + if calls: + _queued_calls.append(calls) + + try: + for queued_call, orig_traceback in _queued_calls: + try: + queued_call() + except Exception as e: + msg = ( + f"MTIA call failed lazily at initialization with error: {str(e)}\n\n" + f"MTIA call was originally invoked at:\n\n{''.join(orig_traceback)}" + ) + raise DeferredMtiaCallError(msg) from e + finally: + delattr(_tls, "is_initializing") + _initialized = True + + +class DeferredMtiaCallError(Exception): + pass + + +def _is_compiled() -> bool: + r"""Return true if compiled with MTIA support.""" + return torch._C._mtia_isBuilt() + + +def is_available() -> bool: + r"""Return true if MTIA device is available""" + if not _is_compiled(): + return False + # MTIA has to init devices first to know if there is any devices available. + return device_count() > 0 + + +def synchronize() -> None: + r"""Waits for all jobs in all streams on a MTIA device to complete.""" + return torch._C._mtia_deviceSynchronize() + + +def device_count() -> int: + r"""Return the number of MTIA devices available.""" + return torch._C._accelerator_hooks_device_count() + + +def current_device() -> int: + r"""Return the index of a currently selected device.""" + return torch._C._accelerator_hooks_get_current_device() + + +def current_stream(device: Optional[_device_t] = None) -> Stream: + r"""Return the currently selected :class:`Stream` for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + the currently selected :class:`Stream` for the current device, given + by :func:`~torch.mtia.current_device`, if :attr:`device` is ``None`` + (default). + """ + return torch._C._mtia_getCurrentStream(_get_device_index(device, optional=True)) + + +def default_stream(device: Optional[_device_t] = None) -> Stream: + r"""Return the default :class:`Stream` for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + the default :class:`Stream` for the current device, given by + :func:`~torch.mtia.current_device`, if :attr:`device` is ``None`` + (default). + """ + return torch._C._mtia_getDefaultStream(_get_device_index(device, optional=True)) + + +def set_stream(stream: Stream): + r"""Set the current stream.This is a wrapper API to set the stream. + Usage of this function is discouraged in favor of the ``stream`` + context manager. + + Args: + stream (Stream): selected stream. This function is a no-op + if this argument is ``None``. + """ + if stream is None: + return + torch._C._mtia_setCurrentStream(stream) + + +class device: + r"""Context-manager that changes the selected device. + + Args: + device (torch.device or int): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + + def __init__(self, device: Any): + self.idx = _get_device_index(device, optional=True) + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch._C._accelerator_hooks_maybe_exchange_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + self.idx = torch._C._accelerator_hooks_maybe_exchange_device(self.prev_idx) + return False + + +class StreamContext: + r"""Context-manager that selects a given stream. + + All MTIA kernels queued within its context will be enqueued on a selected + stream. + + Args: + Stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + .. note:: Streams are per-device. + """ + + cur_stream: Optional["torch.mtia.Stream"] + + def __init__(self, stream: Optional["torch.mtia.Stream"]): + self.stream = stream + self.idx = _get_device_index(None, True) + if not torch.jit.is_scripting(): + if self.idx is None: + self.idx = -1 + + self.src_prev_stream = ( + None if not torch.jit.is_scripting() else torch.mtia.default_stream(None) + ) + self.dst_prev_stream = ( + None if not torch.jit.is_scripting() else torch.mtia.default_stream(None) + ) + + def __enter__(self): + # Local cur_stream variable for type refinement + cur_stream = self.stream + # Return if stream is None or MTIA device not available + if cur_stream is None or self.idx == -1: + return + self.src_prev_stream = torch.mtia.current_stream(None) + + # If the stream is not on the current device, then + # set the current stream on the device + if self.src_prev_stream.device != cur_stream.device: + with device(cur_stream.device): + self.dst_prev_stream = torch.mtia.current_stream(cur_stream.device) + torch.mtia.set_stream(cur_stream) + + def __exit__(self, type: Any, value: Any, traceback: Any): + # Local cur_stream variable for type refinement + cur_stream = self.stream + # If stream is None or no MTIA device available, return + if cur_stream is None or self.idx == -1: + return + + # Reset the stream on the original device + # and destination device + if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr] + torch.mtia.set_stream(self.dst_prev_stream) # type: ignore[arg-type] + torch.mtia.set_stream(self.src_prev_stream) # type: ignore[arg-type] + + +def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext: + r"""Wrap around the Context-manager StreamContext that selects a given stream. + + Arguments: + stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + ..Note:: In eager mode stream is of type Stream class while in JIT it doesn't support torch.mtia.stream + """ + return StreamContext(stream) + + +__all__ = [ + "init", + "is_available", + "is_initialized", + "synchronize", + "device_count", + "current_device", + "current_stream", + "default_stream", + "set_stream", + "stream", + "device", +] diff --git a/torch/mtia/_utils.py b/torch/mtia/_utils.py new file mode 100644 index 00000000000..090e26f3212 --- /dev/null +++ b/torch/mtia/_utils.py @@ -0,0 +1,38 @@ +from typing import Any + +import torch + +# The _get_device_index has been moved to torch.utils._get_device_index +from torch._utils import _get_device_index as _torch_get_device_index + + +def _get_device_index( + device: Any, optional: bool = False, allow_cpu: bool = False +) -> int: + r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. + + If :attr:`device` is a torch.device object, returns the device index if it + is a MTIA device. Note that for a MTIA device without a specified index, + i.e., ``torch.device('mtia')``, this will return the current default MTIA + device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, + CPU devices will be accepted and ``-1`` will be returned in this case. + + If :attr:`device` is a Python integer, it is returned as is. + + If :attr:`device` is ``None``, this will return the current default MTIA + device if :attr:`optional` is ``True``. + """ + if isinstance(device, int): + return device + if isinstance(device, str): + device = torch.device(device) + if isinstance(device, torch.device): + if allow_cpu: + if device.type not in ["mtia", "cpu"]: + raise ValueError(f"Expected a mtia or cpu device, but got: {device}") + elif device.type != "mtia": + raise ValueError(f"Expected a mtia device, but got: {device}") + if not torch.jit.is_scripting(): + if isinstance(device, torch.mtia.device): + return device.idx + return _torch_get_device_index(device, optional, allow_cpu) diff --git a/torch/overrides.py b/torch/overrides.py index 728c75c090b..6c521bc7003 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -283,6 +283,7 @@ def get_ignored_functions() -> Set[Callable]: torch.use_deterministic_algorithms, torch.is_deterministic_algorithms_warn_only_enabled, torch.set_deterministic_debug_mode, + torch.get_device_module, torch.get_deterministic_debug_mode, torch.set_float32_matmul_precision, torch.get_float32_matmul_precision,