diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h index 4fff139f277..a7a1f8dcec7 100644 --- a/aten/src/ATen/detail/MPSHooksInterface.h +++ b/aten/src/ATen/detail/MPSHooksInterface.h @@ -28,6 +28,10 @@ struct TORCH_API MPSHooksInterface { return false; } + virtual bool isOnMacOS13orNewer() const { + AT_ERROR("MPS backend is not available."); + } + virtual const Generator& getDefaultMPSGenerator() const { AT_ERROR("Cannot get default MPS generator without MPS backend."); } @@ -35,6 +39,10 @@ struct TORCH_API MPSHooksInterface { virtual Allocator* getMPSDeviceAllocator() const { AT_ERROR("MPSDeviceAllocator requires MPS."); } + + virtual void deviceSynchronize() const { + AT_ERROR("Cannot synchronize MPS device without MPS backend."); + } }; struct TORCH_API MPSHooksArgs {}; diff --git a/aten/src/ATen/mps/MPSDevice.h b/aten/src/ATen/mps/MPSDevice.h index 0426f546bb3..1890d6050d9 100644 --- a/aten/src/ATen/mps/MPSDevice.h +++ b/aten/src/ATen/mps/MPSDevice.h @@ -79,7 +79,7 @@ class TORCH_API MPSDevice { TORCH_API bool is_available(); TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS); - +TORCH_API void device_synchronize(); TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false); } // namespace mps diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm index d9306f25ffb..0576f9bb789 100644 --- a/aten/src/ATen/mps/MPSDevice.mm +++ b/aten/src/ATen/mps/MPSDevice.mm @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -122,5 +123,9 @@ bool is_macos_13_or_newer(MacOSVersion version) { return MPSDevice::getInstance()->isMacOS13Plus(version); } +void device_synchronize() { + getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT); +} + } // namespace mps } // namespace at diff --git a/aten/src/ATen/mps/MPSHooks.cpp b/aten/src/ATen/mps/MPSHooks.cpp index 5fde8f3843f..f2b0ea6962e 100644 --- a/aten/src/ATen/mps/MPSHooks.cpp +++ b/aten/src/ATen/mps/MPSHooks.cpp @@ -16,6 +16,10 @@ bool MPSHooks::hasMPS() const { return at::mps::is_available(); } +bool MPSHooks::isOnMacOS13orNewer() const { + return at::mps::is_macos_13_or_newer(); +} + Allocator* MPSHooks::getMPSDeviceAllocator() const { return at::mps::GetMPSAllocator(); } @@ -24,6 +28,10 @@ const Generator& MPSHooks::getDefaultMPSGenerator() const { return at::mps::detail::getDefaultMPSGenerator(); } +void MPSHooks::deviceSynchronize() const { + at::mps::device_synchronize(); +} + using at::MPSHooksRegistry; using at::RegistererMPSHooksRegistry; diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h index 2bef3eac426..dfc74936285 100644 --- a/aten/src/ATen/mps/MPSHooks.h +++ b/aten/src/ATen/mps/MPSHooks.h @@ -13,8 +13,10 @@ struct MPSHooks : public at::MPSHooksInterface { MPSHooks(at::MPSHooksArgs) {} void initMPS() const override; bool hasMPS() const override; + bool isOnMacOS13orNewer() const override; Allocator* getMPSDeviceAllocator() const override; const Generator& getDefaultMPSGenerator() const override; + void deviceSynchronize() const override; }; }} // at::mps diff --git a/build_variables.bzl b/build_variables.bzl index f16042a814b..59e21c36b54 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -822,6 +822,7 @@ libtorch_python_core_sources = [ "torch/csrc/dynamo/guards.cpp", "torch/csrc/dynamo/init.cpp", "torch/csrc/functorch/init.cpp", + "torch/csrc/mps/Module.cpp", "torch/csrc/jit/backends/backend_init.cpp", "torch/csrc/jit/python/init.cpp", "torch/csrc/jit/passes/onnx.cpp", diff --git a/docs/source/index.rst b/docs/source/index.rst index a8ce02630d5..59c363d23a0 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -81,6 +81,7 @@ Features described in this documentation are classified by release status: torch.autograd torch.library cuda + mps torch.backends torch.distributed torch.distributed.algorithms.join diff --git a/docs/source/mps.rst b/docs/source/mps.rst new file mode 100644 index 00000000000..9a5c0df5110 --- /dev/null +++ b/docs/source/mps.rst @@ -0,0 +1,14 @@ +torch.mps +=================================== +.. automodule:: torch.mps +.. currentmodule:: torch.mps + +.. autosummary:: + :toctree: generated + :nosignatures: + + synchronize + get_rng_state + set_rng_state + manual_seed + seed \ No newline at end of file diff --git a/test/test_mps.py b/test/test_mps.py index 34ecb2ee608..2ee068cf573 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -5972,6 +5972,45 @@ class TestNLLLoss(TestCase): mps_x = torch.randn(5, device='mps', generator=g_mps) self.assertEqual(mps_x, mps_y) + def test_default_mps_generator(self): + # manual seeding on the "default" MPS generator using + # the global torch.manual_seed() + torch.manual_seed(230) + mps_x = torch.randn(5, device='mps') + # manual seeding using torch.mps.manual_seed() + # which should set the "default" MPS generator + # like the global torch.manual_seed() + torch.mps.manual_seed(230) + mps_y = torch.randn(5, device='mps') + # seed values were the same, so the random tensor contents should match + self.assertEqual(mps_x, mps_y) + + # save the default generator's state to restore it later + g_state = torch.mps.get_rng_state() + + # generate random numbers without seeding + mps_x = torch.randn(5, device='mps') + # in this case, the random results must differ from the last generated random results + self.assertNotEqual(mps_x, mps_y) + + # restore the previously saved state, and the results should match again + torch.mps.set_rng_state(g_state) + mps_x = torch.randn(5, device='mps') + self.assertEqual(mps_x, mps_y) + + def test_device_synchronize(self): + # just running some ops each followed by a synchronize to wait for + # MPS stream to finish running each of them + net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\ + .to(device='mps', dtype=torch.float) + + x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True) + torch.mps.synchronize() + x = net1(x) + torch.mps.synchronize() + x.backward(torch.randn_like(x)) + torch.mps.synchronize() + # Test random_.to and random_.from def test_random(self): def helper(shape, low, high, dtype=torch.int32): diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 28b8d8820c5..9355dbda48b 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -903,8 +903,6 @@ def _disabled_torch_function_impl(func: Callable, types: Iterable[Type], args: T def _disabled_torch_dispatch_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_dispatch_function def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ... def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ... -def _is_mps_available() -> _bool: ... -def _is_mps_on_macos_13_or_newer() -> _bool: ... class _LinalgBackend: Default: _LinalgBackend Cusolver: _LinalgBackend @@ -1200,6 +1198,12 @@ class _TensorBase(metaclass=_TensorMeta): # Defined in torch/csrc/multiprocessing/init.cpp def _multiprocessing_init() -> None: ... +# Defined in torch/csrc/mps/Module.cpp +def _mps_synchronize() -> None: ... +def _mps_get_default_generator() -> Generator: ... +def _is_mps_available() -> _bool: ... +def _is_mps_on_macos_13_or_newer() -> _bool: ... + # Defined in torch/csrc/cuda/Module.cpp def _cuda_getCurrentStream(device: _int) -> Tuple: ... def _cuda_getCurrentRawStream(device: _int) -> _int: ... diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 1d9e295c60e..a5ef894e41b 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -60,6 +60,7 @@ #include #include #include +#include #include #include #include @@ -87,10 +88,6 @@ #endif #endif -#if defined(USE_MPS) -#include -#endif - #if defined(USE_VALGRIND) #include #endif @@ -1271,6 +1268,7 @@ PyObject* initModule() { THPUtils_addPyMethodDefs(methods, DataLoaderMethods); THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions()); THPUtils_addPyMethodDefs(methods, torch::multiprocessing::python_functions()); + THPUtils_addPyMethodDefs(methods, torch::mps::python_functions()); #ifdef USE_CUDA THPUtils_addPyMethodDefs(methods, THCPModule_methods()); #endif @@ -1593,15 +1591,6 @@ Call this whenever a new thread is created in order to propagate values from ASSERT_TRUE(set_module_attr("has_cuda", has_cuda)); ASSERT_TRUE(set_module_attr("has_mps", has_mps)); - py_module.def("_is_mps_available", []() { return at::hasMPS(); }); - py_module.def("_is_mps_on_macos_13_or_newer", []() { -#ifdef USE_MPS - return at::mps::is_macos_13_or_newer(); -#else - return false; -#endif - }); - ASSERT_TRUE( set_module_attr("has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False)); diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp new file mode 100644 index 00000000000..244aac3a394 --- /dev/null +++ b/torch/csrc/mps/Module.cpp @@ -0,0 +1,102 @@ +#include +#include +#include +#include +#include + +// pthread.h is included for tracking bad forks +#ifndef WIN32 +#include +#endif + +namespace torch { +namespace mps { + +namespace { +// True for children forked after mps init +static bool in_bad_fork = false; + +// Called in the forked child if mps has already been initialized +static void forked_mps_child() { + in_bad_fork = true; +} + +// Should be called before the first mps call. +static void track_bad_mps_fork() { +#ifndef WIN32 + static c10::once_flag flag; + c10::call_once( + flag, [] { pthread_atfork(nullptr, nullptr, forked_mps_child); }); +#endif +} +} // namespace + +static PyObject* MPSModule_isInBadFork(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + return PyBool_FromLong(in_bad_fork); + END_HANDLE_TH_ERRORS +} + +static PyObject* MPSModule_getDefaultMPSGenerator( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + track_bad_mps_fork(); + return THPGenerator_initDefaultGenerator( + at::detail::getMPSHooks().getDefaultMPSGenerator()); + END_HANDLE_TH_ERRORS +} + +static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) { + HANDLE_TH_ERRORS + track_bad_mps_fork(); + if (at::detail::getMPSHooks().hasMPS()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +static PyObject* MPSModule_isMacOS13orNewer( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + if (at::detail::getMPSHooks().isOnMacOS13orNewer()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +static PyObject* MPSModule_synchronize(PyObject* _unused, PyObject* noargs) { + HANDLE_TH_ERRORS + at::detail::getMPSHooks().deviceSynchronize(); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +// NOLINTNEXTLINE(modernize-avoid-c-arrays, +// cppcoreguidelines-avoid-non-const-global-variables, +// cppcoreguidelines-avoid-c-arrays) +static struct PyMethodDef _MPSModule_methods[] = { + {"_mps_synchronize", MPSModule_synchronize, METH_NOARGS, nullptr}, + {"_mps_is_in_bad_fork", MPSModule_isInBadFork, METH_NOARGS, nullptr}, + {"_is_mps_available", MPSModule_isAvailable, METH_NOARGS, nullptr}, + {"_is_mps_on_macos_13_or_newer", + MPSModule_isMacOS13orNewer, + METH_NOARGS, + nullptr}, + {"_mps_get_default_generator", + MPSModule_getDefaultMPSGenerator, + METH_NOARGS, + nullptr}, + {nullptr}}; + +PyMethodDef* python_functions() { + return _MPSModule_methods; +} + +} // namespace mps +} // namespace torch diff --git a/torch/csrc/mps/Module.h b/torch/csrc/mps/Module.h new file mode 100644 index 00000000000..3759d36d738 --- /dev/null +++ b/torch/csrc/mps/Module.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace torch { +namespace mps { + +PyMethodDef* python_functions(); + +} // namespace mps +} // namespace torch diff --git a/torch/mps/__init__.py b/torch/mps/__init__.py new file mode 100644 index 00000000000..42e98c9030d --- /dev/null +++ b/torch/mps/__init__.py @@ -0,0 +1,54 @@ +r""" +This package enables an interface for accessing MPS backend in python +""" +import torch +from .. import Tensor + +_is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False) +_default_mps_generator: torch._C.Generator = None # type: ignore[assignment] + +# local helper function (not public or exported) +def _get_default_mps_generator() -> torch._C.Generator: + global _default_mps_generator + if _default_mps_generator is None: + _default_mps_generator = torch._C._mps_get_default_generator() + return _default_mps_generator + +def synchronize() -> None: + r"""Waits for all kernels in all streams on a MPS device to complete.""" + return torch._C._mps_synchronize() + +def get_rng_state() -> Tensor: + r"""Returns the random number generator state as a ByteTensor.""" + return _get_default_mps_generator().get_state() + +def set_rng_state(new_state: Tensor) -> None: + r"""Sets the random number generator state. + + Args: + new_state (torch.ByteTensor): The desired state + """ + new_state_copy = new_state.clone(memory_format=torch.contiguous_format) + _get_default_mps_generator().set_state(new_state_copy) + +def manual_seed(seed: int) -> None: + r"""Sets the seed for generating random numbers. + + Args: + seed (int): The desired seed. + """ + # the torch.mps.manual_seed() can be called from the global + # torch.manual_seed() in torch/random.py. So we need to make + # sure mps is available (otherwise we just return without + # erroring out) + if not torch.has_mps: + return + seed = int(seed) + _get_default_mps_generator().manual_seed(seed) + +def seed() -> None: + r"""Sets the seed for generating random numbers to a random number.""" + _get_default_mps_generator().seed() + +__all__ = [ + 'get_rng_state', 'manual_seed', 'seed', 'set_rng_state', 'synchronize'] diff --git a/torch/random.py b/torch/random.py index f5156bf4873..e4795907a3a 100644 --- a/torch/random.py +++ b/torch/random.py @@ -39,6 +39,10 @@ def manual_seed(seed) -> torch._C.Generator: if not torch.cuda._is_in_bad_fork(): torch.cuda.manual_seed_all(seed) + import torch.mps + if not torch.mps._is_in_bad_fork(): + torch.mps.manual_seed(seed) + return default_generator.manual_seed(seed) @@ -52,6 +56,10 @@ def seed() -> int: if not torch.cuda._is_in_bad_fork(): torch.cuda.manual_seed_all(seed) + import torch.mps + if not torch.mps._is_in_bad_fork(): + torch.mps.manual_seed(seed) + return seed