From 8461e7ed9e68d1b7274e69d5396ff343ac120568 Mon Sep 17 00:00:00 2001 From: egienvalue Date: Thu, 25 Apr 2024 11:48:49 -0700 Subject: [PATCH] Add test_cpp_extensions tests for stream_and_event and mita_backend (#123614) Test the generic torch.Stream/Event with fake device gurad and hooks. Since we added a fake device backend, it is mutual exclusive to other backends. Tests will be skipped if TEST_CUDA or TEST_ROCM is true. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123614 Approved by: https://github.com/albanD ghstack dependencies: #123611, #123612 --- test/cpp_extensions/mtia_extension.cpp | 219 +++++++++++++++++++ test/run_test.py | 2 + test/test_cpp_extensions_mtia_backend.py | 155 +++++++++++++ test/test_cpp_extensions_stream_and_event.py | 109 +++++++++ tools/testing/modulefinder_determinator.py | 2 + 5 files changed, 487 insertions(+) create mode 100644 test/cpp_extensions/mtia_extension.cpp create mode 100644 test/test_cpp_extensions_mtia_backend.py create mode 100644 test/test_cpp_extensions_stream_and_event.py diff --git a/test/cpp_extensions/mtia_extension.cpp b/test/cpp_extensions/mtia_extension.cpp new file mode 100644 index 00000000000..3b02d3968e4 --- /dev/null +++ b/test/cpp_extensions/mtia_extension.cpp @@ -0,0 +1,219 @@ +#include +#include +#include +#include +#include +#include +#include +namespace torch::mtia { + +constexpr c10::DeviceType kMTIADeviceType = c10::DeviceType::MTIA; +constexpr c10::DeviceIndex kMTIADeviceCount = 2; +static thread_local c10::DeviceIndex current_device = 0; +static thread_local std::array current_streams = + {c10::Stream::unpack3(0, 0, c10::DeviceType::MTIA), + c10::Stream::unpack3(0, 1, c10::DeviceType::MTIA)}; +static int64_t stream_id_gen = 1; +static int64_t event_id_gen = 1; +static std::array default_streams = { + c10::Stream::unpack3(0, 0, c10::DeviceType::MTIA), + c10::Stream::unpack3(0, 1, c10::DeviceType::MTIA)}; +struct MTIAGuardImpl final : public c10::impl::DeviceGuardImplInterface { + MTIAGuardImpl() = default; + explicit MTIAGuardImpl(c10::DeviceType t) { + TORCH_INTERNAL_ASSERT(t == kMTIADeviceType); + } + c10::DeviceType type() const override { + return kMTIADeviceType; + } + c10::Device exchangeDevice(c10::Device d) const override { + c10::Device old_device = getDevice(); + if (old_device.index() != d.index()) { + setDevice(d); + } + return old_device; + } + c10::Device getDevice() const override { + return c10::Device(kMTIADeviceType, current_device); + } + + void setDevice(c10::Device d) const override { + c10::Device current_device = getDevice(); + if (current_device.index() != d.index()) { + current_device = d; + } + } + void uncheckedSetDevice(c10::Device d) const noexcept override { + (void)d; + } + c10::Stream getStream(c10::Device d) const noexcept override { + return current_streams[d.index()]; + } + c10::Stream getNewStream(c10::Device d, int priority = 0) const override { + (void)priority; + return c10::Stream::unpack3(stream_id_gen++, d.index(), d.type()); + } + c10::Stream getDefaultStream(c10::Device d) const override { + return default_streams[d.index()]; + } + c10::Stream getStreamFromGlobalPool( + c10::Device d, + bool isHighPriority = false) const override { + return c10::Stream::unpack3(stream_id_gen++, d.index(), d.type()); + } + // NB: These do NOT set the current device + c10::Stream exchangeStream(c10::Stream s) const noexcept override { + c10::Stream old_stream = getStream(s.device()); + return old_stream; + } + c10::DeviceIndex deviceCount() const noexcept override { + return kMTIADeviceCount; + } + + void destroyEvent(void* event, const c10::DeviceIndex device_index) + const noexcept override { + (void)device_index; + } + + void record( + void** event, + const c10::Stream& stream, + const c10::DeviceIndex device_index, + const c10::EventFlag flag) const override { + TORCH_CHECK( + device_index == -1 || device_index == stream.device_index(), + "Event device index ", + device_index, + " does not match recording stream's device index ", + stream.device_index(), + "."); + + const auto orig_device = getDevice(); + + setDevice(stream.device()); + + if (*event == nullptr) { + *event = reinterpret_cast(event_id_gen++); + } + setDevice(orig_device); + } + + void block(void* event, const c10::Stream& stream) const override { + (void)event; + (void)stream; + } + + // May be called from any device + bool queryEvent(void* event) const override { + (void)event; + return true; + } + + // Stream-related functions + bool queryStream(const c10::Stream& stream) const override { + (void)stream; + return true; + } + + void synchronizeStream(const c10::Stream& stream) const override { + (void)stream; + } + + void recordDataPtrOnStream( + const c10::DataPtr& data_ptr, + const c10::Stream& stream) const override { + (void)data_ptr; + (void)stream; + } + + double elapsedTime(void* event1, void* event2) const override { + uint64_t elapsed_time = 1e6; + return (double)(elapsed_time / 1e6); + } + + void synchronizeEvent(void* event) const override { + (void)event; + } +}; + +struct MTIAHooks : public at::MTIAHooksInterface { + explicit MTIAHooks(at::MTIAHooksArgs) {} + void initMTIA() const override {} + + bool hasMTIA() const override { + return true; + } + + c10::DeviceIndex deviceCount() const override { + torch::utils::device_lazy_init(at::kMTIA); + return c10::DeviceIndex(2); + } + + void deviceSynchronize(c10::DeviceIndex device_index) const override { + torch::utils::device_lazy_init(at::kMTIA); + (void)device_index; + } + + std::string showConfig() const override { + return "None config"; + } + + c10::DeviceIndex exchangeDevice(c10::DeviceIndex device) const override { + torch::utils::device_lazy_init(at::kMTIA); + auto orig_device = current_device; + if (current_device != device) { + current_device = device; + } + return orig_device; + } + + c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device) const override { + torch::utils::device_lazy_init(at::kMTIA); + + auto orig_device = current_device; + if (current_device != device) { + current_device = device; + } + return orig_device; + } + + c10::Stream getDefaultStream(c10::DeviceIndex device) const override { + torch::utils::device_lazy_init(at::kMTIA); + + return default_streams[device]; + } + + c10::Stream getCurrentStream(c10::DeviceIndex device) const override { + torch::utils::device_lazy_init(at::kMTIA); + + return current_streams[device]; + } + + void setCurrentStream(const c10::Stream& stream) const override { + torch::utils::device_lazy_init(at::kMTIA); + + current_streams[stream.device_index()] = stream; + } + + c10::DeviceIndex getCurrentDevice() const override { + torch::utils::device_lazy_init(at::kMTIA); + + return current_device; + } + + void setCurrentDevice(c10::DeviceIndex device) const override { + torch::utils::device_lazy_init(at::kMTIA); + + if (current_device != device) { + current_device = device; + } + } +}; + +using at::MTIAHooksRegistry; +using at::RegistererMTIAHooksRegistry; + +REGISTER_MTIA_HOOKS(MTIAHooks); +C10_REGISTER_GUARD_IMPL(MTIA, MTIAGuardImpl); + +} // namespace torch::mtia diff --git a/test/run_test.py b/test/run_test.py index 3626d31fc28..516dbc753ff 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -191,6 +191,8 @@ XPU_TEST = [ RUN_PARALLEL_BLOCKLIST = [ "test_cpp_extensions_jit", "test_cpp_extensions_open_device_registration", + "test_cpp_extensions_stream_and_event", + "test_cpp_extensions_mtia_backend", "test_jit_disabled", "test_mobile_optimizer", "test_multiprocessing", diff --git a/test/test_cpp_extensions_mtia_backend.py b/test/test_cpp_extensions_mtia_backend.py new file mode 100644 index 00000000000..f1613dcf7da --- /dev/null +++ b/test/test_cpp_extensions_mtia_backend.py @@ -0,0 +1,155 @@ +# Owner(s): ["module: mtia"] + +import os +import shutil +import sys +import tempfile +import unittest + +import torch +import torch.testing._internal.common_utils as common +import torch.utils.cpp_extension +from torch.testing._internal.common_utils import ( + IS_ARM64, + IS_LINUX, + skipIfTorchDynamo, + TEST_CUDA, + TEST_PRIVATEUSE1, +) +from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME + + +# define TEST_ROCM before changing TEST_CUDA +TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None +TEST_CUDA = TEST_CUDA and CUDA_HOME is not None + + +def remove_build_path(): + if sys.platform == "win32": + # Not wiping extensions build folder because Windows + return + default_build_root = torch.utils.cpp_extension.get_default_build_root() + if os.path.exists(default_build_root): + shutil.rmtree(default_build_root, ignore_errors=True) + + +@unittest.skipIf( + IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_PRIVATEUSE1 or TEST_ROCM, + "Only on linux platform and mutual exclusive to other backends", +) +@torch.testing._internal.common_utils.markDynamoStrictTest +class TestCppExtensionMTIABackend(common.TestCase): + """Tests MTIA backend with C++ extensions.""" + + module = None + + def setUp(self): + super().setUp() + # cpp extensions use relative paths. Those paths are relative to + # this file, so we'll change the working directory temporarily + self.old_working_dir = os.getcwd() + os.chdir(os.path.dirname(os.path.abspath(__file__))) + + def tearDown(self): + super().tearDown() + # return the working directory (see setUp) + os.chdir(self.old_working_dir) + + @classmethod + def tearDownClass(cls): + remove_build_path() + + @classmethod + def setUpClass(cls): + remove_build_path() + build_dir = tempfile.mkdtemp() + # Load the fake device guard impl. + cls.module = torch.utils.cpp_extension.load( + name="mtia_extension", + sources=["cpp_extensions/mtia_extension.cpp"], + build_directory=build_dir, + extra_include_paths=[ + "cpp_extensions", + "path / with spaces in it", + "path with quote'", + ], + is_python_module=False, + verbose=True, + ) + + @skipIfTorchDynamo("Not a TorchDynamo suitable test") + def test_get_device_module(self): + device = torch.device("mtia:0") + default_stream = torch.get_device_module(device).current_stream() + self.assertEqual( + default_stream.device_type, int(torch._C._autograd.DeviceType.MTIA) + ) + print(torch._C.Stream.__mro__) + print(torch.cuda.Stream.__mro__) + + @skipIfTorchDynamo("Not a TorchDynamo suitable test") + def test_stream_basic(self): + default_stream = torch.mtia.current_stream() + user_stream = torch.mtia.Stream() + self.assertEqual(torch.mtia.current_stream(), default_stream) + self.assertNotEqual(default_stream, user_stream) + # Check mtia_extension.cpp, default stream id starts from 0. + self.assertEqual(default_stream.stream_id, 0) + self.assertNotEqual(user_stream.stream_id, 0) + with torch.mtia.stream(user_stream): + self.assertEqual(torch.mtia.current_stream(), user_stream) + self.assertTrue(user_stream.query()) + default_stream.synchronize() + self.assertTrue(default_stream.query()) + + @skipIfTorchDynamo("Not a TorchDynamo suitable test") + def test_stream_context(self): + mtia_stream_0 = torch.mtia.Stream(device="mtia:0") + mtia_stream_1 = torch.mtia.Stream(device="mtia:0") + print(mtia_stream_0) + print(mtia_stream_1) + with torch.mtia.stream(mtia_stream_0): + current_stream = torch.mtia.current_stream() + msg = f"current_stream {current_stream} should be {mtia_stream_0}" + self.assertTrue(current_stream == mtia_stream_0, msg=msg) + + with torch.mtia.stream(mtia_stream_1): + current_stream = torch.mtia.current_stream() + msg = f"current_stream {current_stream} should be {mtia_stream_1}" + self.assertTrue(current_stream == mtia_stream_1, msg=msg) + + @skipIfTorchDynamo("Not a TorchDynamo suitable test") + def test_stream_context_different_device(self): + device_0 = torch.device("mtia:0") + device_1 = torch.device("mtia:1") + mtia_stream_0 = torch.mtia.Stream(device=device_0) + mtia_stream_1 = torch.mtia.Stream(device=device_1) + print(mtia_stream_0) + print(mtia_stream_1) + orig_current_device = torch.mtia.current_device() + with torch.mtia.stream(mtia_stream_0): + current_stream = torch.mtia.current_stream() + self.assertTrue(torch.mtia.current_device() == device_0.index) + msg = f"current_stream {current_stream} should be {mtia_stream_0}" + self.assertTrue(current_stream == mtia_stream_0, msg=msg) + self.assertTrue(torch.mtia.current_device() == orig_current_device) + with torch.mtia.stream(mtia_stream_1): + current_stream = torch.mtia.current_stream() + self.assertTrue(torch.mtia.current_device() == device_1.index) + msg = f"current_stream {current_stream} should be {mtia_stream_1}" + self.assertTrue(current_stream == mtia_stream_1, msg=msg) + self.assertTrue(torch.mtia.current_device() == orig_current_device) + + @skipIfTorchDynamo("Not a TorchDynamo suitable test") + def test_device_context(self): + device_0 = torch.device("mtia:0") + device_1 = torch.device("mtia:1") + with torch.mtia.device(device_0): + self.assertTrue(torch.mtia.current_device() == device_0.index) + + with torch.mtia.device(device_1): + self.assertTrue(torch.mtia.current_device() == device_1.index) + + +if __name__ == "__main__": + common.run_tests() diff --git a/test/test_cpp_extensions_stream_and_event.py b/test/test_cpp_extensions_stream_and_event.py new file mode 100644 index 00000000000..728ac5f9809 --- /dev/null +++ b/test/test_cpp_extensions_stream_and_event.py @@ -0,0 +1,109 @@ +# Owner(s): ["module: mtia"] + +import os +import shutil +import sys +import tempfile +import unittest + +import torch +import torch.testing._internal.common_utils as common +import torch.utils.cpp_extension +from torch.testing._internal.common_utils import ( + IS_ARM64, + IS_LINUX, + skipIfTorchDynamo, + TEST_CUDA, + TEST_PRIVATEUSE1, +) +from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME + + +# define TEST_ROCM before changing TEST_CUDA +TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None +TEST_CUDA = TEST_CUDA and CUDA_HOME is not None + + +def remove_build_path(): + if sys.platform == "win32": + # Not wiping extensions build folder because Windows + return + default_build_root = torch.utils.cpp_extension.get_default_build_root() + if os.path.exists(default_build_root): + shutil.rmtree(default_build_root, ignore_errors=True) + + +# Since we use a fake MTIA device backend to test generic Stream/Event, device backends are mutual exclusive to each other. +# The test will be skipped if any of the following conditions are met: +@unittest.skipIf( + IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_PRIVATEUSE1 or TEST_ROCM, + "Only on linux platform and mutual exclusive to other backends", +) +@torch.testing._internal.common_utils.markDynamoStrictTest +class TestCppExtensionStreamAndEvent(common.TestCase): + """Tests Stream and Event with C++ extensions.""" + + module = None + + def setUp(self): + super().setUp() + # cpp extensions use relative paths. Those paths are relative to + # this file, so we'll change the working directory temporarily + self.old_working_dir = os.getcwd() + os.chdir(os.path.dirname(os.path.abspath(__file__))) + + def tearDown(self): + super().tearDown() + # return the working directory (see setUp) + os.chdir(self.old_working_dir) + + @classmethod + def tearDownClass(cls): + remove_build_path() + + @classmethod + def setUpClass(cls): + remove_build_path() + build_dir = tempfile.mkdtemp() + # Load the fake device guard impl. + src = f"{os.path.abspath(os.path.dirname(__file__))}/cpp_extensions/mtia_extension.cpp" + cls.module = torch.utils.cpp_extension.load( + name="mtia_extension", + sources=[src], + build_directory=build_dir, + extra_include_paths=[ + "cpp_extensions", + "path / with spaces in it", + "path with quote'", + ], + is_python_module=False, + verbose=True, + ) + + @skipIfTorchDynamo("Not a TorchDynamo suitable test") + def test_stream_event(self): + s = torch.Stream() + self.assertTrue(s.device_type, int(torch._C._autograd.DeviceType.MTIA)) + e = torch.Event() + self.assertTrue(e.device.type, "mtia") + # Should be nullptr by default + self.assertTrue(e.event_id == 0) + s.record_event(event=e) + print(f"recorded event 1: {e}") + self.assertTrue(e.event_id != 0) + e2 = s.record_event() + print(f"recorded event 2: {e2}") + self.assertTrue(e2.event_id != 0) + self.assertTrue(e2.event_id != e.event_id) + e.synchronize() + e2.synchronize() + time_elapsed = e.elapsed_time(e2) + print(f"time elapsed between e1 and e2: {time_elapsed}") + old_event_id = e.event_id + e.record(stream=s) + print(f"recorded event 1: {e}") + self.assertTrue(e.event_id == old_event_id) + + +if __name__ == "__main__": + common.run_tests() diff --git a/tools/testing/modulefinder_determinator.py b/tools/testing/modulefinder_determinator.py index ce55fdb4245..ba58d75c57f 100644 --- a/tools/testing/modulefinder_determinator.py +++ b/tools/testing/modulefinder_determinator.py @@ -21,6 +21,8 @@ TARGET_DET_LIST = [ "test_cpp_extensions_aot_no_ninja", "test_cpp_extensions_jit", "test_cpp_extensions_open_device_registration", + "test_cpp_extensions_stream_and_event", + "test_cpp_extensions_mtia_backend", "test_cuda", "test_cuda_primary_ctx", "test_dataloader",