pytorch/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp
Rohan Varma f5341bd5e6 Enhance ProcessGroupWrapper with additional checks + refactor (#60237)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60237

Closes https://github.com/pytorch/pytorch/issues/58711

This diff refactors the collective consistency checking in `ProcessGroupWrapper` as described in the above issue. In particular, we no longer run separate verification checks (`all_gather`s) for shapes, op type, etc. Instead, we implement a function `serialize_fingerprint` to serialize all this data into a single tensor and only verify that.

This has the benefit of being a lot more extensible, the developer does not need to add separate `all_gather` calls in order to verify additional data in the future. We can also provide some sort of mechanism where we allow data that needs to be verified to be "registered" in the `CollectiveFingerPrint` struct and make it even easier to add additional data, we can consider doing this if there are significant additions to `process group wrapper`.

We now also begin to check tensor `dtypes` and device types for consistency as well. Tests are refactored/added accordingly.
ghstack-source-id: 132520261

Test Plan: CI

Reviewed By: cbalioglu

Differential Revision: D28597287

fbshipit-source-id: b09f14f628df9e2457623ba81fc13fd4e214f3c9
2021-06-28 10:24:11 -07:00

351 lines
12 KiB
C++

#include <c10d/ProcessGroupWrapper.hpp>
#ifdef USE_C10D_GLOO
#include <c10/core/Allocator.h>
#include <c10/core/DeviceType.h>
#include <c10/core/ScalarType.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/intrusive_ptr.h>
#include <c10/util/irange.h>
#include <c10d/ProcessGroup.hpp>
#include <c10d/ProcessGroupGloo.hpp>
#include <stdexcept>
namespace c10d {
namespace {
// A container for information about a particular collective, including optype
// and input tensors (if applicable.)
struct CollectiveFingerPrint {
// Current collective's operation type.
OpType op_type_;
// Ref to input tensors, if given, of the collective. If given, shapes will be
// checked across processes to ensure valid input into the collective.
const std::vector<at::Tensor>& input_tensors_;
// input tensor data types
std::vector<int8_t> tensor_dtypes_;
// input tensor device types
std::vector<int8_t> tensor_device_types_;
explicit CollectiveFingerPrint(
OpType op_type,
const std::vector<at::Tensor>& input_tensors)
: op_type_(op_type), input_tensors_(input_tensors) {
tensor_dtypes_.reserve(input_tensors.size());
tensor_device_types_.reserve(input_tensors.size());
for (const at::Tensor& t : input_tensors_) {
tensor_dtypes_.push_back(static_cast<int8_t>(t.dtype().toScalarType()));
tensor_device_types_.push_back(static_cast<int8_t>(t.device().type()));
}
}
// Logs collective information in case of a failure.
friend std::ostream& operator<<(
std::ostream& output,
const CollectiveFingerPrint& collective_fingerprint);
at::Tensor serialize_fingerprint() {
auto data = std::make_unique<std::vector<int64_t>>();
// std::vector<int64_t> data;
// OpType
data->push_back(static_cast<int64_t>(op_type_));
// Shapes
for (const auto& tensor : input_tensors_) {
auto sizes = tensor.sizes().vec();
for (const auto& s : sizes) {
data->push_back(s);
}
}
// tensor dtypes
for (const auto& type : tensor_dtypes_) {
data->push_back(type);
}
// device types
for (const auto& d : tensor_device_types_) {
data->push_back(d);
}
// Serialize data into tensor
int64_t data_size = data->size();
// Need to release here and get the ptr due to C++ parameter evaluation
// order.
auto d = data.release();
at::Tensor serialized_tensor =
at::for_blob(d->data(), {data_size})
.context(
d,
[](void* ctx) {
delete static_cast<std::vector<int64_t>*>(ctx);
})
.options(at::TensorOptions().dtype(at::kLong))
.make_tensor();
return serialized_tensor;
}
void verify_tensors(
std::vector<at::Tensor>& tensors_to_verify,
c10::intrusive_ptr<ProcessGroup>& pg) {
// Create output tensor data structure to pass into allgather.
std::vector<std::vector<at::Tensor>> output_tensors;
output_tensors.reserve(tensors_to_verify.size());
for (auto& tensor_shape : tensors_to_verify) {
std::vector<at::Tensor> outputs;
outputs.reserve(pg->getSize());
for (int i = 0; i < pg->getSize(); ++i) {
outputs.emplace_back(at::zeros_like(tensor_shape));
}
output_tensors.emplace_back(outputs);
}
// Allgather tensor shapes.
pg->allgather(output_tensors, tensors_to_verify)->wait();
// Verify equivalence
for (const auto i : c10::irange(output_tensors.size())) {
const std::vector<at::Tensor> gathered_tensors = output_tensors[i];
const at::Tensor reference_tensor = tensors_to_verify[i];
for (const auto& rank_tensor : gathered_tensors) {
if (!rank_tensor.equal(reference_tensor)) {
std::stringstream ss;
ss << "Detected mismatch between collectives on ranks. Rank "
<< pg->getRank()
<< " is running inconsistent collective: " << *this;
TORCH_CHECK(false, ss.str());
}
}
}
}
// Executes and verifies the collective fingerprint.
void verify(c10::intrusive_ptr<ProcessGroup> pg) {
at::Tensor serialized_tensor = serialize_fingerprint();
std::vector<at::Tensor> inp{serialized_tensor};
// First verify tensor shapes. This is needed because if e.g. tensor dim
// does not match across processes, directly verifying tensors will result
// in a crash during allgather, but we'd actually like to report a
// description about the inconsistency. Since the input is just a 1D tensor
// the shape will be a single int k_i and we need to make sure k_i is
// consistent across the whole world.
std::vector<at::Tensor> sp = c10d::getTensorShapes(inp);
verify_tensors(sp, pg);
// Now verify consistency for the actual tensor.
verify_tensors(inp, pg);
}
};
std::ostream& operator<<(
std::ostream& output,
const CollectiveFingerPrint& collective_fingerprint) {
std::string collectiveInfo;
if (!collective_fingerprint.input_tensors_.empty()) {
// Convert dtype and device type info to string.
std::vector<std::string> dtype_strs;
std::vector<std::string> device_type_strs;
for (const auto& tensor_dtype : collective_fingerprint.tensor_dtypes_) {
dtype_strs.push_back(
c10::toString(static_cast<at::ScalarType>(tensor_dtype)));
}
for (const auto& tensor_device_type :
collective_fingerprint.tensor_device_types_) {
device_type_strs.push_back(
c10::toString(static_cast<at::DeviceType>(tensor_device_type)));
}
collectiveInfo = c10::str(
"CollectiveFingerPrint(",
"OpType=",
opTypeToString(collective_fingerprint.op_type_),
", TensorShape=",
(collective_fingerprint.input_tensors_)[0].sizes(),
", TensorDtypes=",
(dtype_strs),
", TensorDeviceTypes=",
(device_type_strs));
} else {
collectiveInfo = c10::str(
"CollectiveFingerPrint(",
"OpType=",
opTypeToString(collective_fingerprint.op_type_));
}
return output << collectiveInfo;
}
} // namespace
ProcessGroupWrapper::ProcessGroupWrapper(
c10::intrusive_ptr<ProcessGroup> pg,
c10::intrusive_ptr<ProcessGroupGloo> glooPg)
: ProcessGroup(pg->getRank(), pg->getSize()), pg_(pg), glooPg_(glooPg) {
// Set the sequence number for the underlying process group.
pg_->setSequenceNumberForGroup();
}
const std::string ProcessGroupWrapper::getBackendName() const {
return pg_->getBackendName();
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::broadcast(
std::vector<at::Tensor>& data,
const BroadcastOptions& opts) {
runCollectiveChecks(OpType::BROADCAST, data);
return pg_->broadcast(data, opts);
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::allreduce(
std::vector<at::Tensor>& data,
const AllreduceOptions& opts) {
runCollectiveChecks(OpType::ALLREDUCE, data);
return pg_->allreduce(data, opts);
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts) {
// NOTE: We don't enforce shape checking for allreduce_coalesced because
// the implementation itself does not enforce it we have tests that use
// inconsistent shapes, see python implementation in distributed_c10d for
// details.
runCollectiveChecks(OpType::ALLREDUCE_COALESCED, {});
return pg_->allreduce_coalesced(tensors, opts);
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts) {
runCollectiveChecks(OpType::REDUCE, tensors);
return pg_->reduce(tensors, opts);
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts) {
runCollectiveChecks(OpType::ALLGATHER, inputTensors);
return pg_->allgather(outputTensors, inputTensors, opts);
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::_allgather_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const AllgatherOptions& opts) {
std::vector<at::Tensor> inputTensors({inputBuffer});
runCollectiveChecks(OpType::_ALLGATHER_BASE, inputTensors);
return pg_->_allgather_base(outputBuffer, inputBuffer, opts);
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::allgather_coalesced(
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts) {
// NOTE: We don't enforce shape checking for allgather_coalesced because
// the implementation itself does not enforce it we have tests that use
// inconsistent shapes, see python implementation in distributed_c10d for
// details.
runCollectiveChecks(OpType::ALLGATHER_COALESCED, {});
return pg_->allgather_coalesced(outputTensorLists, inputTensors, opts);
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts) {
runCollectiveChecks(OpType::GATHER, inputTensors);
return pg_->gather(outputTensors, inputTensors, opts);
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts) {
runCollectiveChecks(OpType::SCATTER, outputTensors);
return pg_->scatter(outputTensors, inputTensors, opts);
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts) {
runCollectiveChecks(OpType::REDUCE_SCATTER, outputTensors);
return pg_->reduce_scatter(outputTensors, inputTensors, opts);
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts) {
// alltoall supports uneven split, so don't enforce shape checking.
runCollectiveChecks(OpType::ALLTOALL_BASE, {});
return pg_->alltoall_base(
outputTensor, inputTensor, outputSplitSizes, inputSplitSizes, opts);
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts) {
// alltoall supports uneven split, so don't enforce shape checking.
runCollectiveChecks(OpType::ALLTOALL, {});
return pg_->alltoall(outputTensors, inputTensors, opts);
}
void ProcessGroupWrapper::monitoredBarrier(
const BarrierOptions& opts,
bool waitAllRanks) {
return pg_->monitoredBarrier(opts, waitAllRanks);
}
void ProcessGroupWrapper::setSequenceNumberForGroup() {
// Set underlying pg's sequence number if it is not set.
if (pg_->getSequenceNumberForGroup() == 0) {
// Set the sequence number for the underlying process group.
pg_->setSequenceNumberForGroup();
}
}
uint64_t ProcessGroupWrapper::getSequenceNumberForGroup() {
return pg_->getSequenceNumberForGroup();
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) {
return pg_->send(tensors, dstRank, tag);
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) {
return pg_->recv(tensors, srcRank, tag);
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) {
return pg_->recvAnysource(tensors, tag);
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::barrier(
const BarrierOptions& opts) {
runCollectiveChecks(OpType::BARRIER, {});
return pg_->barrier(opts);
}
void ProcessGroupWrapper::runCollectiveChecks(
OpType op_type,
const std::vector<at::Tensor>& tensors) const {
// first perform a monitored barrier to ensure all ranks can synchronize.
c10d::BarrierOptions options;
// TODO: we should use wrapped pg_'s timeout here, but C++ ProcessGroup API
// does not expose timeout.
glooPg_->monitoredBarrier(options, /* waitAllRanks */ true);
auto finger_print = CollectiveFingerPrint(op_type, tensors);
// Will throw if an ill-formed collective is detected.
finger_print.verify(glooPg_);
}
} // namespace c10d
#endif // USE_C10D_GLOO