diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index b1c99145311..0f4bf91edc2 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -405,22 +405,6 @@ class TestWithNCCL(MultiProcessTestCase): assert output.eq(expect).all() assert output.completed - @skip_if_lt_x_gpu(2) - def test_wait_tensor(self) -> None: - self._init_process_group() - - input = torch.full((10, 10), float(self.rank), device=self.device) - self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) - output = torch.ops._c10d_functional.all_reduce( - input, - "avg", - "default", - ) - self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1) - torch.ops._c10d_functional.wait_tensor(output) - # `wait_tensor(output)` will pop the work from the work registry immediately - self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) - @skip_if_lt_x_gpu(2) def test_unwaited(self) -> None: # Verify that the process can terminate gracefully @@ -428,13 +412,11 @@ class TestWithNCCL(MultiProcessTestCase): self._init_process_group() input = torch.full((10, 10), float(self.rank), device=self.device) - self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) output = torch.ops._c10d_functional.all_reduce( input, "avg", "default", ) - self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1) @skip_if_lt_x_gpu(2) def test_py_work(self) -> None: diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 2626d694dbc..2dcb0e1d2f0 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3178,48 +3178,6 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase): with self.assertRaisesRegex(TypeError, "Invalid function argument"): c10d.barrier(device_ids=self.rank) - @requires_nccl() - @skip_if_lt_x_gpu(2) - def test_unwaited(self) -> None: - # Verify that the process can terminate gracefully - # even with unwaited tensors - store = c10d.FileStore(self.file_name, self.world_size) - c10d.init_process_group( - backend="nccl", rank=self.rank, world_size=self.world_size, store=store - ) - - input = torch.full((10240, 10240), float(self.rank), device=f"cuda:{self.rank}") - dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True) - self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1) - # Running another collective on the same tensor should still work - dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True) - self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2) - - @requires_nccl() - @skip_if_lt_x_gpu(2) - def test_wait_tensor(self) -> None: - # Verify that c10d_functional.wait_tensor() can be invoked on - # output tensor of non-functional collective - store = c10d.FileStore(self.file_name, self.world_size) - c10d.init_process_group( - backend="nccl", rank=self.rank, world_size=self.world_size, store=store - ) - - input1 = torch.full((10, 10), float(self.rank), device=f"cuda:{self.rank}") - self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) - dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True) - self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1) - torch.ops.c10d_functional.wait_tensor(input1) - self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) - - input2 = torch.full((10, 10), float(self.rank), device=f"cuda:{self.rank}") - self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) - work = dist.all_reduce(input2, op=dist.ReduceOp.SUM, async_op=True) - self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1) - work.wait() - self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) - self.assertEqual(input1, input2) - @requires_nccl() @skip_if_lt_x_gpu(2) @with_dist_debug_levels(levels=["DETAIL"]) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 4c8650c1a88..f59c471a0f9 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -import datetime import functools import unittest from unittest.mock import patch @@ -15,7 +14,7 @@ from torch._C import FileCheck from torch._dynamo.testing import CompileCounter from torch._dynamo.utils import same from torch._inductor.compile_fx import compile_fx as inductor_compile_fx -from torch._inductor.utils import run_and_get_code, run_and_get_triton_code +from torch._inductor.utils import run_and_get_triton_code from torch.distributed.distributed_c10d import GroupMember from torch.fx.experimental.proxy_tensor import make_fx from torch.testing._internal.common_distributed import ( @@ -29,7 +28,6 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, requires_cuda, - skipIfRocm, ) from torch.testing._internal.inductor_utils import HAS_GPU @@ -247,74 +245,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase): ) self.assertTrue(same(eager_out, inductor_out, tol=0.001)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @skip_if_lt_x_gpu(2) - @skipIfRocm - def test_eager_async_allreduce_inductor_wait(self): - import torch.distributed as dist - - def all_reduce_non_functional_eager(x): - y = x * x - work = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True) - assert isinstance(work, torch.distributed.Work) - return work, y - - def all_reduce_wait(work, y): # potentially compiled - if torch.compiler.is_dynamo_compiling(): - torch.ops.c10d_functional.wait_tensor(y) - else: - work.wait(datetime.timedelta(seconds=10)) - # Under compile, if `wait_tensor(y)` above is correctly executed, - # `y`'s data is in its final form and the output of this function will match eager; - # otherwise, `y * y` will run in parallel with `all_reduce(y)` and the output of this function - # will not match eager. - return y * y - - with _dynamo_dist_per_rank_init(self.rank, self.world_size): - x = torch.ones(12800, 12800, device="cuda") + self.rank - self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) - - # NOTE: We run for 10 iterations each, to ensure that the GPU execution is way behind CPU - # and that `y * y` on CPU side will be issued before `all_reduce(y)` on GPU side is done, - # thus guaranteeing that in the bad case `y * y` on GPU side will run in parallel with `all_reduce(y)` - # thus will produce the wrong result that fails the unit test. - - # Test: pure-eager - all_reduce_wait_eager = all_reduce_wait - for _ in range(10): - work, y = all_reduce_non_functional_eager(x) - self.assertEqual( - torch._C._distributed_c10d._get_work_registry_size(), 1 - ) - out_ref = all_reduce_wait_eager(work, y) - # `work.wait()` will pop the work from the work registry immediately - self.assertEqual( - torch._C._distributed_c10d._get_work_registry_size(), 0 - ) - - # Test: issue comm in eager -> wait for comm in compile - all_reduce_wait_compiled = torch.compile( - all_reduce_wait, - backend="inductor", - fullgraph=True, - ) - for _ in range(10): - work, y = all_reduce_non_functional_eager(x) - self.assertEqual( - torch._C._distributed_c10d._get_work_registry_size(), 1 - ) - out_compiled, triton_codes = run_and_get_code( - all_reduce_wait_compiled, work, y - ) - # `wait_tensor(y)` will pop the work from the work registry immediately - self.assertEqual( - torch._C._distributed_c10d._get_work_registry_size(), 0 - ) - FileCheck().check( - "torch.ops._c10d_functional.wait_tensor.default(" - ).run(triton_codes[0]) - self.assertEqual(out_ref, out_compiled) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) diff --git a/torch/csrc/distributed/c10d/Functional.cpp b/torch/csrc/distributed/c10d/Functional.cpp index a8c5d018778..1117718ee50 100644 --- a/torch/csrc/distributed/c10d/Functional.cpp +++ b/torch/csrc/distributed/c10d/Functional.cpp @@ -6,10 +6,80 @@ #include #include #include +#include #include namespace { +class WorkRegistry { + public: + void register_work( + const at::Tensor& tensor, + const c10::intrusive_ptr& work) { + auto storage = tensor.storage().getWeakStorageImpl(); + std::unique_lock lock(lock_); + auto [it, inserted] = registry_.try_emplace(std::move(storage), work); + TORCH_CHECK( + inserted || it->second != work, + "The tensor storage is already associated with another work."); + } + + c10::intrusive_ptr pop_work(const at::Tensor& tensor) { + const auto storage = tensor.storage().getWeakStorageImpl(); + std::unique_lock lock(lock_); + auto it = registry_.find(storage); + if (it == registry_.end()) { + return nullptr; + } + auto work = it->second; + registry_.erase(it); + return work; + } + + ~WorkRegistry() { + // If there are still unwaited work objects, their corresponding process + // groups should have already been destroyed at this stage. Any attempts to + // wait for these work objects or to destroy them will only result in + // confusing errors. Therefore, we simply issue a warning and intentionally + // allow the unwaited work objects to leak. + if (!registry_.empty()) { + TORCH_WARN( + "At the time of process termination, there are still ", + registry_.size(), + " unwaited c10d_functional collective calls. " + "Please review your program to ensure c10d_functional.wait_tensor() " + "is invoked on all tensors returned from c10d_functional collective " + "ops before they are used."); + } + for (auto& it : registry_) { + it.second.release(); + } + } + + private: + std::unordered_map< + c10::weak_intrusive_ptr, + c10::intrusive_ptr> + registry_; + std::mutex lock_; +}; + +static WorkRegistry process_registry; + +} // namespace + +namespace c10d { + +void register_work( + const at::Tensor& tensor, + const c10::intrusive_ptr& work) { + RankLocal::get().register_work(tensor, work); +} + +} // namespace c10d + +namespace { + const std::unordered_map str_to_reduce_op = { {"sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::SUM)}, {"avg", c10d::ReduceOp(c10d::ReduceOp::RedOpType::AVG)}, @@ -42,6 +112,7 @@ at::Tensor& all_reduce_( std::vector inputs{input}; auto group = c10d::resolve_process_group(group_name); auto work = group->allreduce(inputs, opts); + c10d::register_work(input, work); return input; } @@ -64,6 +135,9 @@ std::vector all_reduce_coalesced_( auto group = c10d::resolve_process_group(group_name); auto work = group->allreduce_coalesced(inputs, opts); + for (const auto& tensor : inputs) { + c10d::register_work(tensor, work); + } return inputs; } @@ -104,6 +178,9 @@ std::vector all_gather_into_tensor_coalesced( auto group = c10d::resolve_process_group(group_name); auto work = group->allgather_into_tensor_coalesced(outputs, inputs); + for (const auto& tensor : outputs) { + c10d::register_work(tensor, work); + } return outputs; } @@ -125,6 +202,7 @@ at::Tensor& all_gather_into_tensor_out( auto group = c10d::resolve_process_group(group_name); auto work = group->_allgather_base(output, input, opts); + c10d::register_work(output, work); return output; } @@ -160,6 +238,9 @@ std::vector reduce_scatter_tensor_coalesced( auto group = c10d::resolve_process_group(group_name); auto work = group->reduce_scatter_tensor_coalesced(outputs, inputs, opts); + for (const auto& tensor : outputs) { + c10d::register_work(tensor, work); + } return outputs; } @@ -191,6 +272,7 @@ at::Tensor all_to_all_single( const_cast(input), output_split_sizes, input_split_sizes); + c10d::register_work(output, work); return output; } @@ -202,6 +284,7 @@ at::Tensor& broadcast_(at::Tensor& input, int64_t src, std::string group_name) { auto group = c10d::resolve_process_group(group_name); auto work = group->broadcast(inputs, opts); + c10d::register_work(input, work); return input; } @@ -213,6 +296,14 @@ at::Tensor broadcast( return broadcast_(output, src, std::move(group_name)); } +at::Tensor wait_tensor(const at::Tensor& tensor) { + auto work = c10d::RankLocal::get().pop_work(tensor); + if (work != nullptr) { + work->wait(); + } + return tensor; +} + } // namespace TORCH_LIBRARY(_c10d_functional, m) { @@ -298,7 +389,7 @@ TORCH_LIBRARY(_c10d_functional, m) { m.def( "wait_tensor(Tensor tensor) -> Tensor", torch::dispatch( - c10::DispatchKey::CompositeExplicitAutograd, c10d::wait_tensor), + c10::DispatchKey::CompositeExplicitAutograd, ::wait_tensor), {at::Tag::pt2_compliant_tag}); } @@ -347,7 +438,7 @@ class AllToAllSingle : public torch::autograd::Function { // TODO: track active cuda stream in wait out = c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::wait_tensor", "") - .typed() + .typed() .call(out); return {out, at::Tensor(), at::Tensor(), at::Tensor()}; @@ -402,7 +493,7 @@ class ReduceScatterTensor // TODO: track active cuda stream in wait out = c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::wait_tensor", "") - .typed() + .typed() .call(out); return { @@ -458,7 +549,7 @@ class AllGatherIntoTensor // TODO: track active cuda stream in wait out = c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::wait_tensor", "") - .typed() + .typed() .call(out); return { diff --git a/torch/csrc/distributed/c10d/Functional.hpp b/torch/csrc/distributed/c10d/Functional.hpp index e81d44b8dbd..cbb19e68609 100644 --- a/torch/csrc/distributed/c10d/Functional.hpp +++ b/torch/csrc/distributed/c10d/Functional.hpp @@ -1,3 +1,11 @@ #pragma once -#include +#include + +namespace c10d { + +C10_EXPORT void register_work( + const at::Tensor& tensor, + const c10::intrusive_ptr& work); + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index dffe20aebdd..63d64447dfd 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -1,6 +1,5 @@ #include #include -#include #include #include @@ -160,137 +159,3 @@ void ProcessGroup::release_resources() { } } // namespace c10d - -namespace { - -class WorkRegistry { - public: - void register_work( - const at::Tensor& tensor, - const c10::intrusive_ptr& work) { - if (!tensor.has_storage()) { - TORCH_WARN_ONCE( - "Registering collective work for tensor without storage is not supported. " - "Calling c10d_functional.wait_tensor() on this tensor will not wait for the collective to complete. " - "Unsupported tensor type: " + - tensor.toString()); - return; - } - auto storage = tensor.storage().getWeakStorageImpl(); - std::unique_lock lock(lock_); - - auto it = registry_.find(storage); - if (it == registry_.end()) { - registry_.emplace( - std::move(storage), - std::vector>{work}); - } else { - // There is no guarantee that the previous work object for this - // tensor storage is completed before the new work object is registered. - // Therefore we need to maintain a list of work objects for each tensor - // storage. - it->second.push_back(work); - } - } - - std::vector> pop_works( - const at::Tensor& tensor) { - const auto storage = tensor.storage().getWeakStorageImpl(); - std::unique_lock lock(lock_); - auto it = registry_.find(storage); - if (it == registry_.end()) { - return {}; - } - auto works = it->second; - registry_.erase(it); - return works; - } - - void unregister_work(const c10::intrusive_ptr& work) { - std::unique_lock lock(lock_); - for (auto it = registry_.begin(); it != registry_.end();) { - std::vector> nonmatching_works; - for (const auto& _work : it->second) { - if (_work != work) { - nonmatching_works.push_back(_work); - } - } - if (nonmatching_works.empty()) { - it = registry_.erase(it); - } else { - it->second = std::move(nonmatching_works); - ++it; - } - } - } - - size_t get_work_registry_size() { - std::unique_lock lock(lock_); - size_t total_size = 0; - for (const auto& [storage, works] : registry_) { - total_size += works.size(); - } - return total_size; - } - - ~WorkRegistry() { - // If there are still unwaited work objects, their corresponding process - // groups should have already been destroyed at this stage. Any attempts to - // wait for these work objects or to destroy them will only result in - // confusing errors. Therefore, we simply issue a warning and intentionally - // allow the unwaited work objects to leak. - size_t registry_size = get_work_registry_size(); - if (registry_size > 0) { - TORCH_WARN( - "At the time of process termination, there are still ", - registry_size, - " unwaited collective calls. " - "Please review your program to ensure that:\n" - "1. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,\n" - "2. work.wait() is invoked on work object returned from torch.distributed collective with async_op=True,\n" - "before the output tensors of the collective are used."); - } - for (auto& it : registry_) { - for (auto& work : it.second) { - work.release(); - } - } - } - - private: - std::unordered_map< - c10::weak_intrusive_ptr, - std::vector>> - registry_; - std::mutex lock_; -}; - -static WorkRegistry process_registry; - -} // namespace - -namespace c10d { - -void register_work( - const at::Tensor& tensor, - const c10::intrusive_ptr& work) { - RankLocal::get().register_work(tensor, work); -} - -at::Tensor wait_tensor(const at::Tensor& tensor) { - auto works = RankLocal::get().pop_works(tensor); - for (const auto& work : works) { - work->wait(); - } - return tensor; -} - -void unregister_work(const c10::intrusive_ptr& work) { - RankLocal::get().unregister_work(work); -} - -size_t get_work_registry_size() { - return RankLocal::get().get_work_registry_size(); -} - -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index febf885a112..463d1f046db 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include #include @@ -24,16 +23,6 @@ constexpr auto kProcessGroupDefaultTimeout = namespace c10d { -C10_EXPORT void register_work( - const at::Tensor& tensor, - const c10::intrusive_ptr& work); - -C10_EXPORT at::Tensor wait_tensor(const at::Tensor& tensor); - -C10_EXPORT void unregister_work(const c10::intrusive_ptr& work); - -C10_EXPORT size_t get_work_registry_size(); - // ProcessGroup is a base class that captures collective and point to // point communication in a fixed set of processes. // @@ -169,18 +158,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // It's awakward to unbox the opts here and box them again in the custom C++ // op. But it's also complicated to make opts as a CustomClassHolder. Leave // it as it is now. - auto work = std::get<1>(op.call( + return std::get<1>(op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.rootRank, opts.rootTensor, opts.asyncOp, opts.timeout.count())); - - for (const auto& tensor : tensors) { - c10d::register_work(tensor, work); - } - return work; } virtual c10::intrusive_ptr allreduce( @@ -197,17 +181,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const std::optional& sparse_indices, int64_t)>(); - auto work = std::get<1>(op.call( + return std::get<1>(op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive(opts.reduceOp), opts.sparseIndices, opts.timeout.count())); - - for (const auto& tensor : tensors) { - c10d::register_work(tensor, work); - } - return work; } virtual c10::intrusive_ptr allreduce_coalesced( @@ -221,16 +200,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ReduceOp>&, int64_t)>(); - auto work = op.call( + return op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive(opts.reduceOp), opts.timeout.count()); - - for (const auto& tensor : tensors) { - c10d::register_work(tensor, work); - } - return work; } virtual c10::intrusive_ptr reduce( @@ -245,18 +219,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { int64_t, int64_t, int64_t)>(); - auto work = op.call( + return op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive(opts.reduceOp), opts.rootRank, opts.rootTensor, opts.timeout.count()); - - for (const auto& tensor : tensors) { - c10d::register_work(tensor, work); - } - return work; } virtual c10::intrusive_ptr allgather( @@ -273,18 +242,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t)>(); - auto work = std::get<1>(op.call( + return std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.timeout.count())); - - for (const auto& tensor_list : outputTensors) { - for (const auto& tensor : tensor_list) { - c10d::register_work(tensor, work); - } - } - return work; } // Gathers a single tensor inputBuffer into a single buffer outputBuffer that @@ -305,15 +267,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { bool, int64_t)>(); - auto work = std::get<1>(op.call( + return std::get<1>(op.call( outputBuffer, inputBuffer, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.asyncOp, opts.timeout.count())); - - c10d::register_work(outputBuffer, work); - return work; } // This function is deprecated and will be moved out of ProcessGroup to comms: @@ -332,17 +291,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const at::TensorList&, const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); - auto work = op.call( + return op.call( outputTensorLists, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this)); - - for (const auto& tensor_list : outputTensorLists) { - for (const auto& tensor : tensor_list) { - c10d::register_work(tensor, work); - } - } - return work; } // This function is a coalesced version of `allgather_into_tensor` (currently @@ -360,15 +312,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); - auto work = op.call( + return op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this)); - - for (const auto& tensor : outputTensors) { - c10d::register_work(tensor, work); - } - return work; } virtual c10::intrusive_ptr gather( @@ -383,19 +330,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t, int64_t)>(); - auto work = op.call( + return op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.rootRank, opts.timeout.count()); - - for (const auto& tensor_list : outputTensors) { - for (const auto& tensor : tensor_list) { - c10d::register_work(tensor, work); - } - } - return work; } virtual c10::intrusive_ptr scatter( @@ -413,18 +353,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { int64_t, bool, int64_t)>(); - auto work = std::get<1>(op.call( + return std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.rootRank, opts.asyncOp, opts.timeout.count())); - - for (const auto& tensor : outputTensors) { - c10d::register_work(tensor, work); - } - return work; } virtual c10::intrusive_ptr reduce_scatter( @@ -441,17 +376,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, int64_t)>(); - auto work = std::get<1>(op.call( + return std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), opts.timeout.count())); - - for (const auto& tensor : outputTensors) { - c10d::register_work(tensor, work); - } - return work; } virtual c10::intrusive_ptr _reduce_scatter_base( @@ -468,16 +398,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ReduceOp>&, bool, int64_t)>(); - auto work = std::get<1>(op.call( + return std::get<1>(op.call( outputBuffer, inputBuffer, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), opts.asyncOp, opts.timeout.count())); - - c10d::register_work(outputBuffer, work); - return work; } // This function is a coalesced version of `reduce_scatter_tensor` (currently @@ -497,17 +424,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ReduceOp>&, int64_t)>(); - auto work = op.call( + return op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), opts.timeout.count()); - - for (const auto& tensor : outputTensors) { - c10d::register_work(tensor, work); - } - return work; } virtual c10::intrusive_ptr alltoall_base( @@ -525,16 +447,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { std::vector, std::vector, int64_t)>(); - auto work = op.call( + return op.call( outputBuffer, inputBuffer, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), outputSplitSizes, inputSplitSizes, opts.timeout.count()); - - c10d::register_work(outputBuffer, work); - return work; } virtual c10::intrusive_ptr alltoall( @@ -550,16 +469,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const at::TensorList&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t)>(); - auto work = std::get<1>(op.call( + return std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.timeout.count())); - - for (const auto& tensor : outputTensors) { - c10d::register_work(tensor, work); - } - return work; } virtual void monitoredBarrier( @@ -635,15 +549,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t, int64_t)>(); - auto work = op.call( + return op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), dstRank, tag); - for (const auto& tensor : tensors) { - c10d::register_work(tensor, work); - } - return work; } virtual c10::intrusive_ptr recv( @@ -657,15 +567,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t, int64_t)>(); - auto work = op.call( + return op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), srcRank, tag); - for (const auto& tensor : tensors) { - c10d::register_work(tensor, work); - } - return work; } virtual c10::intrusive_ptr recvAnysource( @@ -677,14 +583,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t)>(); - auto work = op.call( + return op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), tag); - for (const auto& tensor : tensors) { - c10d::register_work(tensor, work); - } - return work; } virtual c10::intrusive_ptr barrier( @@ -716,13 +618,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const std::vector&, int64_t)>(); - auto work = op.call( + return op.call( tensor, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.device_ids, opts.timeout.count()); - c10d::register_work(tensor, work); - return work; } bool hasBackends() { diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 8ac81f4c396..3cb765a6589 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -5,7 +5,6 @@ #include #include -#include #include #include @@ -575,9 +574,6 @@ bool ProcessGroupGloo::SendWork::wait(std::chrono::milliseconds timeout) { // Completes the Work object and throws the exception. finishAndThrow(exception); - c10d::unregister_work( - c10::intrusive_ptr< - ProcessGroupGloo::SendWork>::unsafe_reclaim_from_nonowning(this)); return sendCompleted; } @@ -625,9 +621,6 @@ bool ProcessGroupGloo::RecvWork::wait(std::chrono::milliseconds timeout) { // Completes the Work object and throws the exception. finishAndThrow(exception); - c10d::unregister_work( - c10::intrusive_ptr< - ProcessGroupGloo::RecvWork>::unsafe_reclaim_from_nonowning(this)); return recvCompleted; } diff --git a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp index df6c3acda70..91e9f938f1d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp @@ -7,7 +7,6 @@ #include #include -#include #if defined(OPEN_MPI) && OPEN_MPI #include // Needed for CUDA-aware check @@ -199,9 +198,6 @@ bool ProcessGroupMPI::AsyncWork::wait(std::chrono::milliseconds /* unused */) { populateException(); std::rethrow_exception(exception_); } - c10d::unregister_work( - c10::intrusive_ptr< - ProcessGroupMPI::AsyncWork>::unsafe_reclaim_from_nonowning(this)); // Always return true, because abort API is not implemented. return true; } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index eb16d6e09c9..fbd69dd7fd9 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -720,9 +720,6 @@ void ProcessGroupNCCL::WorkNCCL::handleException( void ProcessGroupNCCL::WorkNCCL::synchronize() { synchronizeStream(); - c10d::unregister_work( - c10::intrusive_ptr< - ProcessGroupNCCL::WorkNCCL>::unsafe_reclaim_from_nonowning(this)); } void ProcessGroupNCCL::WorkNCCL::synchronizeStream() { diff --git a/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp b/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp index c1937aaf52a..dab6aa6d26e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp @@ -2,7 +2,6 @@ #include #include -#include #include #include #include @@ -274,9 +273,6 @@ bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) { Work::recordFunctionEndCallback_(); Work::recordFunctionEndCallback_ = nullptr; } - c10d::unregister_work( - c10::intrusive_ptr< - ProcessGroupUCC::WorkUCC>::unsafe_reclaim_from_nonowning(this)); return true; } diff --git a/torch/csrc/distributed/c10d/Work.cpp b/torch/csrc/distributed/c10d/Work.cpp index af006e2d985..d7890566acb 100644 --- a/torch/csrc/distributed/c10d/Work.cpp +++ b/torch/csrc/distributed/c10d/Work.cpp @@ -1,5 +1,4 @@ #include -#include #include #include @@ -71,10 +70,7 @@ std::vector Work::result() { TORCH_CHECK(false, "result() not implemented."); } -void Work::synchronize() { - c10d::unregister_work( - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this)); -} +void Work::synchronize() {} bool Work::wait(std::chrono::milliseconds timeout) { std::unique_lock lock(mutex_); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index f613bf02455..67cf3b581b1 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -933,10 +933,6 @@ This class does not support ``__members__`` property.)"); py::arg("tensor"), py::arg("work")); - module.def("_get_work_registry_size", []() { - return ::c10d::get_work_registry_size(); - }); - // Remove a group from the native registry module.def( "_unregister_process_group",