[distributed] Make DDP work with python process group (#79176)

This PR enables python process group usage with DDP by doing the following:

- Surface PG::Work::getFuture() as overridable()
- Use Work::getFuture() to retrieve values from a PG.
- Add _create_work_from_future python method that creates a Work object that wraps a Future.

To test this changes we use both strategies to run DDP with a python based PG.

The reason for offering two methods is that both have short-comings.

The wrapper method is harder to troubleshoot as there's no visibility of how the future is used.

The subclass method has memory management issues as can be noticed in the test suite by having to keep Work instances alive by storing them in PG fields.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79176
Approved by: https://github.com/rohan-varma
This commit is contained in:
Rodrigo Kumpera 2022-06-28 17:14:19 +00:00 committed by PyTorch MergeBot
parent 3cc48184a5
commit f6a45f7984
8 changed files with 297 additions and 13 deletions

View file

@ -0,0 +1,154 @@
# Owner(s): ["oncall: distributed"]
import os
import torch
import torch.distributed as dist
from torch.testing._internal.common_utils import (
run_tests,
)
from torch.futures import Future
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import test_c10d_common
import weakref
from torch._C._distributed_c10d import _create_work_from_future
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
)
def create_work(result):
future = Future()
future.set_result(result)
return _create_work_from_future(future)
class MyWork(dist._Work):
def __init__(self, result, pg):
super().__init__()
self.result_ = result
self.future_ = torch.futures.Future()
self.future_.set_result(result)
self.pg_ = weakref.ref(pg)
def wait(self, timeout):
self.pg_().wait_count += 1
return True
def get_future(self):
self.pg_().get_future_count += 1
return self.future_
class LonelyRankProcessGroup(dist.ProcessGroup):
"""
This PG only supports world_size of 1
"""
def __init__(self, rank, world, use_wrapper):
super(LonelyRankProcessGroup, self).__init__(rank, world)
assert rank == 0
assert world == 1
self._rank = rank
self._world = world
self.wait_count = 0
self.get_future_count = 0
self.use_wrapper = use_wrapper
self._work = []
def broadcast(self, tensor_list, opts):
if self.use_wrapper:
return create_work(tensor_list)
res = MyWork(tensor_list, self)
self._work.append(res)
return res
def allgather(self, output_tensors, input_tensor, opts):
for o, i in zip(output_tensors[0], input_tensor):
o.copy_(i)
if self.use_wrapper:
return create_work(output_tensors)
res = MyWork(output_tensors, self)
self._work.append(res)
return res
def allreduce(self, tensors, opts):
if self.use_wrapper:
return create_work(tensors)
res = MyWork(tensors, self)
self._work.append(res)
return res
def size(self):
return self._world
def getBackendName(self):
return "lonely-pg"
def __repr__(self):
return f"PLG w:{self._world} r:{self._rank}"
# We cannot use parametrize as some tests are defined on the base class and use _get_process_group
class AbstractDDPSingleRank(test_c10d_common.CommonDistributedDataParallelTest):
def setUp(self):
super(AbstractDDPSingleRank, self).setUp()
self._spawn_processes()
@property
def world_size(self):
return 1
def tearDown(self):
super(AbstractDDPSingleRank, self).tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def _get_process_group(self):
return LonelyRankProcessGroup(self.rank, self.world_size, self.use_wrapper)
def test_ddp_invoke_work_object(self):
pg = self._get_process_group()
torch.manual_seed(123)
model = nn.Sequential(
nn.Linear(2, 2),
nn.ReLU()
)
wrapped_model = model
input_tensor = torch.rand(2)
model = DDP(model, process_group=pg)
model(input_tensor).sum().backward()
ddp_grad = wrapped_model[0].bias.grad.clone()
wrapped_model.zero_grad()
wrapped_model(input_tensor).sum().backward()
self.assertEqual(wrapped_model[0].bias.grad, ddp_grad)
if not self.use_wrapper:
self.assertTrue(pg.wait_count > 0)
self.assertTrue(pg.get_future_count > 0)
def test_ddp_with_pypg(self):
pg = self._get_process_group()
self._test_ddp_with_process_group(pg, [torch.device("cpu")], device_ids=None)
def test_ddp_with_pypg_with_grad_views(self):
pg = self._get_process_group()
self._test_ddp_with_process_group(pg, [torch.device("cpu")], device_ids=None, gradient_as_bucket_view=True)
class TestDDPWithWorkSubclass(AbstractDDPSingleRank, MultiProcessTestCase):
@property
def use_wrapper(self):
return False
class TestDDPWithWorkWrapper(AbstractDDPSingleRank, MultiProcessTestCase):
@property
def use_wrapper(self):
return True
if __name__ == '__main__':
run_tests()

View file

@ -191,4 +191,56 @@ void ProcessGroup::init() {
fmt::format("c10d.process_group_{}", getBackendName()));
}
class FutureWrappingWork : public ProcessGroup::Work {
public:
FutureWrappingWork(c10::intrusive_ptr<c10::ivalue::Future> fut)
: Work(), _fut(fut) {}
~FutureWrappingWork() {}
bool isCompleted() override {
return _fut->completed();
}
bool isSuccess() const override {
return _fut->hasValue();
}
std::exception_ptr exception() const override {
return _fut->exception_ptr();
}
int sourceRank() const override {
TORCH_CHECK(false, "FutureWrappingWork::sourceRank() not implemented");
}
std::vector<at::Tensor> result() override {
return _fut->value().toPyObjectHolder()->extractTensors();
}
bool wait(std::chrono::milliseconds timeout) override {
// FIXME
TORCH_CHECK(
timeout == kNoTimeout,
"FutureWrappingWork::wait() with finite timeout not implemented");
_fut->wait();
return true;
}
void abort() override {
TORCH_CHECK(false, "FutureWrappingWork::abort() not implemented");
}
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
return _fut;
}
private:
c10::intrusive_ptr<c10::ivalue::Future> _fut;
};
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroup::Work::create_from_future(
c10::intrusive_ptr<c10::ivalue::Future> future) {
return c10::make_intrusive<FutureWrappingWork>(future);
}
} // namespace c10d

View file

@ -149,6 +149,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
OpType retrieveOpType();
static c10::intrusive_ptr<Work> create_from_future(c10::intrusive_ptr<c10::ivalue::Future>);
protected:
// Completes the work object and optionally sets the exception in a
// thread-safe manner. Notifies all waiting condition variables as well.

View file

@ -2,6 +2,7 @@
#include <c10d/ProcessGroup.hpp>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/jit/python/pybind_utils.h>
namespace c10d {
@ -22,6 +23,22 @@ class PyProcessGroup : public ProcessGroup {
wait, /* Name of function in C++ */
timeout);
}
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
// We cannot use PYBIND11_OVERRIDE because:
// 1. We have to >MANUALLY< unwrap the PyFutureWrapper and
// 2. The python name is get_future
pybind11::gil_scoped_acquire gil;
auto override = pybind11::get_override(static_cast<const ProcessGroup::Work *>(this), "get_future");
if (override) {
py::object o = override();
auto futWrapper = o.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>();
return futWrapper->fut;
}
return Work::getFuture();
}
};
using ProcessGroup::ProcessGroup;

View file

@ -99,5 +99,26 @@ std::vector<at::Tensor> GradBucket::getGradients() const {
}
return per_parameter_tensors;
}
namespace detail {
at::Tensor parseCppCommHookResult(const c10::IValue& result) {
if (result.isPyObject()) {
std::vector<at::Tensor> tensors =
result.toPyObjectHolder()->extractTensors();
return tensors[0];
}
TORCH_INTERNAL_ASSERT(
result.isTensor() || result.isTensorList(),
"expected the hook result is either a Tensor or a TensorList found ",
result.tagKind());
if (result.isTensor()) {
return result.toTensor();
}
return result.toTensorVector()[0];
}
} // namespace detail
} // namespace c10d

View file

@ -106,18 +106,7 @@ class TORCH_API CommHookInterface {
namespace detail {
// This helper function is called both by CppCommHookInterface below and inside
// reducer.
inline at::Tensor parseCppCommHookResult(
const c10::IValue& result) {
TORCH_INTERNAL_ASSERT(
result.isTensor() || result.isTensorList(),
"expected the hook result is either a Tensor or a TensorList");
if (result.isTensor()) {
return result.toTensor();
}
return result.toTensorVector()[0];
}
at::Tensor parseCppCommHookResult(const c10::IValue& result);
} // namespace detail
// This CppCommHook interface only requires implementing runHook method that

View file

@ -1777,6 +1777,36 @@ Example::
module.attr("_DEFAULT_PG_TIMEOUT") = py::cast(kProcessGroupDefaultTimeout);
module.attr("_DEFAULT_NO_TIMEOUT") = py::cast(kNoTimeout);
module.def(
"_create_work_from_future",
[](std::shared_ptr<jit::PythonFutureWrapper> future) {
return ::c10d::ProcessGroup::Work::create_from_future(future->fut);
},
py::arg("future"),
R"(
Arguments:
future(str): The future to wrap.
Returns:
A ``ProcessGroup::Work`` object which is associated with the completion of
the ``torch.futures.Future``.
This is the prefered way of constructing Work objects when writing a custom ProcessGroup
in python.
Example::
>>> class SingleRankProcessGroup(torch.distributed.ProcessGroup):
>>> def broadcast(self, tensor_list, opts):
>>> fut = torch.futures.Future()
>>> fut.set_result(tensor_list)
>>> return torch._C._distributed_c10d._create_work_from_future(fut)
.. warning ::
This API is experimental and subject to change.
The returned Work object has multiple limitations:
- synchronize() does nothing. Use ``torch.futures.Future`` based synchronization.
- wait() ignored timeout argument.
- sourceRank() raises.
- abort() raises.
The provided Future object result must be a Tensor or a list of Tensors.
)");
Py_RETURN_TRUE;
}

View file

@ -69,6 +69,22 @@ class CpuTimer : public Timer {
C10_REGISTER_TYPED_CLASS(TimerRegistry, c10::kCPU, CpuTimer);
std::vector<at::Tensor> extractTensors(const c10::IValue& result) {
if (result.isPyObject()) {
return result.toPyObjectHolder()->extractTensors();
}
TORCH_INTERNAL_ASSERT(
result.isTensor() || result.isTensorList(),
"expected the hook result is either a Tensor or a TensorList found ",
result.tagKind());
if (result.isTensor()) {
return {result.toTensor()};
}
return result.toTensorVector();
}
} // namespace
Reducer::Reducer(
@ -494,7 +510,10 @@ void Reducer::set_divide_factor() {
auto& workHandle = forwardPassWorkHandle_.workHandle;
if (workHandle && !forwardPassWorkHandle_.useStaticWorldSize) {
workHandle->wait();
auto results = workHandle->result();
// PyProcessGroup::PyWork doesn't expose value, so fetch it from the
// future
auto results = extractTensors(workHandle->getFuture()->value());
// Guard against the results being empty
TORCH_INTERNAL_ASSERT(results.size() > 0);
at::Tensor& res = results.front();