ProcessGroupUCC tests (#83285)

- [x] Direct dependency on UCX is completely removed, UCC active set API always enabled
- [x] Remove `TORCH_UCC_PROFILING_ENABLE`, always enable profiling
- [x] Fixes profiling of `recv` and `all_gather`
- [x] Use the NCCL TL of UCC on CUDA, as  the UCP TL is not well supported on CUDA

Most tests are passing, but there are a few skipped tests:
- `scatter` and `gather` are not supported by the UCP TL of UCC on CPU tensors
- A few flaky tests in PyTorch's CI environment
- Profiler-related failures, some of them will be fixed by @Fuzzkatt in https://github.com/pytorch/pytorch/pull/84368

After this PR is merged, I will continue to work on these skipped failures.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83285
Approved by: https://github.com/vtlam, https://github.com/malfet, https://github.com/kwen2501
This commit is contained in:
Xiang Gao 2022-09-10 10:56:05 +00:00 committed by PyTorch MergeBot
parent 2765243cd5
commit 08c4f8c7a7
10 changed files with 64 additions and 494 deletions

View file

@ -36,7 +36,7 @@ function install_ucc() {
git submodule update --init --recursive
./autogen.sh
./configure --prefix=$UCC_HOME --with-ucx=$UCX_HOME --with-nccl=no --with-cuda=$with_cuda
./configure --prefix=$UCC_HOME --with-ucx=$UCX_HOME --with-cuda=$with_cuda
time make -j
sudo make install

View file

@ -27,7 +27,7 @@ if NO_MULTIPROCESSING_SPAWN:
BACKEND = os.environ["BACKEND"]
if BACKEND == "gloo" or BACKEND == "nccl":
if BACKEND == "gloo" or BACKEND == "nccl" or BACKEND == "ucc":
class TestDistBackendWithSpawn(TestDistBackend, DistributedTest._DistTestBase):
def setUp(self):

View file

@ -299,6 +299,14 @@ if dist.is_available():
"WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3",
"TEST_REPORT_SOURCE_OVERRIDE": "dist-gloo",
}
if dist.is_ucc_available():
DISTRIBUTED_TESTS_CONFIG["ucc"] = {
"WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3",
"TEST_REPORT_SOURCE_OVERRIDE": "dist-ucc",
"UCX_TLS": "tcp",
"UCC_TLS": "nccl,ucp",
"UCC_TL_UCP_TUNE": "cuda:0", # don't use UCP TL on CUDA as it is not well supported
}
# https://stackoverflow.com/questions/2549939/get-signal-names-from-numbers-in-python
SIGNALS_TO_NAMES_DICT = {

View file

@ -444,6 +444,11 @@ target_compile_options(torch_python PRIVATE ${TORCH_PYTHON_COMPILE_OPTIONS})
target_include_directories(torch_python PUBLIC ${TORCH_PYTHON_INCLUDE_DIRECTORIES})
if(USE_UCC)
target_link_libraries(torch_python PRIVATE __caffe2_ucc)
target_compile_definitions(torch_python PRIVATE USE_UCC)
endif()
if(BUILD_ONEDNN_GRAPH)
target_compile_definitions(torch_python PRIVATE "-DBUILD_ONEDNN_GRAPH")
target_compile_definitions(torch_cpu PRIVATE "-DBUILD_ONEDNN_GRAPH")

View file

@ -13,18 +13,6 @@ namespace c10d {
namespace {
constexpr int64_t kBusyWaitMillis = 10;
const std::map<c10::DeviceType, ucs_memory_type_t> ucs_mtype_map = {
{c10::kCPU, UCS_MEMORY_TYPE_HOST},
{c10::kCUDA, UCS_MEMORY_TYPE_CUDA},
};
ucs_memory_type_t to_ucs_memType(c10::DeviceType _c10_type) {
if (ucs_mtype_map.find(_c10_type) != ucs_mtype_map.end())
return ucs_mtype_map.at(_c10_type);
else
return UCS_MEMORY_TYPE_UNKNOWN;
}
const std::map<c10::DeviceType, ucc_memory_type_t> ucc_mtype_map = {
{c10::kCPU, UCC_MEMORY_TYPE_HOST},
{c10::kCUDA, UCC_MEMORY_TYPE_CUDA},
@ -97,7 +85,6 @@ ucc_reduction_op_t to_ucc_reduceOp(
struct torch_ucc_config_t {
c10::once_flag flag;
std::array<bool, 32> blocking_wait;
bool enable_profiling;
bool enable_comms_logger;
bool use_future;
// Sharing UCC communicator among multiple PGs to save resource.
@ -163,7 +150,6 @@ std::vector<OpType> parse_blocking_wait(std::string op_list_string) {
void read_config() {
// default configuration
torch_ucc_config.blocking_wait.fill(false);
torch_ucc_config.enable_profiling = false;
torch_ucc_config.use_future = true;
torch_ucc_config.shared_comm = false;
torch_ucc_config.use_allgatherv = false;
@ -186,8 +172,6 @@ void read_config() {
torch_ucc_config.use_future =
std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_FUTURE"));
torch_ucc_config.enable_profiling =
std::stoi(torch_ucc_envs_map.at("TORCH_UCC_PROFILING_ENABLE"));
torch_ucc_config.shared_comm =
std::stoi(torch_ucc_envs_map.at("TORCH_UCC_SHARED_COMM"));
torch_ucc_config.use_allgatherv =
@ -294,6 +278,14 @@ c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupUCC::WorkUCC::getFuture() {
return future_;
}
int ProcessGroupUCC::WorkUCC::sourceRank() const {
if (opType_ != OpType::RECV && opType_ != OpType::RECVANYSOURCE) {
// Throw an error
return ProcessGroup::Work::sourceRank();
}
return sourceRank_;
}
std::vector<at::Tensor> ProcessGroupUCC::WorkUCC::result() {
return *outputs_;
}
@ -327,7 +319,6 @@ Comm::Comm(
bool is_health_check)
: logger(logger_),
oob(oob_),
ucx_comm(oob->size, logger),
ucc_comm(oob, logger),
finalize_phase(
is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_FINALIZE),
@ -364,7 +355,7 @@ std::shared_ptr<Comm> Comm::get_comm(
static uint32_t comm_id;
std::lock_guard<std::mutex> lock(m);
id = (comm_id % TORCH_UCX_MAX_COMM);
id = comm_id;
std::string group_id = "group_id";
if (is_health_check) {
@ -419,126 +410,6 @@ std::shared_ptr<Comm> Comm::get_comm(
}
}
void Comm::ucx_connect_eps(
std::vector<ucp_ep_h>& eps,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob) {
ucp_address_t* local_addr;
size_t local_addr_len;
std::vector<uint8_t> peer_addr;
TORCH_UCX_CHECK(
ucp_worker_get_address(ucx_comm.worker, &local_addr, &local_addr_len),
"failed to get worker address");
std::vector<uint8_t> val = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(local_addr),
reinterpret_cast<uint8_t*>(local_addr) + local_addr_len);
oob->store->set(oob->getKey("wa" + std::to_string(oob->rank)), val);
ucp_worker_release_address(ucx_comm.worker, local_addr);
eps.resize(oob->size);
for (int i = 0; i < oob->size; i++) {
peer_addr = oob->store->get(oob->getKey("wa" + std::to_string(i)));
ucp_ep_params_t ep_params;
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
ep_params.address = reinterpret_cast<ucp_address_t*>(peer_addr.data());
TORCH_UCX_CHECK(
ucp_ep_create(ucx_comm.worker, &ep_params, &(eps[i])),
c10::str("failed to create endpoint with rank ", i));
}
}
void Comm::ucx_disconnect_eps(
std::vector<ucp_ep_h>& eps,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob) {
ucs_status_t st;
for (ucp_ep_h& ep : eps) {
ucs_status_ptr_t close_req = ucp_ep_close_nb(ep, UCP_EP_CLOSE_MODE_FLUSH);
if (UCS_PTR_IS_ERR(close_req)) {
TORCH_UCC_LOG_ERROR(
finalize_phase, "failed to close endpoint, ignore and continue...");
return;
}
if (UCS_PTR_IS_PTR(close_req)) {
do {
ucp_worker_progress(ucx_comm.worker);
st = ucp_request_check_status(close_req);
} while (st != UCS_OK);
ucp_request_free(close_req);
}
}
if (!eps.size()) {
return;
}
try {
auto sz = (size_t)oob->store->add(oob->getKey("epclosed"), 1);
while (sz != eps.size()) {
ucp_worker_progress(ucx_comm.worker);
std::this_thread::sleep_for(std::chrono::milliseconds(kBusyWaitMillis));
sz = (size_t)oob->store->add(oob->getKey("epclosed"), 0);
}
} catch (std::exception& ex) {
LOG(ERROR) << "(disconnect_eps) Caught error in Store Operation .. "
<< "[" << ex.what() << "]";
}
}
ucc_coll_req_h Comm::send_nb(
ucp_ep_h ep,
void* data,
ucs_memory_type_t mtype,
size_t size,
ucp_tag_t ucp_tag) {
ucs_status_ptr_t st;
ucp_request_param_t params;
params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_DATATYPE | UCP_OP_ATTR_FIELD_MEMORY_TYPE;
params.datatype = ucp_dt_make_contig(size);
params.memory_type = mtype;
params.cb.send = [](void* request, ucs_status_t status, void* user_data) {
static_cast<ucc_coll_req_h>(request)->status = UCC_OK;
};
st = ucp_tag_send_nbx(ep, data, 1, ucp_tag, &params);
if (UCS_PTR_IS_ERR(st)) {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_COLL_POST,
c10::str(
"failed to send message: ", ucs_status_string(UCS_PTR_STATUS(st))));
throw std::runtime_error(ucs_status_string(UCS_PTR_STATUS(st)));
}
return reinterpret_cast<ucc_coll_req_h>(st);
}
ucc_coll_req_h Comm::recv_nb(
void* data,
ucs_memory_type_t mtype,
size_t size,
ucp_tag_t ucp_tag,
ucp_tag_t ucp_tag_mask) {
ucs_status_ptr_t st;
ucp_request_param_t params;
params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_DATATYPE | UCP_OP_ATTR_FIELD_MEMORY_TYPE;
params.datatype = ucp_dt_make_contig(size);
params.cb.recv = [](void* request,
ucs_status_t status,
const ucp_tag_recv_info_t* info,
void* user_data) {
static_cast<ucc_coll_req_h>(request)->status = UCC_OK;
};
params.memory_type = mtype;
st = ucp_tag_recv_nbx(
ucx_comm.worker, data, 1, ucp_tag, ucp_tag_mask, &params);
if (UCS_PTR_IS_ERR(st)) {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_COLL_POST,
c10::str(
"failed to recv message: ", ucs_status_string(UCS_PTR_STATUS(st))));
throw std::runtime_error(ucs_status_string(UCS_PTR_STATUS(st)));
}
return reinterpret_cast<ucc_coll_req_h>(st);
}
void Comm::ucc_create_team(
ucc_team_h& team,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob) {
@ -582,34 +453,6 @@ void Comm::ucc_destroy_team(ucc_team_h& team) {
lock.unlock();
}
c10::intrusive_ptr<ProcessGroup::Work> Comm::enqueue_p2p(
OpType opType,
ucc_coll_req_h request,
const char* prof_title) {
auto work =
c10::make_intrusive<ProcessGroupUCC::WorkUCC>(opType, prof_title, logger);
if (torch_ucc_config.use_future) {
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()));
}
if (request == nullptr) {
// p2p2 request completed immediately don't save it to progress queue
// and mark future completed immediately
if (torch_ucc_config.use_future) {
work->future_->markCompleted(c10::IValue(std::vector<at::Tensor>()));
}
return work;
}
auto entry =
std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucx_comm, request);
work->entry_ = entry;
std::unique_lock<std::mutex> lock(mutex);
progress_queue.push_back(entry);
lock.unlock();
queue_produce_cv.notify_one();
return work;
}
void Comm::enqueue_collective(
std::unique_ptr<ProcessGroupUCC::WorkData> data,
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
@ -688,7 +531,6 @@ void Comm::progress_loop() {
try {
while (work->request_->status > 0) {
ucc_comm.progress();
ucx_comm.progress();
}
if (work->request_->status < 0) {
eptr = std::make_exception_ptr(
@ -724,7 +566,7 @@ ProcessGroupUCC::ProcessGroupUCC(
comm = nullptr;
cuda_ee = nullptr;
static uint32_t id = 0;
uint32_t pg_id = (id++ % TORCH_UCX_MAX_COMM);
uint32_t pg_id = id++;
logger = c10::make_intrusive<ProcessGroupUCCLogger>(
c10::str("[Rank ", rank_, "]", "[ProcessGroupUCC-", pg_id, "]"),
@ -769,20 +611,10 @@ ProcessGroupUCC::~ProcessGroupUCC() {
comm->ucc_destroy_team(team);
TORCH_UCC_LOG_INFO(
TORCH_UCC_FINALIZE, "Successfully destroyed UCC library");
comm->ucx_disconnect_eps(eps, oob);
TORCH_UCC_LOG_INFO(
TORCH_UCC_FINALIZE, "Successfully destroyed UCX library");
try {
if (cuda_ee) {
ucc_ee_destroy(cuda_ee);
}
if ((size_t)oob->store->add(oob->getKey("ucc_pg_closed"), 1) ==
eps.size()) {
std::vector<uint8_t> val = {1};
oob->store->set(oob->getKey("ucc_pg_finished"), val);
} else {
oob->store->wait({oob->getKey("ucc_pg_finished")});
}
} catch (std::exception& ex) {
TORCH_UCC_LOG_INFO(
TORCH_UCC_FINALIZE,
@ -817,7 +649,6 @@ void ProcessGroupUCC::runHealthCheck() {
struct HealthCheckData {
std::mutex healthCheckMutex;
std::condition_variable healthCheckCv;
bool ucxHealthCheckSuccess = false;
bool uccHealthCheckSuccess = false;
std::exception_ptr healthCheckException;
} healthCheckData;
@ -837,8 +668,6 @@ void ProcessGroupUCC::runHealthCheck() {
oob->rank = this->oob->rank;
oob->size = this->oob->size;
oob->store = this->oob->store;
std::vector<ucp_ep_h> eps;
ucc_team_h team = nullptr;
uint32_t comm_id;
#ifdef USE_CUDA
@ -847,19 +676,6 @@ void ProcessGroupUCC::runHealthCheck() {
}
#endif
auto comm = Comm::get_comm(comm_id, device, oob, logger, true);
comm->ucx_connect_eps(eps, oob);
comm->ucx_disconnect_eps(eps, oob);
TORCH_UCC_LOG_INFO(
TORCH_UCC_HEALTH_CHECK,
c10::str(
"UCX library health check succeed for device ",
c10::DeviceTypeName(device.type())));
// Mark ucx health check as complete.
if (is_last_device) {
std::lock_guard<std::mutex> lk(healthCheckData.healthCheckMutex);
healthCheckData.ucxHealthCheckSuccess = true;
}
comm->ucc_create_team(team, oob);
comm->ucc_destroy_team(team);
TORCH_UCC_LOG_INFO(
@ -898,18 +714,13 @@ void ProcessGroupUCC::runHealthCheck() {
" msec for UCC health check to complete."));
std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex);
healthCheckData.healthCheckCv.wait_for(lock, timeout_, [&healthCheckData]() {
return healthCheckData.ucxHealthCheckSuccess &&
healthCheckData.uccHealthCheckSuccess;
return healthCheckData.uccHealthCheckSuccess;
});
if (healthCheckData.healthCheckException) {
std::rethrow_exception(healthCheckData.healthCheckException);
}
// If there is no exception, the likely culprit is a timeout/hang
TORCH_CHECK(
healthCheckData.ucxHealthCheckSuccess,
"ProcessGroupUCC: Health check failure: Failed to initialize UCX on rank ",
rank_);
TORCH_CHECK(
healthCheckData.uccHealthCheckSuccess,
"ProcessGroupUCC: Health check failure: Failed to initialize UCC on rank ",
@ -948,8 +759,12 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::collective_post(
std::vector<at::Tensor>& outputTensors,
const char* prof_title) {
set_timeout(coll);
auto work = c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
opType, torch_ucc_config.enable_profiling ? prof_title : nullptr, logger);
auto work =
c10::make_intrusive<ProcessGroupUCC::WorkUCC>(opType, prof_title, logger);
if (opType == OpType::RECV) {
work->sourceRank_ = coll.root;
}
RECORD_COMMS_TRACE(
logger->trace_generator,
@ -1106,7 +921,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::allgather(
tensor.device(),
inputTensors,
outputTensors[0],
"ucc:allgather");
"ucc:all_gather");
}
}
@ -1689,7 +1504,6 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::send(
auto& tensor = tensors[0];
initComm(tensor.device());
#ifdef USE_ACTIVE_SETS
WorkData* data = new WorkData();
ucc_coll_args_t coll;
coll.tag = tag;
@ -1717,28 +1531,6 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::send(
tensors,
tensors,
"ucc:send");
#else
ucp_tag_t ucp_tag;
TORCH_UCX_MAKE_SEND_TAG(ucp_tag, tag, rank_, comm_id);
ucc_coll_req_h request = comm->send_nb(
eps[dstRank],
tensor.data_ptr(),
to_ucs_memType(tensor.device().type()),
tensor.numel() * tensor.element_size(),
ucp_tag);
auto work = comm->enqueue_p2p(OpType::SEND, request, "ucc:send");
// TODO: record src, dst ranks and tag
RECORD_COMMS_TRACE(
logger->trace_generator,
work,
OpType::SEND,
this->getRank(),
this->getSize(),
tensors,
tensors);
return work;
#endif
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::recv(
@ -1749,7 +1541,6 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::recv(
auto& tensor = tensors[0];
initComm(tensor.device());
#ifdef USE_ACTIVE_SETS
WorkData* data = new WorkData();
ucc_coll_args_t coll;
coll.tag = tag;
@ -1777,58 +1568,6 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::recv(
tensors,
tensors,
"ucc:recv");
#else
ucp_tag_t ucp_tag, ucp_tag_mask;
TORCH_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, srcRank, comm_id);
ucc_coll_req_h request = comm->recv_nb(
tensor.data_ptr(),
to_ucs_memType(tensor.device().type()),
tensor.numel() * tensor.element_size(),
ucp_tag,
ucp_tag_mask);
auto work = comm->enqueue_p2p(OpType::RECV, request, "ucc:recv");
// TODO: record src, dst ranks and tag
RECORD_COMMS_TRACE(
logger->trace_generator,
work,
OpType::RECV,
this->getRank(),
this->getSize(),
tensors,
tensors);
return work;
#endif
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) {
check_tensor(tensors);
auto& tensor = tensors[0];
initComm(tensor.device());
ucp_tag_t ucp_tag, ucp_tag_mask;
TORCH_UCX_MAKE_RECV_TAG(
ucp_tag, ucp_tag_mask, tag, TORCH_UCX_ANY_SOURCE, comm_id);
ucc_coll_req_h request = comm->recv_nb(
tensor.data_ptr(),
to_ucs_memType(tensor.device().type()),
tensor.numel() * tensor.element_size(),
ucp_tag,
ucp_tag_mask);
auto work = comm->enqueue_p2p(OpType::RECVANYSOURCE, request, "ucc:recv");
// TODO: record dst rank and tag
RECORD_COMMS_TRACE(
logger->trace_generator,
work,
OpType::RECVANYSOURCE,
this->getRank(),
this->getSize(),
tensors,
tensors);
return work;
}
c10::intrusive_ptr<ProcessGroup> ProcessGroupUCC::createProcessGroupUCC(
@ -1847,7 +1586,6 @@ void ProcessGroupUCC::initComm(c10::Device dev) {
}
#endif
comm = Comm::get_comm(comm_id, dev, oob, logger);
comm->ucx_connect_eps(eps, oob);
TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCX library");
comm->ucc_create_team(team, oob);
TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCC library");

View file

@ -24,47 +24,6 @@ namespace c10d {
#define TORCH_UCC_DEVICE_NOT_SET -2
#define TORCH_UCX_MAKE_P2P_TAG(_tag, _rank, _comm) \
((((uint64_t)(_tag)) << TORCH_UCX_TAG_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
(((uint64_t)(_comm)) << TORCH_UCX_COMM_BITS_OFFSET))
#define TORCH_UCX_MAKE_OOB_TAG(_tag, _rank, _comm) \
((((uint64_t)(_tag)) << TORCH_UCX_OOB_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_COMM_BITS_OFFSET))
#define TORCH_UCX_MAKE_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
} while (0)
#define TORCH_UCX_ANY_SOURCE (TORCH_UCX_MAX_RANK - 1)
#define TORCH_UCX_ANY_SOURCE_MASK (~TORCH_UCX_RANK_MASK)
#define TORCH_UCX_SPECIFIC_SOURCE_MASK ((uint64_t)-1)
#define TORCH_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
if ((_rank) == TORCH_UCX_ANY_SOURCE) { \
(_ucp_tag_mask) = TORCH_UCX_ANY_SOURCE_MASK; \
} else { \
(_ucp_tag_mask) = TORCH_UCX_SPECIFIC_SOURCE_MASK; \
} \
} while (0)
#define TORCH_UCX_MAKE_OOB_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
} while (0)
#define TORCH_UCX_MAKE_OOB_RECV_TAG( \
_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
(_ucp_tag_mask) = (uint64_t)-1; \
} while (0)
#ifdef USE_CUDA
#define SAVE_TENSORS(_TENSORS, _DATA) \
do { \
@ -84,8 +43,6 @@ namespace c10d {
constexpr const char* UCC_BACKEND_NAME = "ucc";
enum torch_ucx_tag_type_t { TORCH_UCX_P2P_TAG, TORCH_UCX_OOB_TAG };
struct event_pool_t {
#ifdef USE_CUDA
std::queue<std::unique_ptr<at::cuda::CUDAEvent>> event_pool;
@ -173,10 +130,12 @@ class TORCH_API ProcessGroupUCC : public ProcessGroup {
bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
std::vector<at::Tensor> result() override;
int sourceRank() const override;
#ifdef USE_CUDA
std::unique_ptr<at::cuda::CUDAEvent> fence = nullptr;
event_pool_t* ep = nullptr;
#endif
int sourceRank_;
protected:
std::shared_ptr<ProgressEntry> entry_;
c10::intrusive_ptr<ProcessGroupUCCLogger> logger_;
@ -293,10 +252,6 @@ class TORCH_API ProcessGroupUCC : public ProcessGroup {
int srcRank,
int tag) override;
c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) override;
static c10::intrusive_ptr<ProcessGroup> createProcessGroupUCC(
const c10::intrusive_ptr<::c10d::Store>& store,
int rank,
@ -308,7 +263,6 @@ class TORCH_API ProcessGroupUCC : public ProcessGroup {
std::shared_ptr<torch_ucc_oob_coll_info_t> oob;
std::shared_ptr<Comm> comm = {nullptr};
uint32_t comm_id;
std::vector<ucp_ep_h> eps;
ucc_team_h team{nullptr};
ucc_ee_h cuda_ee{nullptr};
#ifdef USE_CUDA
@ -321,7 +275,6 @@ class TORCH_API ProcessGroupUCC : public ProcessGroup {
class Comm {
c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
std::shared_ptr<torch_ucc_oob_coll_info_t> oob;
CommUCX ucx_comm;
CommUCC ucc_comm;
std::mutex mutex;
std::thread progress_thread;
@ -342,16 +295,6 @@ class Comm {
~Comm();
// Connects UCX end points.
void ucx_connect_eps(
std::vector<ucp_ep_h>& eps,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
// Disconnects UCX end points.
void ucx_disconnect_eps(
std::vector<ucp_ep_h>& eps,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
void ucc_create_team(
ucc_team_h& team,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
@ -386,20 +329,6 @@ class Comm {
bool is_health_check = false);
void progress_loop();
ucc_coll_req_h send_nb(
ucp_ep_h ep,
void* data,
ucs_memory_type_t mtype,
size_t size,
ucp_tag_t ucp_tag);
ucc_coll_req_h recv_nb(
void* data,
ucs_memory_type_t mtype,
size_t size,
ucp_tag_t ucp_tag,
ucp_tag_t ucp_tag_mask);
};
} // namespace c10d

View file

@ -17,75 +17,6 @@ constexpr char kAllGatherDone[] = "ag_done";
constexpr char kAllGatherFree[] = "ag_free";
} // namespace
CommUCX::CommUCX(
int comm_size,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger)
: CommBase(logger) {
ucp_params_t params;
ucp_config_t* config;
ucs_status_t st;
ucp_worker_params_t worker_params;
ucp_lib_attr_t ucp_attr;
ucp_attr.field_mask = UCP_LIB_ATTR_FIELD_MAX_THREAD_LEVEL;
TORCH_UCX_CHECK(
ucp_lib_query(&ucp_attr), "failed to query UCP lib attributes");
TORCH_CHECK(
ucp_attr.max_thread_level == UCS_THREAD_MODE_MULTI,
"ucx library wasn't initialized with multithreading support, "
"please check ucx build options");
TORCH_UCX_CHECK(
ucp_config_read("TORCH", nullptr, &config), "failed to read UCP config");
memset(&params, 0, sizeof(ucp_params_t));
params.field_mask = UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_REQUEST_SIZE |
UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | UCP_PARAM_FIELD_TAG_SENDER_MASK |
UCP_PARAM_FIELD_REQUEST_INIT | UCP_PARAM_FIELD_REQUEST_CLEANUP;
params.request_size = sizeof(ucc_coll_req_t);
params.features = UCP_FEATURE_TAG;
params.estimated_num_eps = comm_size;
params.tag_sender_mask = TORCH_UCX_RANK_MASK;
params.request_init = [](void* request) {
static_cast<ucc_coll_req_h>(request)->status = UCC_INPROGRESS;
};
params.request_cleanup = [](void*) {};
TORCH_UCX_CHECK(
ucp_init(&params, config, &context), "failed to init UCP context");
ucp_config_release(config);
memset(&worker_params, 0, sizeof(ucp_worker_params_t));
worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
worker_params.thread_mode = UCS_THREAD_MODE_MULTI;
st = ucp_worker_create(context, &worker_params, &worker);
if (st != UCS_OK) {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_INIT,
c10::str("UCX failed to create UCP worker:", ucs_status_string(st)));
ucp_cleanup(context);
throw std::runtime_error(ucs_status_string(st));
}
}
void CommUCX::progress() {
ucp_worker_progress(worker);
}
void CommUCX::free_request(ucc_coll_req_h request) {
request->status = UCC_INPROGRESS;
ucp_request_free(request);
}
CommUCX::~CommUCX() {
if (worker != nullptr) {
ucp_worker_destroy(worker);
}
if (context != nullptr) {
ucp_cleanup(context);
}
worker = nullptr;
context = nullptr;
}
ucc_status_t oob_allgather(
void* sbuf,
void* rbuf,

View file

@ -5,28 +5,6 @@
#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
#include <ucc/api/ucc.h>
#include <ucp/api/ucp.h>
#define TORCH_UCX_COMM_BITS 15
#define TORCH_UCX_RANK_BITS 16
#define TORCH_UCX_TAG_BITS 32
#define TORCH_UCX_OOB_BITS 1
#define TORCH_UCX_COMM_BITS_OFFSET 0
#define TORCH_UCX_RANK_BITS_OFFSET TORCH_UCX_COMM_BITS
#define TORCH_UCX_TAG_BITS_OFFSET (TORCH_UCX_COMM_BITS + TORCH_UCX_RANK_BITS)
#define TORCH_UCX_OOB_BITS_OFFSET \
(TORCH_UCX_COMM_BITS + TORCH_UCX_RANK_BITS + TORCH_UCX_TAG_BITS)
#define TORCH_UCX_MAX_COMM ((((uint64_t)1) << TORCH_UCX_COMM_BITS) - 1)
#define TORCH_UCX_MAX_RANK ((((uint64_t)1) << TORCH_UCX_RANK_BITS) - 1)
#define TORCH_UCX_MAX_TAG ((((uint64_t)1) << TORCH_UCX_TAG_BITS) - 1)
#define TORCH_UCX_MAX_OOB ((((uint64_t)1) << TORCH_UCX_OOB_BITS) - 1)
#define TORCH_UCX_COMM_MASK (TORCH_UCX_MAX_COMM << TORCH_UCX_COMM_BITS_OFFSET)
#define TORCH_UCX_RANK_MASK (TORCH_UCX_MAX_RANK << TORCH_UCX_RANK_BITS_OFFSET)
#define TORCH_UCX_TAG_MASK (TORCH_UCX_MAX_TAG << TORCH_UCX_TAG_BITS_OFFSET)
#define TORCH_UCX_OOB_MASK (TORCH_UCX_MAX_OOB << TORCH_UCX_OOB_BITS_OFFSET)
namespace c10d {
@ -53,29 +31,6 @@ namespace c10d {
} \
} while (0)
// Macro to throw on a non-successful UCX return value.
#define TORCH_UCX_CHECK(_cmd, _error_msg) \
do { \
ucs_status_t result = _cmd; \
if (result != UCS_OK) { \
std::string err = c10::str( \
"[", \
std::string(__FILE__), \
":", \
std::to_string(__LINE__), \
"] ", \
logger->getLogPrefix(), \
_error_msg, \
", error code ", \
result, \
": ", \
ucs_status_string(result), \
", system error code ", \
errno); \
TORCH_CHECK(false, err); \
} \
} while (0)
// Macros to print logs with unified format
#define TORCH_UCC_LOG_ERROR(_phase, _msg) \
LOG(ERROR) << logger->getLogPrefix(_phase) << "[ERROR] " << _msg;
@ -148,21 +103,6 @@ class CommBase {
virtual ~CommBase() {}
c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
};
class CommUCX : public CommBase {
public:
ucp_context_h context{nullptr};
ucp_worker_h worker{nullptr};
public:
void progress() override;
void free_request(ucc_coll_req_h request) override;
CommUCX(
int comm_size,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger);
~CommUCX();
};
class CommUCC : public CommBase {
public:
ucc_lib_h lib{nullptr};

View file

@ -73,17 +73,17 @@ TEST_SKIPS = {
class DistTestCases:
# Backends that do not support a specific collective
skip_collective = {}
skip_collective["allgather_coalesced"] = {"nccl", "mpi"}
skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc"}
skip_collective["reduce"] = set()
skip_collective["sendrecv anysource"] = {"nccl"}
skip_collective["cpu barrier"] = {"nccl"}
skip_collective["sendrecv anysource"] = {"nccl", "ucc"}
skip_collective["cpu barrier"] = {"nccl", "ucc"}
# Sets showing that something is implemented
backend_feature = {}
backend_feature["gpu"] = {"nccl", "gloo"}
backend_feature["cuda"] = {"nccl", "gloo"}
backend_feature["ddp"] = {"nccl", "gloo"}
backend_feature["subgroup"] = {"nccl", "gloo"}
backend_feature["gpu"] = {"nccl", "gloo"} # TODO(ucc): add sequence number support to ucc and enable it here
backend_feature["cuda"] = {"nccl", "gloo", "ucc"}
backend_feature["ddp"] = {"nccl", "gloo", "ucc"}
backend_feature["subgroup"] = {"nccl", "gloo", "ucc"}
backend_feature["plugin"] = set()

View file

@ -143,6 +143,7 @@ PROFILING_SUPPORTED_BACKENDS = [
dist.Backend.NCCL,
dist.Backend.GLOO,
dist.Backend.MPI,
# TODO(ucc): dist.Backend.UCC,
]
# Allowlist of distributed backends where profiling is supported with use_cuda=True
@ -150,6 +151,7 @@ CUDA_PROFILING_SUPPORTED_BACKENDS = [
dist.Backend.GLOO,
dist.Backend.MPI,
dist.Backend.NCCL,
# TODO(ucc): dist.Backend.UCC,
]
# Allowlist of distributed backends where profiling is supported for p2p ops
@ -157,6 +159,7 @@ SEND_RECV_PROFILING_SUPPORTED_BACKENDS = [
dist.Backend.MPI,
dist.Backend.GLOO,
dist.Backend.NCCL,
# TODO(ucc): dist.Backend.UCC,
]
# Dummy NamedTuple data structures to test DDP support for NamedTuple types.
@ -395,6 +398,8 @@ def require_backends_available(backends):
return dist.is_nccl_available()
if backend == dist.Backend.MPI:
return dist.is_mpi_available()
if backend == dist.Backend.UCC:
return dist.is_ucc_available()
if backend in DistTestCases.backend_feature["plugin"]:
return True
return False
@ -2913,6 +2918,7 @@ class DistributedTest:
self._barrier()
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
def test_scatter_checks(self):
group, group_id, rank = self._init_global_test()
one = torch.ones([1])
@ -2936,6 +2942,7 @@ class DistributedTest:
self.assertEqual(output, one * rank)
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
def test_scatter(self):
group, group_id, rank = self._init_global_test()
self._test_scatter_helper(group, group_id, rank)
@ -2948,6 +2955,7 @@ class DistributedTest:
self._test_scatter_helper(group, group_id, rank, True, rank_to_GPU)
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
def test_scatter_complex(self):
group, group_id, rank = self._init_global_test()
self._test_scatter_helper(group, group_id, rank, dtype=torch.cfloat)
@ -2960,12 +2968,14 @@ class DistributedTest:
self._test_scatter_helper(group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat)
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
@skip_if_small_worldsize
def test_scatter_group(self):
group, group_id, rank = self._init_group_test()
self._test_scatter_helper(group, group_id, rank)
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
def test_scatter_full_group(self):
group, group_id, rank = self._init_full_group_test()
self._test_scatter_helper(group, group_id, rank)
@ -2999,6 +3009,7 @@ class DistributedTest:
self._barrier()
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
def test_gather_checks(self):
group, group_id, rank = self._init_global_test()
one = torch.ones([1])
@ -3022,6 +3033,7 @@ class DistributedTest:
dist.gather(one * rank)
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
def test_gather(self):
group, group_id, rank = self._init_global_test()
self._test_gather_helper(group, group_id, rank)
@ -3034,12 +3046,14 @@ class DistributedTest:
self._test_gather_helper(group, group_id, rank, True, rank_to_GPU)
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
@skip_if_small_worldsize
def test_gather_group(self):
group, group_id, rank = self._init_group_test()
self._test_gather_helper(group, group_id, rank)
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
def test_gather_full_group(self):
group, group_id, rank = self._init_full_group_test()
self._test_gather_helper(group, group_id, rank)
@ -3611,6 +3625,7 @@ class DistributedTest:
@skip_if_no_gpu
@sandcastle_skip_if(BACKEND == "mpi", "MPI doesn't supports GPU barrier")
@sandcastle_skip_if(BACKEND == "ucc", "flaky on PyTorch CI with timeout")
def test_barrier_cuda(self):
group, group_id, rank = self._init_global_test()
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
@ -4336,7 +4351,7 @@ class DistributedTest:
dist.barrier()
@sandcastle_skip_if(
BACKEND == "nccl",
BACKEND == "nccl" or BACKEND == "ucc",
"Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259"
)
@skip_if_lt_x_gpu(2)
@ -4363,7 +4378,7 @@ class DistributedTest:
)
@sandcastle_skip_if(
BACKEND == "nccl",
BACKEND == "nccl" or BACKEND == "ucc",
"Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259"
)
@skip_if_lt_x_gpu(2)
@ -4383,7 +4398,7 @@ class DistributedTest:
)
@sandcastle_skip_if(
BACKEND == "nccl",
BACKEND == "nccl" or BACKEND == "ucc",
"Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259"
)
@skip_if_lt_x_gpu(2)
@ -9018,6 +9033,10 @@ class DistributedTest:
BACKEND not in DistTestCases.backend_feature["cuda"],
f"The {BACKEND} backend does not support DDP communication hook on CUDA devices"
)
@sandcastle_skip_if(
BACKEND == "ucc",
"flaky on PyTorch CI: No such file or directory: '/tmp/checkpoint.pt'"
)
@skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
def test_ddp_hook_pickling_powerSGD(self):