pytorch/torch/csrc/distributed/c10d/PyProcessGroup.hpp
Rodrigo Kumpera f6a45f7984 [distributed] Make DDP work with python process group (#79176)
This PR enables python process group usage with DDP by doing the following:

- Surface PG::Work::getFuture() as overridable()
- Use Work::getFuture() to retrieve values from a PG.
- Add _create_work_from_future python method that creates a Work object that wraps a Future.

To test this changes we use both strategies to run DDP with a python based PG.

The reason for offering two methods is that both have short-comings.

The wrapper method is harder to troubleshoot as there's no visibility of how the future is used.

The subclass method has memory management issues as can be noticed in the test suite by having to keep Work instances alive by storing them in PG fields.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79176
Approved by: https://github.com/rohan-varma
2022-06-28 17:14:21 +00:00

138 lines
4.3 KiB
C++

#pragma once
#include <c10d/ProcessGroup.hpp>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/jit/python/pybind_utils.h>
namespace c10d {
// PyProcessGroup is a pybind11 trampoline class to allow a Python
// class to inherit from torch.distributed.ProcessGroup
class PyProcessGroup : public ProcessGroup {
public:
// PyWork is a pybind11 trampoline class to allow a Python
// class to inherit from torch.distributed.Work
class PyWork : public ProcessGroup::Work {
public:
PyWork() = default;
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
PYBIND11_OVERRIDE(
bool, /* Return type */
ProcessGroup::Work, /* Parent class */
wait, /* Name of function in C++ */
timeout);
}
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
// We cannot use PYBIND11_OVERRIDE because:
// 1. We have to >MANUALLY< unwrap the PyFutureWrapper and
// 2. The python name is get_future
pybind11::gil_scoped_acquire gil;
auto override = pybind11::get_override(static_cast<const ProcessGroup::Work *>(this), "get_future");
if (override) {
py::object o = override();
auto futWrapper = o.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>();
return futWrapper->fut;
}
return Work::getFuture();
}
};
using ProcessGroup::ProcessGroup;
const std::string getBackendName() const override {
PYBIND11_OVERRIDE_PURE(
std::string, /* Return type */
ProcessGroup, /* Parent class */
getBackendName, /* Name of function in C++ */
);
}
c10::intrusive_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
ProcessGroup, /* Parent class */
allgather, /* Name of function in C++ */
outputTensors,
inputTensors,
opts);
}
c10::intrusive_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
ProcessGroup, /* Parent class */
allreduce, /* Name of function in C++ */
tensors,
opts);
}
c10::intrusive_ptr<ProcessGroup::Work> barrier(
const BarrierOptions& opts = BarrierOptions()) {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
ProcessGroup, /* Parent class */
barrier, /* Name of function in C++ */
opts);
}
c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
ProcessGroup, /* Parent class */
broadcast, /* Name of function in C++ */
tensors,
opts);
}
c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
ProcessGroup, /* Parent class */
reduce_scatter, /* Name of function in C++ */
outputTensors,
inputTensors,
opts);
}
c10::intrusive_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
ProcessGroup, /* Parent class */
send, /* Name of function in C++ */
tensors,
dstRank,
tag);
}
c10::intrusive_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
ProcessGroup, /* Parent class */
recv, /* Name of function in C++ */
tensors,
srcRank,
tag);
}
};
} // namespace c10d