mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[distributed] Make DDP work with python process group (#79176)
This PR enables python process group usage with DDP by doing the following: - Surface PG::Work::getFuture() as overridable() - Use Work::getFuture() to retrieve values from a PG. - Add _create_work_from_future python method that creates a Work object that wraps a Future. To test this changes we use both strategies to run DDP with a python based PG. The reason for offering two methods is that both have short-comings. The wrapper method is harder to troubleshoot as there's no visibility of how the future is used. The subclass method has memory management issues as can be noticed in the test suite by having to keep Work instances alive by storing them in PG fields. Pull Request resolved: https://github.com/pytorch/pytorch/pull/79176 Approved by: https://github.com/rohan-varma
This commit is contained in:
parent
3cc48184a5
commit
f6a45f7984
8 changed files with 297 additions and 13 deletions
154
test/distributed/test_c10d_pypg.py
Normal file
154
test/distributed/test_c10d_pypg.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
)
|
||||
from torch.futures import Future
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import test_c10d_common
|
||||
import weakref
|
||||
from torch._C._distributed_c10d import _create_work_from_future
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
)
|
||||
|
||||
def create_work(result):
|
||||
future = Future()
|
||||
future.set_result(result)
|
||||
return _create_work_from_future(future)
|
||||
|
||||
class MyWork(dist._Work):
|
||||
def __init__(self, result, pg):
|
||||
super().__init__()
|
||||
self.result_ = result
|
||||
self.future_ = torch.futures.Future()
|
||||
self.future_.set_result(result)
|
||||
self.pg_ = weakref.ref(pg)
|
||||
|
||||
def wait(self, timeout):
|
||||
self.pg_().wait_count += 1
|
||||
return True
|
||||
|
||||
def get_future(self):
|
||||
self.pg_().get_future_count += 1
|
||||
return self.future_
|
||||
|
||||
class LonelyRankProcessGroup(dist.ProcessGroup):
|
||||
"""
|
||||
This PG only supports world_size of 1
|
||||
"""
|
||||
def __init__(self, rank, world, use_wrapper):
|
||||
super(LonelyRankProcessGroup, self).__init__(rank, world)
|
||||
assert rank == 0
|
||||
assert world == 1
|
||||
|
||||
self._rank = rank
|
||||
self._world = world
|
||||
self.wait_count = 0
|
||||
self.get_future_count = 0
|
||||
self.use_wrapper = use_wrapper
|
||||
self._work = []
|
||||
|
||||
def broadcast(self, tensor_list, opts):
|
||||
if self.use_wrapper:
|
||||
return create_work(tensor_list)
|
||||
res = MyWork(tensor_list, self)
|
||||
self._work.append(res)
|
||||
return res
|
||||
|
||||
def allgather(self, output_tensors, input_tensor, opts):
|
||||
for o, i in zip(output_tensors[0], input_tensor):
|
||||
o.copy_(i)
|
||||
if self.use_wrapper:
|
||||
return create_work(output_tensors)
|
||||
|
||||
res = MyWork(output_tensors, self)
|
||||
self._work.append(res)
|
||||
|
||||
return res
|
||||
|
||||
def allreduce(self, tensors, opts):
|
||||
if self.use_wrapper:
|
||||
return create_work(tensors)
|
||||
res = MyWork(tensors, self)
|
||||
self._work.append(res)
|
||||
return res
|
||||
|
||||
def size(self):
|
||||
return self._world
|
||||
|
||||
def getBackendName(self):
|
||||
return "lonely-pg"
|
||||
|
||||
def __repr__(self):
|
||||
return f"PLG w:{self._world} r:{self._rank}"
|
||||
|
||||
# We cannot use parametrize as some tests are defined on the base class and use _get_process_group
|
||||
class AbstractDDPSingleRank(test_c10d_common.CommonDistributedDataParallelTest):
|
||||
def setUp(self):
|
||||
super(AbstractDDPSingleRank, self).setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return 1
|
||||
|
||||
def tearDown(self):
|
||||
super(AbstractDDPSingleRank, self).tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _get_process_group(self):
|
||||
return LonelyRankProcessGroup(self.rank, self.world_size, self.use_wrapper)
|
||||
|
||||
def test_ddp_invoke_work_object(self):
|
||||
pg = self._get_process_group()
|
||||
|
||||
torch.manual_seed(123)
|
||||
model = nn.Sequential(
|
||||
nn.Linear(2, 2),
|
||||
nn.ReLU()
|
||||
)
|
||||
wrapped_model = model
|
||||
input_tensor = torch.rand(2)
|
||||
model = DDP(model, process_group=pg)
|
||||
model(input_tensor).sum().backward()
|
||||
|
||||
ddp_grad = wrapped_model[0].bias.grad.clone()
|
||||
|
||||
wrapped_model.zero_grad()
|
||||
wrapped_model(input_tensor).sum().backward()
|
||||
self.assertEqual(wrapped_model[0].bias.grad, ddp_grad)
|
||||
if not self.use_wrapper:
|
||||
self.assertTrue(pg.wait_count > 0)
|
||||
self.assertTrue(pg.get_future_count > 0)
|
||||
|
||||
def test_ddp_with_pypg(self):
|
||||
pg = self._get_process_group()
|
||||
|
||||
self._test_ddp_with_process_group(pg, [torch.device("cpu")], device_ids=None)
|
||||
|
||||
def test_ddp_with_pypg_with_grad_views(self):
|
||||
pg = self._get_process_group()
|
||||
|
||||
self._test_ddp_with_process_group(pg, [torch.device("cpu")], device_ids=None, gradient_as_bucket_view=True)
|
||||
|
||||
class TestDDPWithWorkSubclass(AbstractDDPSingleRank, MultiProcessTestCase):
|
||||
@property
|
||||
def use_wrapper(self):
|
||||
return False
|
||||
|
||||
class TestDDPWithWorkWrapper(AbstractDDPSingleRank, MultiProcessTestCase):
|
||||
@property
|
||||
def use_wrapper(self):
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
@ -191,4 +191,56 @@ void ProcessGroup::init() {
|
|||
fmt::format("c10d.process_group_{}", getBackendName()));
|
||||
}
|
||||
|
||||
class FutureWrappingWork : public ProcessGroup::Work {
|
||||
public:
|
||||
FutureWrappingWork(c10::intrusive_ptr<c10::ivalue::Future> fut)
|
||||
: Work(), _fut(fut) {}
|
||||
|
||||
~FutureWrappingWork() {}
|
||||
|
||||
bool isCompleted() override {
|
||||
return _fut->completed();
|
||||
}
|
||||
|
||||
bool isSuccess() const override {
|
||||
return _fut->hasValue();
|
||||
}
|
||||
|
||||
std::exception_ptr exception() const override {
|
||||
return _fut->exception_ptr();
|
||||
}
|
||||
|
||||
int sourceRank() const override {
|
||||
TORCH_CHECK(false, "FutureWrappingWork::sourceRank() not implemented");
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> result() override {
|
||||
return _fut->value().toPyObjectHolder()->extractTensors();
|
||||
}
|
||||
|
||||
bool wait(std::chrono::milliseconds timeout) override {
|
||||
// FIXME
|
||||
TORCH_CHECK(
|
||||
timeout == kNoTimeout,
|
||||
"FutureWrappingWork::wait() with finite timeout not implemented");
|
||||
_fut->wait();
|
||||
return true;
|
||||
}
|
||||
|
||||
void abort() override {
|
||||
TORCH_CHECK(false, "FutureWrappingWork::abort() not implemented");
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
|
||||
return _fut;
|
||||
}
|
||||
|
||||
private:
|
||||
c10::intrusive_ptr<c10::ivalue::Future> _fut;
|
||||
};
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroup::Work::create_from_future(
|
||||
c10::intrusive_ptr<c10::ivalue::Future> future) {
|
||||
return c10::make_intrusive<FutureWrappingWork>(future);
|
||||
}
|
||||
} // namespace c10d
|
||||
|
|
|
|||
|
|
@ -149,6 +149,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
|||
|
||||
OpType retrieveOpType();
|
||||
|
||||
static c10::intrusive_ptr<Work> create_from_future(c10::intrusive_ptr<c10::ivalue::Future>);
|
||||
|
||||
protected:
|
||||
// Completes the work object and optionally sets the exception in a
|
||||
// thread-safe manner. Notifies all waiting condition variables as well.
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <c10d/ProcessGroup.hpp>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
|
|
@ -22,6 +23,22 @@ class PyProcessGroup : public ProcessGroup {
|
|||
wait, /* Name of function in C++ */
|
||||
timeout);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
|
||||
// We cannot use PYBIND11_OVERRIDE because:
|
||||
// 1. We have to >MANUALLY< unwrap the PyFutureWrapper and
|
||||
// 2. The python name is get_future
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
auto override = pybind11::get_override(static_cast<const ProcessGroup::Work *>(this), "get_future");
|
||||
|
||||
if (override) {
|
||||
py::object o = override();
|
||||
auto futWrapper = o.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>();
|
||||
return futWrapper->fut;
|
||||
}
|
||||
|
||||
return Work::getFuture();
|
||||
}
|
||||
};
|
||||
|
||||
using ProcessGroup::ProcessGroup;
|
||||
|
|
|
|||
|
|
@ -99,5 +99,26 @@ std::vector<at::Tensor> GradBucket::getGradients() const {
|
|||
}
|
||||
return per_parameter_tensors;
|
||||
}
|
||||
namespace detail {
|
||||
|
||||
at::Tensor parseCppCommHookResult(const c10::IValue& result) {
|
||||
if (result.isPyObject()) {
|
||||
std::vector<at::Tensor> tensors =
|
||||
result.toPyObjectHolder()->extractTensors();
|
||||
return tensors[0];
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
result.isTensor() || result.isTensorList(),
|
||||
"expected the hook result is either a Tensor or a TensorList found ",
|
||||
result.tagKind());
|
||||
|
||||
if (result.isTensor()) {
|
||||
return result.toTensor();
|
||||
}
|
||||
|
||||
return result.toTensorVector()[0];
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace c10d
|
||||
|
|
|
|||
|
|
@ -106,18 +106,7 @@ class TORCH_API CommHookInterface {
|
|||
namespace detail {
|
||||
// This helper function is called both by CppCommHookInterface below and inside
|
||||
// reducer.
|
||||
inline at::Tensor parseCppCommHookResult(
|
||||
const c10::IValue& result) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
result.isTensor() || result.isTensorList(),
|
||||
"expected the hook result is either a Tensor or a TensorList");
|
||||
|
||||
if (result.isTensor()) {
|
||||
return result.toTensor();
|
||||
}
|
||||
|
||||
return result.toTensorVector()[0];
|
||||
}
|
||||
at::Tensor parseCppCommHookResult(const c10::IValue& result);
|
||||
} // namespace detail
|
||||
|
||||
// This CppCommHook interface only requires implementing runHook method that
|
||||
|
|
|
|||
|
|
@ -1777,6 +1777,36 @@ Example::
|
|||
module.attr("_DEFAULT_PG_TIMEOUT") = py::cast(kProcessGroupDefaultTimeout);
|
||||
module.attr("_DEFAULT_NO_TIMEOUT") = py::cast(kNoTimeout);
|
||||
|
||||
module.def(
|
||||
"_create_work_from_future",
|
||||
[](std::shared_ptr<jit::PythonFutureWrapper> future) {
|
||||
return ::c10d::ProcessGroup::Work::create_from_future(future->fut);
|
||||
},
|
||||
py::arg("future"),
|
||||
R"(
|
||||
Arguments:
|
||||
future(str): The future to wrap.
|
||||
Returns:
|
||||
A ``ProcessGroup::Work`` object which is associated with the completion of
|
||||
the ``torch.futures.Future``.
|
||||
This is the prefered way of constructing Work objects when writing a custom ProcessGroup
|
||||
in python.
|
||||
Example::
|
||||
>>> class SingleRankProcessGroup(torch.distributed.ProcessGroup):
|
||||
>>> def broadcast(self, tensor_list, opts):
|
||||
>>> fut = torch.futures.Future()
|
||||
>>> fut.set_result(tensor_list)
|
||||
>>> return torch._C._distributed_c10d._create_work_from_future(fut)
|
||||
.. warning ::
|
||||
This API is experimental and subject to change.
|
||||
The returned Work object has multiple limitations:
|
||||
- synchronize() does nothing. Use ``torch.futures.Future`` based synchronization.
|
||||
- wait() ignored timeout argument.
|
||||
- sourceRank() raises.
|
||||
- abort() raises.
|
||||
The provided Future object result must be a Tensor or a list of Tensors.
|
||||
)");
|
||||
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -69,6 +69,22 @@ class CpuTimer : public Timer {
|
|||
|
||||
C10_REGISTER_TYPED_CLASS(TimerRegistry, c10::kCPU, CpuTimer);
|
||||
|
||||
std::vector<at::Tensor> extractTensors(const c10::IValue& result) {
|
||||
if (result.isPyObject()) {
|
||||
return result.toPyObjectHolder()->extractTensors();
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
result.isTensor() || result.isTensorList(),
|
||||
"expected the hook result is either a Tensor or a TensorList found ",
|
||||
result.tagKind());
|
||||
|
||||
if (result.isTensor()) {
|
||||
return {result.toTensor()};
|
||||
}
|
||||
|
||||
return result.toTensorVector();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Reducer::Reducer(
|
||||
|
|
@ -494,7 +510,10 @@ void Reducer::set_divide_factor() {
|
|||
auto& workHandle = forwardPassWorkHandle_.workHandle;
|
||||
if (workHandle && !forwardPassWorkHandle_.useStaticWorldSize) {
|
||||
workHandle->wait();
|
||||
auto results = workHandle->result();
|
||||
// PyProcessGroup::PyWork doesn't expose value, so fetch it from the
|
||||
// future
|
||||
auto results = extractTensors(workHandle->getFuture()->value());
|
||||
|
||||
// Guard against the results being empty
|
||||
TORCH_INTERNAL_ASSERT(results.size() > 0);
|
||||
at::Tensor& res = results.front();
|
||||
|
|
|
|||
Loading…
Reference in a new issue