diff --git a/test/distributed/test_c10d_pypg.py b/test/distributed/test_c10d_pypg.py new file mode 100644 index 00000000000..9c9e0c4422d --- /dev/null +++ b/test/distributed/test_c10d_pypg.py @@ -0,0 +1,154 @@ +# Owner(s): ["oncall: distributed"] + +import os + +import torch +import torch.distributed as dist +from torch.testing._internal.common_utils import ( + run_tests, +) +from torch.futures import Future +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +import test_c10d_common +import weakref +from torch._C._distributed_c10d import _create_work_from_future +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, +) + +def create_work(result): + future = Future() + future.set_result(result) + return _create_work_from_future(future) + +class MyWork(dist._Work): + def __init__(self, result, pg): + super().__init__() + self.result_ = result + self.future_ = torch.futures.Future() + self.future_.set_result(result) + self.pg_ = weakref.ref(pg) + + def wait(self, timeout): + self.pg_().wait_count += 1 + return True + + def get_future(self): + self.pg_().get_future_count += 1 + return self.future_ + +class LonelyRankProcessGroup(dist.ProcessGroup): + """ + This PG only supports world_size of 1 + """ + def __init__(self, rank, world, use_wrapper): + super(LonelyRankProcessGroup, self).__init__(rank, world) + assert rank == 0 + assert world == 1 + + self._rank = rank + self._world = world + self.wait_count = 0 + self.get_future_count = 0 + self.use_wrapper = use_wrapper + self._work = [] + + def broadcast(self, tensor_list, opts): + if self.use_wrapper: + return create_work(tensor_list) + res = MyWork(tensor_list, self) + self._work.append(res) + return res + + def allgather(self, output_tensors, input_tensor, opts): + for o, i in zip(output_tensors[0], input_tensor): + o.copy_(i) + if self.use_wrapper: + return create_work(output_tensors) + + res = MyWork(output_tensors, self) + self._work.append(res) + + return res + + def allreduce(self, tensors, opts): + if self.use_wrapper: + return create_work(tensors) + res = MyWork(tensors, self) + self._work.append(res) + return res + + def size(self): + return self._world + + def getBackendName(self): + return "lonely-pg" + + def __repr__(self): + return f"PLG w:{self._world} r:{self._rank}" + +# We cannot use parametrize as some tests are defined on the base class and use _get_process_group +class AbstractDDPSingleRank(test_c10d_common.CommonDistributedDataParallelTest): + def setUp(self): + super(AbstractDDPSingleRank, self).setUp() + self._spawn_processes() + + @property + def world_size(self): + return 1 + + def tearDown(self): + super(AbstractDDPSingleRank, self).tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + def _get_process_group(self): + return LonelyRankProcessGroup(self.rank, self.world_size, self.use_wrapper) + + def test_ddp_invoke_work_object(self): + pg = self._get_process_group() + + torch.manual_seed(123) + model = nn.Sequential( + nn.Linear(2, 2), + nn.ReLU() + ) + wrapped_model = model + input_tensor = torch.rand(2) + model = DDP(model, process_group=pg) + model(input_tensor).sum().backward() + + ddp_grad = wrapped_model[0].bias.grad.clone() + + wrapped_model.zero_grad() + wrapped_model(input_tensor).sum().backward() + self.assertEqual(wrapped_model[0].bias.grad, ddp_grad) + if not self.use_wrapper: + self.assertTrue(pg.wait_count > 0) + self.assertTrue(pg.get_future_count > 0) + + def test_ddp_with_pypg(self): + pg = self._get_process_group() + + self._test_ddp_with_process_group(pg, [torch.device("cpu")], device_ids=None) + + def test_ddp_with_pypg_with_grad_views(self): + pg = self._get_process_group() + + self._test_ddp_with_process_group(pg, [torch.device("cpu")], device_ids=None, gradient_as_bucket_view=True) + +class TestDDPWithWorkSubclass(AbstractDDPSingleRank, MultiProcessTestCase): + @property + def use_wrapper(self): + return False + +class TestDDPWithWorkWrapper(AbstractDDPSingleRank, MultiProcessTestCase): + @property + def use_wrapper(self): + return True + +if __name__ == '__main__': + run_tests() diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index f31a7d8c815..fde76d9f503 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -191,4 +191,56 @@ void ProcessGroup::init() { fmt::format("c10d.process_group_{}", getBackendName())); } +class FutureWrappingWork : public ProcessGroup::Work { + public: + FutureWrappingWork(c10::intrusive_ptr fut) + : Work(), _fut(fut) {} + + ~FutureWrappingWork() {} + + bool isCompleted() override { + return _fut->completed(); + } + + bool isSuccess() const override { + return _fut->hasValue(); + } + + std::exception_ptr exception() const override { + return _fut->exception_ptr(); + } + + int sourceRank() const override { + TORCH_CHECK(false, "FutureWrappingWork::sourceRank() not implemented"); + } + + std::vector result() override { + return _fut->value().toPyObjectHolder()->extractTensors(); + } + + bool wait(std::chrono::milliseconds timeout) override { + // FIXME + TORCH_CHECK( + timeout == kNoTimeout, + "FutureWrappingWork::wait() with finite timeout not implemented"); + _fut->wait(); + return true; + } + + void abort() override { + TORCH_CHECK(false, "FutureWrappingWork::abort() not implemented"); + } + + c10::intrusive_ptr getFuture() override { + return _fut; + } + + private: + c10::intrusive_ptr _fut; +}; + +c10::intrusive_ptr ProcessGroup::Work::create_from_future( + c10::intrusive_ptr future) { + return c10::make_intrusive(future); +} } // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index b6787092c88..f81358a87be 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -149,6 +149,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { OpType retrieveOpType(); + static c10::intrusive_ptr create_from_future(c10::intrusive_ptr); + protected: // Completes the work object and optionally sets the exception in a // thread-safe manner. Notifies all waiting condition variables as well. diff --git a/torch/csrc/distributed/c10d/PyProcessGroup.hpp b/torch/csrc/distributed/c10d/PyProcessGroup.hpp index 761a124dff4..22612aee12d 100644 --- a/torch/csrc/distributed/c10d/PyProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/PyProcessGroup.hpp @@ -2,6 +2,7 @@ #include #include +#include namespace c10d { @@ -22,6 +23,22 @@ class PyProcessGroup : public ProcessGroup { wait, /* Name of function in C++ */ timeout); } + + c10::intrusive_ptr getFuture() override { + // We cannot use PYBIND11_OVERRIDE because: + // 1. We have to >MANUALLY< unwrap the PyFutureWrapper and + // 2. The python name is get_future + pybind11::gil_scoped_acquire gil; + auto override = pybind11::get_override(static_cast(this), "get_future"); + + if (override) { + py::object o = override(); + auto futWrapper = o.cast>(); + return futWrapper->fut; + } + + return Work::getFuture(); + } }; using ProcessGroup::ProcessGroup; diff --git a/torch/csrc/distributed/c10d/comm.cpp b/torch/csrc/distributed/c10d/comm.cpp index 8c72fabd63a..d4c26d99bb0 100644 --- a/torch/csrc/distributed/c10d/comm.cpp +++ b/torch/csrc/distributed/c10d/comm.cpp @@ -99,5 +99,26 @@ std::vector GradBucket::getGradients() const { } return per_parameter_tensors; } +namespace detail { + +at::Tensor parseCppCommHookResult(const c10::IValue& result) { + if (result.isPyObject()) { + std::vector tensors = + result.toPyObjectHolder()->extractTensors(); + return tensors[0]; + } + TORCH_INTERNAL_ASSERT( + result.isTensor() || result.isTensorList(), + "expected the hook result is either a Tensor or a TensorList found ", + result.tagKind()); + + if (result.isTensor()) { + return result.toTensor(); + } + + return result.toTensorVector()[0]; +} + +} // namespace detail } // namespace c10d diff --git a/torch/csrc/distributed/c10d/comm.hpp b/torch/csrc/distributed/c10d/comm.hpp index 43ffdb07d78..7e97d6297e1 100644 --- a/torch/csrc/distributed/c10d/comm.hpp +++ b/torch/csrc/distributed/c10d/comm.hpp @@ -106,18 +106,7 @@ class TORCH_API CommHookInterface { namespace detail { // This helper function is called both by CppCommHookInterface below and inside // reducer. -inline at::Tensor parseCppCommHookResult( - const c10::IValue& result) { - TORCH_INTERNAL_ASSERT( - result.isTensor() || result.isTensorList(), - "expected the hook result is either a Tensor or a TensorList"); - - if (result.isTensor()) { - return result.toTensor(); - } - - return result.toTensorVector()[0]; -} + at::Tensor parseCppCommHookResult(const c10::IValue& result); } // namespace detail // This CppCommHook interface only requires implementing runHook method that diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 3e234c6f62a..0e5b64d49df 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1777,6 +1777,36 @@ Example:: module.attr("_DEFAULT_PG_TIMEOUT") = py::cast(kProcessGroupDefaultTimeout); module.attr("_DEFAULT_NO_TIMEOUT") = py::cast(kNoTimeout); + module.def( + "_create_work_from_future", + [](std::shared_ptr future) { + return ::c10d::ProcessGroup::Work::create_from_future(future->fut); + }, + py::arg("future"), + R"( + Arguments: + future(str): The future to wrap. + Returns: + A ``ProcessGroup::Work`` object which is associated with the completion of + the ``torch.futures.Future``. + This is the prefered way of constructing Work objects when writing a custom ProcessGroup + in python. + Example:: + >>> class SingleRankProcessGroup(torch.distributed.ProcessGroup): + >>> def broadcast(self, tensor_list, opts): + >>> fut = torch.futures.Future() + >>> fut.set_result(tensor_list) + >>> return torch._C._distributed_c10d._create_work_from_future(fut) + .. warning :: + This API is experimental and subject to change. + The returned Work object has multiple limitations: + - synchronize() does nothing. Use ``torch.futures.Future`` based synchronization. + - wait() ignored timeout argument. + - sourceRank() raises. + - abort() raises. + The provided Future object result must be a Tensor or a list of Tensors. + )"); + Py_RETURN_TRUE; } diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index c49fb2a045f..7572ccccb7a 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -69,6 +69,22 @@ class CpuTimer : public Timer { C10_REGISTER_TYPED_CLASS(TimerRegistry, c10::kCPU, CpuTimer); +std::vector extractTensors(const c10::IValue& result) { + if (result.isPyObject()) { + return result.toPyObjectHolder()->extractTensors(); + } + TORCH_INTERNAL_ASSERT( + result.isTensor() || result.isTensorList(), + "expected the hook result is either a Tensor or a TensorList found ", + result.tagKind()); + + if (result.isTensor()) { + return {result.toTensor()}; + } + + return result.toTensorVector(); +} + } // namespace Reducer::Reducer( @@ -494,7 +510,10 @@ void Reducer::set_divide_factor() { auto& workHandle = forwardPassWorkHandle_.workHandle; if (workHandle && !forwardPassWorkHandle_.useStaticWorldSize) { workHandle->wait(); - auto results = workHandle->result(); + // PyProcessGroup::PyWork doesn't expose value, so fetch it from the + // future + auto results = extractTensors(workHandle->getFuture()->value()); + // Guard against the results being empty TORCH_INTERNAL_ASSERT(results.size() > 0); at::Tensor& res = results.front();