diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index 0f4bf91edc2..b1c99145311 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -405,6 +405,22 @@ class TestWithNCCL(MultiProcessTestCase): assert output.eq(expect).all() assert output.completed + @skip_if_lt_x_gpu(2) + def test_wait_tensor(self) -> None: + self._init_process_group() + + input = torch.full((10, 10), float(self.rank), device=self.device) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + output = torch.ops._c10d_functional.all_reduce( + input, + "avg", + "default", + ) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1) + torch.ops._c10d_functional.wait_tensor(output) + # `wait_tensor(output)` will pop the work from the work registry immediately + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + @skip_if_lt_x_gpu(2) def test_unwaited(self) -> None: # Verify that the process can terminate gracefully @@ -412,11 +428,13 @@ class TestWithNCCL(MultiProcessTestCase): self._init_process_group() input = torch.full((10, 10), float(self.rank), device=self.device) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) output = torch.ops._c10d_functional.all_reduce( input, "avg", "default", ) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1) @skip_if_lt_x_gpu(2) def test_py_work(self) -> None: diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 64a210ed3b6..312a6acccad 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -20,6 +20,7 @@ from unittest import mock, SkipTest import torch import torch.distributed as c10d +import torch.distributed._functional_collectives as _functional_collectives if not c10d.is_available() or not c10d.is_nccl_available(): @@ -3218,6 +3219,86 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase): with self.assertRaisesRegex(TypeError, "Invalid function argument"): c10d.barrier(device_ids=self.rank) + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_unwaited(self) -> None: + # Verify that the process can terminate gracefully + # even with unwaited tensors + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="nccl", rank=self.rank, world_size=self.world_size, store=store + ) + + # Case 1: Run collectives under context manager, and don't call wait on them. + with _functional_collectives.allow_inflight_collective_as_graph_input_ctx(): + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + input = torch.full( + (10240, 10240), float(self.rank), device=f"cuda:{self.rank}" + ) + dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True) + # Non-functional collectives run under the context manager is registered in the work registry. + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1) + # Running another collective on the same tensor should still work + dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2) + + # Case 2: Run collectives not under context manager, and don't call wait on them. + # NOTE: Here we intentionally test memory-stressed case. + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2) + for _ in range(50000): + input = torch.full( + (1024, 1024), float(self.rank), device=f"cuda:{self.rank}" + ) + dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True) + # Work registry size is unchanged, since non-functional collectives not run under + # the context manager is not registered in the work registry. + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_wait_tensor(self) -> None: + # Verify that c10d_functional.wait_tensor() can be invoked on + # output tensor of non-functional collective + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="nccl", rank=self.rank, world_size=self.world_size, store=store + ) + + # Case 1: under context manager (i.e. work is registered in registry) + with _functional_collectives.allow_inflight_collective_as_graph_input_ctx(): + input1 = torch.full((10, 10), float(self.rank), device=f"cuda:{self.rank}") + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1) + torch.ops.c10d_functional.wait_tensor(input1) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + + input2 = torch.full((10, 10), float(self.rank), device=f"cuda:{self.rank}") + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + work = dist.all_reduce(input2, op=dist.ReduceOp.SUM, async_op=True) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1) + work.wait() + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + self.assertEqual(input1, input2) + + # Case 2: not under context manager (i.e. work is not registered in registry) + input1 = torch.full((10, 10), float(self.rank), device=f"cuda:{self.rank}") + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + # this does not take effect, since the underlying wait_tensor() logic would not + # be able to find the corresponding work object (because it's not registered in registry) + torch.ops.c10d_functional.wait_tensor(input1) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + + input2 = torch.full((10, 10), float(self.rank), device=f"cuda:{self.rank}") + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + work = dist.all_reduce(input2, op=dist.ReduceOp.SUM, async_op=True) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + work.wait() + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + self.assertEqual(input1, input2) + @requires_nccl() @skip_if_lt_x_gpu(2) @with_dist_debug_levels(levels=["DETAIL"]) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index f59c471a0f9..18c5f844a42 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1,4 +1,5 @@ # Owner(s): ["module: dynamo"] +import datetime import functools import unittest from unittest.mock import patch @@ -28,6 +29,7 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, requires_cuda, + skipIfRocm, ) from torch.testing._internal.inductor_utils import HAS_GPU @@ -245,6 +247,90 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase): ) self.assertTrue(same(eager_out, inductor_out, tol=0.001)) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @skip_if_lt_x_gpu(2) + @skipIfRocm + def test_eager_async_allreduce_inductor_wait(self): + import torch.distributed as dist + from torch._inductor.utils import run_and_get_code + + def all_reduce_non_functional_eager(x): + y = x * x + work = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True) + assert isinstance(work, torch.distributed.Work) + return work, y + + def all_reduce_wait(work, y): # potentially compiled + if torch.compiler.is_dynamo_compiling(): + torch.ops.c10d_functional.wait_tensor(y) + else: + work.wait(datetime.timedelta(seconds=10)) + # Under compile, if `wait_tensor(y)` above is correctly executed, + # `y`'s data is in its final form and the output of this function will match eager; + # otherwise, `y * y` will run in parallel with `all_reduce(y)` and the output of this function + # will not match eager. + return y * y + + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + x = torch.ones(12800, 12800, device="cuda") + self.rank + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + + # NOTE: We run for 10 iterations each, to ensure that the GPU execution is way behind CPU + # and that `y * y` on CPU side will be issued before `all_reduce(y)` on GPU side is done, + # thus guaranteeing that in the bad case `y * y` on GPU side will run in parallel with `all_reduce(y)` + # thus will produce the wrong result that fails the unit test. + + def _run_loop_collective_wait(x, wait_fn, expected_registry_size): + for _ in range(10): + self.assertEqual( + torch._C._distributed_c10d._get_work_registry_size(), 0 + ) + work, y = all_reduce_non_functional_eager(x) + self.assertEqual( + torch._C._distributed_c10d._get_work_registry_size(), + expected_registry_size, + ) + out = wait_fn(work, y) + self.assertEqual( + torch._C._distributed_c10d._get_work_registry_size(), 0 + ) + return work, y, out + + # Test: Pure-eager + all_reduce_wait_eager = all_reduce_wait + work, y, out_ref = _run_loop_collective_wait( + x, + wait_fn=all_reduce_wait_eager, + expected_registry_size=0, + ) + + all_reduce_wait_compiled = torch.compile( + all_reduce_wait, + backend="inductor", + fullgraph=True, + ) + + # Test: Issue comm in eager -> wait for comm in compile. Use the context manager. + with _functional_collectives.allow_inflight_collective_as_graph_input_ctx(): + work, y, out_compiled = _run_loop_collective_wait( + x, wait_fn=all_reduce_wait_compiled, expected_registry_size=1 + ) + self.assertEqual(out_ref, out_compiled) + + # Check that `wait_tensor()` is in the Inductor generated code + _, triton_codes = run_and_get_code(all_reduce_wait_compiled, work, y) + FileCheck().check("torch.ops._c10d_functional.wait_tensor.default(").run( + triton_codes[0] + ) + + # Failure Case: Issue comm in eager -> wait for comm in compile. Doesn't use the context manager. + _, _, out_compiled = _run_loop_collective_wait( + x, wait_fn=all_reduce_wait_compiled, expected_registry_size=0 + ) + # In this case `.wait_tensor(y)` in compiled region will not be able to find the corresponding work object + # to invoke the wait, thus the result will not match eager. + self.assertNotEqual(out_ref, out_compiled) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index fea0f54f538..1dde77d756e 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -628,6 +628,11 @@ def _register_process_group( ) -> None: ... def _resolve_process_group(group_name: str) -> ProcessGroup: ... def _register_work(tensor: torch.Tensor, work: Work) -> ProcessGroup: ... +def _get_work_registry_size() -> int: ... +def _set_allow_inflight_collective_as_graph_input( + value: bool, +) -> None: ... +def _allow_inflight_collective_as_graph_input() -> bool: ... def _unregister_all_process_groups() -> None: ... def _unregister_process_group(group_name: str) -> None: ... diff --git a/torch/csrc/distributed/c10d/Functional.cpp b/torch/csrc/distributed/c10d/Functional.cpp index 1117718ee50..30088a29c82 100644 --- a/torch/csrc/distributed/c10d/Functional.cpp +++ b/torch/csrc/distributed/c10d/Functional.cpp @@ -6,80 +6,10 @@ #include #include #include -#include #include namespace { -class WorkRegistry { - public: - void register_work( - const at::Tensor& tensor, - const c10::intrusive_ptr& work) { - auto storage = tensor.storage().getWeakStorageImpl(); - std::unique_lock lock(lock_); - auto [it, inserted] = registry_.try_emplace(std::move(storage), work); - TORCH_CHECK( - inserted || it->second != work, - "The tensor storage is already associated with another work."); - } - - c10::intrusive_ptr pop_work(const at::Tensor& tensor) { - const auto storage = tensor.storage().getWeakStorageImpl(); - std::unique_lock lock(lock_); - auto it = registry_.find(storage); - if (it == registry_.end()) { - return nullptr; - } - auto work = it->second; - registry_.erase(it); - return work; - } - - ~WorkRegistry() { - // If there are still unwaited work objects, their corresponding process - // groups should have already been destroyed at this stage. Any attempts to - // wait for these work objects or to destroy them will only result in - // confusing errors. Therefore, we simply issue a warning and intentionally - // allow the unwaited work objects to leak. - if (!registry_.empty()) { - TORCH_WARN( - "At the time of process termination, there are still ", - registry_.size(), - " unwaited c10d_functional collective calls. " - "Please review your program to ensure c10d_functional.wait_tensor() " - "is invoked on all tensors returned from c10d_functional collective " - "ops before they are used."); - } - for (auto& it : registry_) { - it.second.release(); - } - } - - private: - std::unordered_map< - c10::weak_intrusive_ptr, - c10::intrusive_ptr> - registry_; - std::mutex lock_; -}; - -static WorkRegistry process_registry; - -} // namespace - -namespace c10d { - -void register_work( - const at::Tensor& tensor, - const c10::intrusive_ptr& work) { - RankLocal::get().register_work(tensor, work); -} - -} // namespace c10d - -namespace { - const std::unordered_map str_to_reduce_op = { {"sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::SUM)}, {"avg", c10d::ReduceOp(c10d::ReduceOp::RedOpType::AVG)}, @@ -296,14 +226,6 @@ at::Tensor broadcast( return broadcast_(output, src, std::move(group_name)); } -at::Tensor wait_tensor(const at::Tensor& tensor) { - auto work = c10d::RankLocal::get().pop_work(tensor); - if (work != nullptr) { - work->wait(); - } - return tensor; -} - } // namespace TORCH_LIBRARY(_c10d_functional, m) { @@ -389,7 +311,7 @@ TORCH_LIBRARY(_c10d_functional, m) { m.def( "wait_tensor(Tensor tensor) -> Tensor", torch::dispatch( - c10::DispatchKey::CompositeExplicitAutograd, ::wait_tensor), + c10::DispatchKey::CompositeExplicitAutograd, c10d::wait_tensor), {at::Tag::pt2_compliant_tag}); } @@ -438,7 +360,7 @@ class AllToAllSingle : public torch::autograd::Function { // TODO: track active cuda stream in wait out = c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::wait_tensor", "") - .typed() + .typed() .call(out); return {out, at::Tensor(), at::Tensor(), at::Tensor()}; @@ -493,7 +415,7 @@ class ReduceScatterTensor // TODO: track active cuda stream in wait out = c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::wait_tensor", "") - .typed() + .typed() .call(out); return { @@ -549,7 +471,7 @@ class AllGatherIntoTensor // TODO: track active cuda stream in wait out = c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::wait_tensor", "") - .typed() + .typed() .call(out); return { diff --git a/torch/csrc/distributed/c10d/Functional.hpp b/torch/csrc/distributed/c10d/Functional.hpp index cbb19e68609..e81d44b8dbd 100644 --- a/torch/csrc/distributed/c10d/Functional.hpp +++ b/torch/csrc/distributed/c10d/Functional.hpp @@ -1,11 +1,3 @@ #pragma once -#include - -namespace c10d { - -C10_EXPORT void register_work( - const at::Tensor& tensor, - const c10::intrusive_ptr& work); - -} // namespace c10d +#include diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index 63d64447dfd..48816b88fd2 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -11,7 +12,6 @@ #include #include #include -#include namespace c10d { @@ -159,3 +159,172 @@ void ProcessGroup::release_resources() { } } // namespace c10d + +namespace { + +class WorkRegistry { + public: + void register_work( + const at::Tensor& tensor, + const c10::intrusive_ptr& work) { + if (!tensor.has_storage()) { + TORCH_WARN_ONCE( + "Registering collective work for tensor without storage is not supported. " + "Calling c10d_functional.wait_tensor() on this tensor will not wait for the collective to complete. " + "Unsupported tensor type: " + + tensor.toString()); + return; + } + auto storage = tensor.storage().getWeakStorageImpl(); + std::unique_lock lock(lock_); + + auto it = registry_.find(storage); + if (it == registry_.end()) { + registry_.emplace( + std::move(storage), + std::vector>{work}); + } else { + // There is no guarantee that the previous work object for this + // tensor storage is completed before the new work object is registered. + // Therefore we need to maintain a list of work objects for each tensor + // storage. + + // Check if work is already in the list + bool work_exists = false; + for (const auto& existing_work : it->second) { + if (existing_work == work) { + work_exists = true; + break; + } + } + + // Only append if work is not already in the list + if (!work_exists) { + it->second.push_back(work); + } + } + } + + std::vector> pop_works( + const at::Tensor& tensor) { + const auto storage = tensor.storage().getWeakStorageImpl(); + std::unique_lock lock(lock_); + auto it = registry_.find(storage); + if (it == registry_.end()) { + return {}; + } + auto works = it->second; + registry_.erase(it); + return works; + } + + void unregister_work(const c10::intrusive_ptr& work) { + std::unique_lock lock(lock_); + for (auto it = registry_.begin(); it != registry_.end();) { + std::vector> nonmatching_works; + for (const auto& _work : it->second) { + if (_work != work) { + nonmatching_works.push_back(_work); + } + } + if (nonmatching_works.empty()) { + it = registry_.erase(it); + } else { + it->second = std::move(nonmatching_works); + ++it; + } + } + } + + size_t get_work_registry_size() { + std::unique_lock lock(lock_); + size_t total_size = 0; + for (const auto& [storage, works] : registry_) { + total_size += works.size(); + } + return total_size; + } + + void set_allow_inflight_collective_as_graph_input(bool value) { + std::unique_lock lock(lock_); + allow_inflight_collective_as_graph_input_ = value; + } + + bool allow_inflight_collective_as_graph_input() { + std::unique_lock lock(lock_); + return allow_inflight_collective_as_graph_input_; + } + + ~WorkRegistry() { + // If there are still unwaited work objects, their corresponding process + // groups should have already been destroyed at this stage. Any attempts to + // wait for these work objects or to destroy them will only result in + // confusing errors. Therefore, we simply issue a warning and intentionally + // allow the unwaited work objects to leak. + size_t registry_size = get_work_registry_size(); + if (registry_size > 0) { + TORCH_WARN( + "At the time of process termination, there are still ", + registry_size, + " unwaited collective calls. " + "Please review your program to ensure that:\n" + "1. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,\n" + "2. c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective " + "called under `with allow_inflight_collective_as_graph_input_ctx():`,\n" + "before the output tensors of the collective are used."); + } + for (auto& it : registry_) { + for (auto& work : it.second) { + work.release(); + } + } + } + + private: + std::unordered_map< + c10::weak_intrusive_ptr, + std::vector>> + registry_; + bool allow_inflight_collective_as_graph_input_ = false; + std::mutex lock_; +}; + +static WorkRegistry process_registry; + +} // namespace + +namespace c10d { + +void register_work( + const at::Tensor& tensor, + const c10::intrusive_ptr& work) { + RankLocal::get().register_work(tensor, work); +} + +at::Tensor wait_tensor(const at::Tensor& tensor) { + auto works = RankLocal::get().pop_works(tensor); + for (const auto& work : works) { + work->wait(); + } + return tensor; +} + +void unregister_work(const c10::intrusive_ptr& work) { + RankLocal::get().unregister_work(work); +} + +size_t get_work_registry_size() { + return RankLocal::get().get_work_registry_size(); +} + +void set_allow_inflight_collective_as_graph_input(bool value) { + return RankLocal::get() + .set_allow_inflight_collective_as_graph_input(value); +} + +bool allow_inflight_collective_as_graph_input() { + return RankLocal::get() + .allow_inflight_collective_as_graph_input(); +} + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 463d1f046db..1f28f5442d0 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -23,6 +24,31 @@ constexpr auto kProcessGroupDefaultTimeout = namespace c10d { +// We only call `register_work()` in two cases: +// 1. If the work object is created from a functional collective call. +// 2. If the work object is created from a non-functional collective call within +// the `with allow_inflight_collective_as_graph_input_ctx()` context manager. +C10_EXPORT void register_work( + const at::Tensor& tensor, + const c10::intrusive_ptr& work); + +C10_EXPORT at::Tensor wait_tensor(const at::Tensor& tensor); + +// We only call `unregister_work()` in one case: +// 1. If the work object is created from a non-functional collective call within +// the `with allow_inflight_collective_as_graph_input_ctx()` context manager. +// +// Q: What about the functional collective case? +// A: The unregistration of work object for functional collective is done in +// the required user-side explicit call to `wait_tensor()`. +C10_EXPORT void unregister_work(const c10::intrusive_ptr& work); + +C10_EXPORT size_t get_work_registry_size(); + +C10_EXPORT void set_allow_inflight_collective_as_graph_input(bool value); + +C10_EXPORT bool allow_inflight_collective_as_graph_input(); + // ProcessGroup is a base class that captures collective and point to // point communication in a fixed set of processes. // @@ -158,13 +184,20 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // It's awakward to unbox the opts here and box them again in the custom C++ // op. But it's also complicated to make opts as a CustomClassHolder. Leave // it as it is now. - return std::get<1>(op.call( + auto work = std::get<1>(op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.rootRank, opts.rootTensor, opts.asyncOp, opts.timeout.count())); + + if (c10d::allow_inflight_collective_as_graph_input()) { + for (const auto& tensor : tensors) { + c10d::register_work(tensor, work); + } + } + return work; } virtual c10::intrusive_ptr allreduce( @@ -181,12 +214,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const std::optional& sparse_indices, int64_t)>(); - return std::get<1>(op.call( + auto work = std::get<1>(op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive(opts.reduceOp), opts.sparseIndices, opts.timeout.count())); + + if (c10d::allow_inflight_collective_as_graph_input()) { + for (const auto& tensor : tensors) { + c10d::register_work(tensor, work); + } + } + return work; } virtual c10::intrusive_ptr allreduce_coalesced( @@ -200,11 +240,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ReduceOp>&, int64_t)>(); - return op.call( + auto work = op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive(opts.reduceOp), opts.timeout.count()); + + if (c10d::allow_inflight_collective_as_graph_input()) { + for (const auto& tensor : tensors) { + c10d::register_work(tensor, work); + } + } + return work; } virtual c10::intrusive_ptr reduce( @@ -219,13 +266,20 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { int64_t, int64_t, int64_t)>(); - return op.call( + auto work = op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive(opts.reduceOp), opts.rootRank, opts.rootTensor, opts.timeout.count()); + + if (c10d::allow_inflight_collective_as_graph_input()) { + for (const auto& tensor : tensors) { + c10d::register_work(tensor, work); + } + } + return work; } virtual c10::intrusive_ptr allgather( @@ -242,11 +296,20 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t)>(); - return std::get<1>(op.call( + auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.timeout.count())); + + if (c10d::allow_inflight_collective_as_graph_input()) { + for (const auto& tensor_list : outputTensors) { + for (const auto& tensor : tensor_list) { + c10d::register_work(tensor, work); + } + } + } + return work; } // Gathers a single tensor inputBuffer into a single buffer outputBuffer that @@ -267,12 +330,17 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { bool, int64_t)>(); - return std::get<1>(op.call( + auto work = std::get<1>(op.call( outputBuffer, inputBuffer, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.asyncOp, opts.timeout.count())); + + if (c10d::allow_inflight_collective_as_graph_input()) { + c10d::register_work(outputBuffer, work); + } + return work; } // This function is deprecated and will be moved out of ProcessGroup to comms: @@ -291,10 +359,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const at::TensorList&, const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); - return op.call( + auto work = op.call( outputTensorLists, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this)); + + if (c10d::allow_inflight_collective_as_graph_input()) { + for (const auto& tensor_list : outputTensorLists) { + for (const auto& tensor : tensor_list) { + c10d::register_work(tensor, work); + } + } + } + return work; } // This function is a coalesced version of `allgather_into_tensor` (currently @@ -312,10 +389,17 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); - return op.call( + auto work = op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this)); + + if (c10d::allow_inflight_collective_as_graph_input()) { + for (const auto& tensor : outputTensors) { + c10d::register_work(tensor, work); + } + } + return work; } virtual c10::intrusive_ptr gather( @@ -330,12 +414,21 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t, int64_t)>(); - return op.call( + auto work = op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.rootRank, opts.timeout.count()); + + if (c10d::allow_inflight_collective_as_graph_input()) { + for (const auto& tensor_list : outputTensors) { + for (const auto& tensor : tensor_list) { + c10d::register_work(tensor, work); + } + } + } + return work; } virtual c10::intrusive_ptr scatter( @@ -353,13 +446,20 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { int64_t, bool, int64_t)>(); - return std::get<1>(op.call( + auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.rootRank, opts.asyncOp, opts.timeout.count())); + + if (c10d::allow_inflight_collective_as_graph_input()) { + for (const auto& tensor : outputTensors) { + c10d::register_work(tensor, work); + } + } + return work; } virtual c10::intrusive_ptr reduce_scatter( @@ -376,12 +476,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, int64_t)>(); - return std::get<1>(op.call( + auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), opts.timeout.count())); + + if (c10d::allow_inflight_collective_as_graph_input()) { + for (const auto& tensor : outputTensors) { + c10d::register_work(tensor, work); + } + } + return work; } virtual c10::intrusive_ptr _reduce_scatter_base( @@ -398,13 +505,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ReduceOp>&, bool, int64_t)>(); - return std::get<1>(op.call( + auto work = std::get<1>(op.call( outputBuffer, inputBuffer, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), opts.asyncOp, opts.timeout.count())); + + if (c10d::allow_inflight_collective_as_graph_input()) { + c10d::register_work(outputBuffer, work); + } + return work; } // This function is a coalesced version of `reduce_scatter_tensor` (currently @@ -424,12 +536,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ReduceOp>&, int64_t)>(); - return op.call( + auto work = op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), opts.timeout.count()); + + if (c10d::allow_inflight_collective_as_graph_input()) { + for (const auto& tensor : outputTensors) { + c10d::register_work(tensor, work); + } + } + return work; } virtual c10::intrusive_ptr alltoall_base( @@ -447,13 +566,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { std::vector, std::vector, int64_t)>(); - return op.call( + auto work = op.call( outputBuffer, inputBuffer, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), outputSplitSizes, inputSplitSizes, opts.timeout.count()); + + if (c10d::allow_inflight_collective_as_graph_input()) { + c10d::register_work(outputBuffer, work); + } + return work; } virtual c10::intrusive_ptr alltoall( @@ -469,11 +593,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const at::TensorList&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t)>(); - return std::get<1>(op.call( + auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.timeout.count())); + + if (c10d::allow_inflight_collective_as_graph_input()) { + for (const auto& tensor : outputTensors) { + c10d::register_work(tensor, work); + } + } + return work; } virtual void monitoredBarrier( @@ -549,11 +680,17 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t, int64_t)>(); - return op.call( + auto work = op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), dstRank, tag); + if (c10d::allow_inflight_collective_as_graph_input()) { + for (const auto& tensor : tensors) { + c10d::register_work(tensor, work); + } + } + return work; } virtual c10::intrusive_ptr recv( @@ -567,11 +704,17 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t, int64_t)>(); - return op.call( + auto work = op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), srcRank, tag); + if (c10d::allow_inflight_collective_as_graph_input()) { + for (const auto& tensor : tensors) { + c10d::register_work(tensor, work); + } + } + return work; } virtual c10::intrusive_ptr recvAnysource( @@ -583,10 +726,16 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t)>(); - return op.call( + auto work = op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), tag); + if (c10d::allow_inflight_collective_as_graph_input()) { + for (const auto& tensor : tensors) { + c10d::register_work(tensor, work); + } + } + return work; } virtual c10::intrusive_ptr barrier( @@ -618,11 +767,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const std::vector&, int64_t)>(); - return op.call( + auto work = op.call( tensor, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.device_ids, opts.timeout.count()); + if (c10d::allow_inflight_collective_as_graph_input()) { + c10d::register_work(tensor, work); + } + return work; } bool hasBackends() { diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 3cb765a6589..2d8d15af543 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -574,6 +575,11 @@ bool ProcessGroupGloo::SendWork::wait(std::chrono::milliseconds timeout) { // Completes the Work object and throws the exception. finishAndThrow(exception); + if (c10d::allow_inflight_collective_as_graph_input()) { + c10d::unregister_work( + c10::intrusive_ptr< + ProcessGroupGloo::SendWork>::unsafe_reclaim_from_nonowning(this)); + } return sendCompleted; } @@ -621,6 +627,11 @@ bool ProcessGroupGloo::RecvWork::wait(std::chrono::milliseconds timeout) { // Completes the Work object and throws the exception. finishAndThrow(exception); + if (c10d::allow_inflight_collective_as_graph_input()) { + c10d::unregister_work( + c10::intrusive_ptr< + ProcessGroupGloo::RecvWork>::unsafe_reclaim_from_nonowning(this)); + } return recvCompleted; } diff --git a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp index 91e9f938f1d..a46e216179c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp @@ -7,6 +7,7 @@ #include #include +#include #if defined(OPEN_MPI) && OPEN_MPI #include // Needed for CUDA-aware check @@ -198,6 +199,11 @@ bool ProcessGroupMPI::AsyncWork::wait(std::chrono::milliseconds /* unused */) { populateException(); std::rethrow_exception(exception_); } + if (c10d::allow_inflight_collective_as_graph_input()) { + c10d::unregister_work( + c10::intrusive_ptr< + ProcessGroupMPI::AsyncWork>::unsafe_reclaim_from_nonowning(this)); + } // Always return true, because abort API is not implemented. return true; } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index c98e76b1fd7..55c757504d8 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -725,6 +725,11 @@ void ProcessGroupNCCL::WorkNCCL::handleException( void ProcessGroupNCCL::WorkNCCL::synchronize() { synchronizeStream(); + if (c10d::allow_inflight_collective_as_graph_input()) { + c10d::unregister_work( + c10::intrusive_ptr< + ProcessGroupNCCL::WorkNCCL>::unsafe_reclaim_from_nonowning(this)); + } } void ProcessGroupNCCL::WorkNCCL::synchronizeStream() { diff --git a/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp b/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp index dab6aa6d26e..d177a1fa6d1 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -273,6 +274,11 @@ bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) { Work::recordFunctionEndCallback_(); Work::recordFunctionEndCallback_ = nullptr; } + if (c10d::allow_inflight_collective_as_graph_input()) { + c10d::unregister_work( + c10::intrusive_ptr< + ProcessGroupUCC::WorkUCC>::unsafe_reclaim_from_nonowning(this)); + } return true; } diff --git a/torch/csrc/distributed/c10d/Work.cpp b/torch/csrc/distributed/c10d/Work.cpp index d7890566acb..4502e4aa235 100644 --- a/torch/csrc/distributed/c10d/Work.cpp +++ b/torch/csrc/distributed/c10d/Work.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -70,7 +71,12 @@ std::vector Work::result() { TORCH_CHECK(false, "result() not implemented."); } -void Work::synchronize() {} +void Work::synchronize() { + if (c10d::allow_inflight_collective_as_graph_input()) { + c10d::unregister_work( + c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this)); + } +} bool Work::wait(std::chrono::milliseconds timeout) { std::unique_lock lock(mutex_); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 2c3f836a394..51deeadbea1 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -937,6 +937,21 @@ This class does not support ``__members__`` property.)"); py::arg("tensor"), py::arg("work")); + module.def("_get_work_registry_size", []() { + return ::c10d::get_work_registry_size(); + }); + + module.def( + "_set_allow_inflight_collective_as_graph_input", + [](bool value) { + return ::c10d::set_allow_inflight_collective_as_graph_input(value); + }, + py::arg("value")); + + module.def("_allow_inflight_collective_as_graph_input", []() { + return ::c10d::allow_inflight_collective_as_graph_input(); + }); + // Remove a group from the native registry module.def( "_unregister_process_group", diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 4127885dccc..ebb9ec1e4b2 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import contextlib import sys import warnings from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union @@ -816,6 +817,43 @@ def _maybe_wrap_tensor(self) -> torch.Tensor: return cast(torch.Tensor, res) +@contextlib.contextmanager +def allow_inflight_collective_as_graph_input_ctx(value: bool = True): + """ + Context manager to temporarily set whether inflight collectives are allowed as torch.compile graph inputs. + Common use case is when the collective is issued in eager (with `async_op=True`) but waited in compiled region: + ``` + def all_reduce_eager(x): + y = x * x + req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True) + return y + + @torch.compile(fullgraph=True) + def all_reduce_wait_compiled(y): + torch.ops.c10d_functional.wait_tensor(y) + return y * y + + x = torch.ones(1280, 1280, device="cuda") + self.rank + # the context manager ensures that `wait_tensor(y)` will wait on the correct work object + with allow_inflight_collective_as_graph_input_ctx(): + y = all_reduce_eager(x) + z = all_reduce_wait_compiled(y) + ``` + With this context manager, when a collective is called, under the hood the work object of the collective + will be registered in the work registry, and the wait_tensor() in compiled region called on + the output tensor of the collective will wait on the correct work object. + """ + previous = torch._C._distributed_c10d._allow_inflight_collective_as_graph_input() + + try: + torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(value) + yield + finally: + torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input( + previous + ) + + def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size): def mk_out_tensor(shard): out_size = list(shard.size())