From d70b7029c896de89f5e38c0dedcc03db079d15db Mon Sep 17 00:00:00 2001 From: Hyunho Yeo Date: Thu, 28 Nov 2024 02:24:19 +0000 Subject: [PATCH] [MTIA] Support torch.mtia.empty_cache() (#141533) Summary: As title Test Plan: Passed a local unit test: `buck2 test //mtia/host_runtime/torch_mtia/tests:test_torch_mtia_api` https://www.internalfb.com/intern/testinfra/testrun/4785074861101240 Reviewed By: nautsimon Differential Revision: D66481778 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141533 Approved by: https://github.com/nautsimon --- aten/src/ATen/detail/MTIAHooksInterface.h | 5 +++++ docs/source/mtia.rst | 1 + torch/_C/__init__.pyi.in | 2 +- torch/csrc/mtia/Module.cpp | 2 ++ torch/mtia/__init__.py | 6 ++++++ 5 files changed, 15 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/detail/MTIAHooksInterface.h b/aten/src/ATen/detail/MTIAHooksInterface.h index 035680e9a33..bcb26320eed 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.h +++ b/aten/src/ATen/detail/MTIAHooksInterface.h @@ -109,6 +109,11 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { FAIL_MTIAHOOKS_FUNC(__func__); return nullptr; } + + virtual void emptyCache() const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + }; struct TORCH_API MTIAHooksArgs {}; diff --git a/docs/source/mtia.rst b/docs/source/mtia.rst index 548c6463d08..c25972d003d 100644 --- a/docs/source/mtia.rst +++ b/docs/source/mtia.rst @@ -20,6 +20,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined is_initialized memory_stats get_device_capability + empty_cache set_device set_stream stream diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 3cdf304e632..f858020f6e1 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1789,7 +1789,7 @@ def _mtia_setCurrentStream(stream: Stream) -> None: ... def _mtia_getDefaultStream(device: _int) -> Stream: ... def _mtia_memoryStats(device: _int) -> Dict[str, Any]: ... def _mtia_getDeviceCapability(device: _int) -> Tuple[_int, _int]: ... - +def _mtia_emptyCache() -> None: ... # Defined in torch/csrc/mps/Module.cpp def _mps_deviceSynchronize() -> None: ... diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp index 37624b3737d..d77dc3b95a2 100644 --- a/torch/csrc/mtia/Module.cpp +++ b/torch/csrc/mtia/Module.cpp @@ -85,6 +85,8 @@ void initModule(PyObject* module) { at::detail::getMTIAHooks().getDeviceCapability(device_index); return py::reinterpret_steal(raw_pyobject); }); + + m.def("_mtia_emptyCache", []() { at::detail::getMTIAHooks().emptyCache(); }); } } // namespace torch::mtia diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index 36cc6dab0c0..6ab681a0401 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -175,6 +175,11 @@ def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int] return torch._C._mtia_getDeviceCapability(_get_device_index(device, optional=True)) +def empty_cache() -> None: + r"""Empty the MTIA device cache.""" + return torch._C._mtia_emptyCache() + + def set_stream(stream: Stream): r"""Set the current stream.This is a wrapper API to set the stream. Usage of this function is discouraged in favor of the ``stream`` @@ -333,6 +338,7 @@ __all__ = [ "default_stream", "memory_stats", "get_device_capability", + "empty_cache", "set_device", "set_stream", "stream",