mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
ProcessGroupUCC tests (#83285)
- [x] Direct dependency on UCX is completely removed, UCC active set API always enabled - [x] Remove `TORCH_UCC_PROFILING_ENABLE`, always enable profiling - [x] Fixes profiling of `recv` and `all_gather` - [x] Use the NCCL TL of UCC on CUDA, as the UCP TL is not well supported on CUDA Most tests are passing, but there are a few skipped tests: - `scatter` and `gather` are not supported by the UCP TL of UCC on CPU tensors - A few flaky tests in PyTorch's CI environment - Profiler-related failures, some of them will be fixed by @Fuzzkatt in https://github.com/pytorch/pytorch/pull/84368 After this PR is merged, I will continue to work on these skipped failures. Pull Request resolved: https://github.com/pytorch/pytorch/pull/83285 Approved by: https://github.com/vtlam, https://github.com/malfet, https://github.com/kwen2501
This commit is contained in:
parent
2765243cd5
commit
08c4f8c7a7
10 changed files with 64 additions and 494 deletions
|
|
@ -36,7 +36,7 @@ function install_ucc() {
|
|||
git submodule update --init --recursive
|
||||
|
||||
./autogen.sh
|
||||
./configure --prefix=$UCC_HOME --with-ucx=$UCX_HOME --with-nccl=no --with-cuda=$with_cuda
|
||||
./configure --prefix=$UCC_HOME --with-ucx=$UCX_HOME --with-cuda=$with_cuda
|
||||
time make -j
|
||||
sudo make install
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ if NO_MULTIPROCESSING_SPAWN:
|
|||
|
||||
BACKEND = os.environ["BACKEND"]
|
||||
|
||||
if BACKEND == "gloo" or BACKEND == "nccl":
|
||||
if BACKEND == "gloo" or BACKEND == "nccl" or BACKEND == "ucc":
|
||||
class TestDistBackendWithSpawn(TestDistBackend, DistributedTest._DistTestBase):
|
||||
|
||||
def setUp(self):
|
||||
|
|
|
|||
|
|
@ -299,6 +299,14 @@ if dist.is_available():
|
|||
"WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3",
|
||||
"TEST_REPORT_SOURCE_OVERRIDE": "dist-gloo",
|
||||
}
|
||||
if dist.is_ucc_available():
|
||||
DISTRIBUTED_TESTS_CONFIG["ucc"] = {
|
||||
"WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3",
|
||||
"TEST_REPORT_SOURCE_OVERRIDE": "dist-ucc",
|
||||
"UCX_TLS": "tcp",
|
||||
"UCC_TLS": "nccl,ucp",
|
||||
"UCC_TL_UCP_TUNE": "cuda:0", # don't use UCP TL on CUDA as it is not well supported
|
||||
}
|
||||
|
||||
# https://stackoverflow.com/questions/2549939/get-signal-names-from-numbers-in-python
|
||||
SIGNALS_TO_NAMES_DICT = {
|
||||
|
|
|
|||
|
|
@ -444,6 +444,11 @@ target_compile_options(torch_python PRIVATE ${TORCH_PYTHON_COMPILE_OPTIONS})
|
|||
|
||||
target_include_directories(torch_python PUBLIC ${TORCH_PYTHON_INCLUDE_DIRECTORIES})
|
||||
|
||||
if(USE_UCC)
|
||||
target_link_libraries(torch_python PRIVATE __caffe2_ucc)
|
||||
target_compile_definitions(torch_python PRIVATE USE_UCC)
|
||||
endif()
|
||||
|
||||
if(BUILD_ONEDNN_GRAPH)
|
||||
target_compile_definitions(torch_python PRIVATE "-DBUILD_ONEDNN_GRAPH")
|
||||
target_compile_definitions(torch_cpu PRIVATE "-DBUILD_ONEDNN_GRAPH")
|
||||
|
|
|
|||
|
|
@ -13,18 +13,6 @@ namespace c10d {
|
|||
namespace {
|
||||
constexpr int64_t kBusyWaitMillis = 10;
|
||||
|
||||
const std::map<c10::DeviceType, ucs_memory_type_t> ucs_mtype_map = {
|
||||
{c10::kCPU, UCS_MEMORY_TYPE_HOST},
|
||||
{c10::kCUDA, UCS_MEMORY_TYPE_CUDA},
|
||||
};
|
||||
|
||||
ucs_memory_type_t to_ucs_memType(c10::DeviceType _c10_type) {
|
||||
if (ucs_mtype_map.find(_c10_type) != ucs_mtype_map.end())
|
||||
return ucs_mtype_map.at(_c10_type);
|
||||
else
|
||||
return UCS_MEMORY_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
const std::map<c10::DeviceType, ucc_memory_type_t> ucc_mtype_map = {
|
||||
{c10::kCPU, UCC_MEMORY_TYPE_HOST},
|
||||
{c10::kCUDA, UCC_MEMORY_TYPE_CUDA},
|
||||
|
|
@ -97,7 +85,6 @@ ucc_reduction_op_t to_ucc_reduceOp(
|
|||
struct torch_ucc_config_t {
|
||||
c10::once_flag flag;
|
||||
std::array<bool, 32> blocking_wait;
|
||||
bool enable_profiling;
|
||||
bool enable_comms_logger;
|
||||
bool use_future;
|
||||
// Sharing UCC communicator among multiple PGs to save resource.
|
||||
|
|
@ -163,7 +150,6 @@ std::vector<OpType> parse_blocking_wait(std::string op_list_string) {
|
|||
void read_config() {
|
||||
// default configuration
|
||||
torch_ucc_config.blocking_wait.fill(false);
|
||||
torch_ucc_config.enable_profiling = false;
|
||||
torch_ucc_config.use_future = true;
|
||||
torch_ucc_config.shared_comm = false;
|
||||
torch_ucc_config.use_allgatherv = false;
|
||||
|
|
@ -186,8 +172,6 @@ void read_config() {
|
|||
|
||||
torch_ucc_config.use_future =
|
||||
std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_FUTURE"));
|
||||
torch_ucc_config.enable_profiling =
|
||||
std::stoi(torch_ucc_envs_map.at("TORCH_UCC_PROFILING_ENABLE"));
|
||||
torch_ucc_config.shared_comm =
|
||||
std::stoi(torch_ucc_envs_map.at("TORCH_UCC_SHARED_COMM"));
|
||||
torch_ucc_config.use_allgatherv =
|
||||
|
|
@ -294,6 +278,14 @@ c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupUCC::WorkUCC::getFuture() {
|
|||
return future_;
|
||||
}
|
||||
|
||||
int ProcessGroupUCC::WorkUCC::sourceRank() const {
|
||||
if (opType_ != OpType::RECV && opType_ != OpType::RECVANYSOURCE) {
|
||||
// Throw an error
|
||||
return ProcessGroup::Work::sourceRank();
|
||||
}
|
||||
return sourceRank_;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> ProcessGroupUCC::WorkUCC::result() {
|
||||
return *outputs_;
|
||||
}
|
||||
|
|
@ -327,7 +319,6 @@ Comm::Comm(
|
|||
bool is_health_check)
|
||||
: logger(logger_),
|
||||
oob(oob_),
|
||||
ucx_comm(oob->size, logger),
|
||||
ucc_comm(oob, logger),
|
||||
finalize_phase(
|
||||
is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_FINALIZE),
|
||||
|
|
@ -364,7 +355,7 @@ std::shared_ptr<Comm> Comm::get_comm(
|
|||
static uint32_t comm_id;
|
||||
|
||||
std::lock_guard<std::mutex> lock(m);
|
||||
id = (comm_id % TORCH_UCX_MAX_COMM);
|
||||
id = comm_id;
|
||||
|
||||
std::string group_id = "group_id";
|
||||
if (is_health_check) {
|
||||
|
|
@ -419,126 +410,6 @@ std::shared_ptr<Comm> Comm::get_comm(
|
|||
}
|
||||
}
|
||||
|
||||
void Comm::ucx_connect_eps(
|
||||
std::vector<ucp_ep_h>& eps,
|
||||
std::shared_ptr<torch_ucc_oob_coll_info_t> oob) {
|
||||
ucp_address_t* local_addr;
|
||||
size_t local_addr_len;
|
||||
std::vector<uint8_t> peer_addr;
|
||||
|
||||
TORCH_UCX_CHECK(
|
||||
ucp_worker_get_address(ucx_comm.worker, &local_addr, &local_addr_len),
|
||||
"failed to get worker address");
|
||||
|
||||
std::vector<uint8_t> val = std::vector<uint8_t>(
|
||||
reinterpret_cast<uint8_t*>(local_addr),
|
||||
reinterpret_cast<uint8_t*>(local_addr) + local_addr_len);
|
||||
oob->store->set(oob->getKey("wa" + std::to_string(oob->rank)), val);
|
||||
ucp_worker_release_address(ucx_comm.worker, local_addr);
|
||||
eps.resize(oob->size);
|
||||
for (int i = 0; i < oob->size; i++) {
|
||||
peer_addr = oob->store->get(oob->getKey("wa" + std::to_string(i)));
|
||||
ucp_ep_params_t ep_params;
|
||||
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
|
||||
ep_params.address = reinterpret_cast<ucp_address_t*>(peer_addr.data());
|
||||
TORCH_UCX_CHECK(
|
||||
ucp_ep_create(ucx_comm.worker, &ep_params, &(eps[i])),
|
||||
c10::str("failed to create endpoint with rank ", i));
|
||||
}
|
||||
}
|
||||
|
||||
void Comm::ucx_disconnect_eps(
|
||||
std::vector<ucp_ep_h>& eps,
|
||||
std::shared_ptr<torch_ucc_oob_coll_info_t> oob) {
|
||||
ucs_status_t st;
|
||||
|
||||
for (ucp_ep_h& ep : eps) {
|
||||
ucs_status_ptr_t close_req = ucp_ep_close_nb(ep, UCP_EP_CLOSE_MODE_FLUSH);
|
||||
if (UCS_PTR_IS_ERR(close_req)) {
|
||||
TORCH_UCC_LOG_ERROR(
|
||||
finalize_phase, "failed to close endpoint, ignore and continue...");
|
||||
return;
|
||||
}
|
||||
if (UCS_PTR_IS_PTR(close_req)) {
|
||||
do {
|
||||
ucp_worker_progress(ucx_comm.worker);
|
||||
st = ucp_request_check_status(close_req);
|
||||
} while (st != UCS_OK);
|
||||
ucp_request_free(close_req);
|
||||
}
|
||||
}
|
||||
if (!eps.size()) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
auto sz = (size_t)oob->store->add(oob->getKey("epclosed"), 1);
|
||||
while (sz != eps.size()) {
|
||||
ucp_worker_progress(ucx_comm.worker);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kBusyWaitMillis));
|
||||
sz = (size_t)oob->store->add(oob->getKey("epclosed"), 0);
|
||||
}
|
||||
} catch (std::exception& ex) {
|
||||
LOG(ERROR) << "(disconnect_eps) Caught error in Store Operation .. "
|
||||
<< "[" << ex.what() << "]";
|
||||
}
|
||||
}
|
||||
|
||||
ucc_coll_req_h Comm::send_nb(
|
||||
ucp_ep_h ep,
|
||||
void* data,
|
||||
ucs_memory_type_t mtype,
|
||||
size_t size,
|
||||
ucp_tag_t ucp_tag) {
|
||||
ucs_status_ptr_t st;
|
||||
ucp_request_param_t params;
|
||||
params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
|
||||
UCP_OP_ATTR_FIELD_DATATYPE | UCP_OP_ATTR_FIELD_MEMORY_TYPE;
|
||||
params.datatype = ucp_dt_make_contig(size);
|
||||
params.memory_type = mtype;
|
||||
params.cb.send = [](void* request, ucs_status_t status, void* user_data) {
|
||||
static_cast<ucc_coll_req_h>(request)->status = UCC_OK;
|
||||
};
|
||||
st = ucp_tag_send_nbx(ep, data, 1, ucp_tag, ¶ms);
|
||||
if (UCS_PTR_IS_ERR(st)) {
|
||||
TORCH_UCC_LOG_ERROR(
|
||||
TORCH_UCC_COLL_POST,
|
||||
c10::str(
|
||||
"failed to send message: ", ucs_status_string(UCS_PTR_STATUS(st))));
|
||||
throw std::runtime_error(ucs_status_string(UCS_PTR_STATUS(st)));
|
||||
}
|
||||
return reinterpret_cast<ucc_coll_req_h>(st);
|
||||
}
|
||||
|
||||
ucc_coll_req_h Comm::recv_nb(
|
||||
void* data,
|
||||
ucs_memory_type_t mtype,
|
||||
size_t size,
|
||||
ucp_tag_t ucp_tag,
|
||||
ucp_tag_t ucp_tag_mask) {
|
||||
ucs_status_ptr_t st;
|
||||
ucp_request_param_t params;
|
||||
params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
|
||||
UCP_OP_ATTR_FIELD_DATATYPE | UCP_OP_ATTR_FIELD_MEMORY_TYPE;
|
||||
params.datatype = ucp_dt_make_contig(size);
|
||||
params.cb.recv = [](void* request,
|
||||
ucs_status_t status,
|
||||
const ucp_tag_recv_info_t* info,
|
||||
void* user_data) {
|
||||
static_cast<ucc_coll_req_h>(request)->status = UCC_OK;
|
||||
};
|
||||
params.memory_type = mtype;
|
||||
st = ucp_tag_recv_nbx(
|
||||
ucx_comm.worker, data, 1, ucp_tag, ucp_tag_mask, ¶ms);
|
||||
if (UCS_PTR_IS_ERR(st)) {
|
||||
TORCH_UCC_LOG_ERROR(
|
||||
TORCH_UCC_COLL_POST,
|
||||
c10::str(
|
||||
"failed to recv message: ", ucs_status_string(UCS_PTR_STATUS(st))));
|
||||
throw std::runtime_error(ucs_status_string(UCS_PTR_STATUS(st)));
|
||||
}
|
||||
return reinterpret_cast<ucc_coll_req_h>(st);
|
||||
}
|
||||
|
||||
void Comm::ucc_create_team(
|
||||
ucc_team_h& team,
|
||||
std::shared_ptr<torch_ucc_oob_coll_info_t> oob) {
|
||||
|
|
@ -582,34 +453,6 @@ void Comm::ucc_destroy_team(ucc_team_h& team) {
|
|||
lock.unlock();
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> Comm::enqueue_p2p(
|
||||
OpType opType,
|
||||
ucc_coll_req_h request,
|
||||
const char* prof_title) {
|
||||
auto work =
|
||||
c10::make_intrusive<ProcessGroupUCC::WorkUCC>(opType, prof_title, logger);
|
||||
if (torch_ucc_config.use_future) {
|
||||
work->future_ = c10::make_intrusive<at::ivalue::Future>(
|
||||
c10::ListType::create(c10::TensorType::get()));
|
||||
}
|
||||
if (request == nullptr) {
|
||||
// p2p2 request completed immediately don't save it to progress queue
|
||||
// and mark future completed immediately
|
||||
if (torch_ucc_config.use_future) {
|
||||
work->future_->markCompleted(c10::IValue(std::vector<at::Tensor>()));
|
||||
}
|
||||
return work;
|
||||
}
|
||||
auto entry =
|
||||
std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucx_comm, request);
|
||||
work->entry_ = entry;
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
progress_queue.push_back(entry);
|
||||
lock.unlock();
|
||||
queue_produce_cv.notify_one();
|
||||
return work;
|
||||
}
|
||||
|
||||
void Comm::enqueue_collective(
|
||||
std::unique_ptr<ProcessGroupUCC::WorkData> data,
|
||||
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
|
||||
|
|
@ -688,7 +531,6 @@ void Comm::progress_loop() {
|
|||
try {
|
||||
while (work->request_->status > 0) {
|
||||
ucc_comm.progress();
|
||||
ucx_comm.progress();
|
||||
}
|
||||
if (work->request_->status < 0) {
|
||||
eptr = std::make_exception_ptr(
|
||||
|
|
@ -724,7 +566,7 @@ ProcessGroupUCC::ProcessGroupUCC(
|
|||
comm = nullptr;
|
||||
cuda_ee = nullptr;
|
||||
static uint32_t id = 0;
|
||||
uint32_t pg_id = (id++ % TORCH_UCX_MAX_COMM);
|
||||
uint32_t pg_id = id++;
|
||||
|
||||
logger = c10::make_intrusive<ProcessGroupUCCLogger>(
|
||||
c10::str("[Rank ", rank_, "]", "[ProcessGroupUCC-", pg_id, "]"),
|
||||
|
|
@ -769,20 +611,10 @@ ProcessGroupUCC::~ProcessGroupUCC() {
|
|||
comm->ucc_destroy_team(team);
|
||||
TORCH_UCC_LOG_INFO(
|
||||
TORCH_UCC_FINALIZE, "Successfully destroyed UCC library");
|
||||
comm->ucx_disconnect_eps(eps, oob);
|
||||
TORCH_UCC_LOG_INFO(
|
||||
TORCH_UCC_FINALIZE, "Successfully destroyed UCX library");
|
||||
try {
|
||||
if (cuda_ee) {
|
||||
ucc_ee_destroy(cuda_ee);
|
||||
}
|
||||
if ((size_t)oob->store->add(oob->getKey("ucc_pg_closed"), 1) ==
|
||||
eps.size()) {
|
||||
std::vector<uint8_t> val = {1};
|
||||
oob->store->set(oob->getKey("ucc_pg_finished"), val);
|
||||
} else {
|
||||
oob->store->wait({oob->getKey("ucc_pg_finished")});
|
||||
}
|
||||
} catch (std::exception& ex) {
|
||||
TORCH_UCC_LOG_INFO(
|
||||
TORCH_UCC_FINALIZE,
|
||||
|
|
@ -817,7 +649,6 @@ void ProcessGroupUCC::runHealthCheck() {
|
|||
struct HealthCheckData {
|
||||
std::mutex healthCheckMutex;
|
||||
std::condition_variable healthCheckCv;
|
||||
bool ucxHealthCheckSuccess = false;
|
||||
bool uccHealthCheckSuccess = false;
|
||||
std::exception_ptr healthCheckException;
|
||||
} healthCheckData;
|
||||
|
|
@ -837,8 +668,6 @@ void ProcessGroupUCC::runHealthCheck() {
|
|||
oob->rank = this->oob->rank;
|
||||
oob->size = this->oob->size;
|
||||
oob->store = this->oob->store;
|
||||
|
||||
std::vector<ucp_ep_h> eps;
|
||||
ucc_team_h team = nullptr;
|
||||
uint32_t comm_id;
|
||||
#ifdef USE_CUDA
|
||||
|
|
@ -847,19 +676,6 @@ void ProcessGroupUCC::runHealthCheck() {
|
|||
}
|
||||
#endif
|
||||
auto comm = Comm::get_comm(comm_id, device, oob, logger, true);
|
||||
comm->ucx_connect_eps(eps, oob);
|
||||
comm->ucx_disconnect_eps(eps, oob);
|
||||
TORCH_UCC_LOG_INFO(
|
||||
TORCH_UCC_HEALTH_CHECK,
|
||||
c10::str(
|
||||
"UCX library health check succeed for device ",
|
||||
c10::DeviceTypeName(device.type())));
|
||||
// Mark ucx health check as complete.
|
||||
if (is_last_device) {
|
||||
std::lock_guard<std::mutex> lk(healthCheckData.healthCheckMutex);
|
||||
healthCheckData.ucxHealthCheckSuccess = true;
|
||||
}
|
||||
|
||||
comm->ucc_create_team(team, oob);
|
||||
comm->ucc_destroy_team(team);
|
||||
TORCH_UCC_LOG_INFO(
|
||||
|
|
@ -898,18 +714,13 @@ void ProcessGroupUCC::runHealthCheck() {
|
|||
" msec for UCC health check to complete."));
|
||||
std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex);
|
||||
healthCheckData.healthCheckCv.wait_for(lock, timeout_, [&healthCheckData]() {
|
||||
return healthCheckData.ucxHealthCheckSuccess &&
|
||||
healthCheckData.uccHealthCheckSuccess;
|
||||
return healthCheckData.uccHealthCheckSuccess;
|
||||
});
|
||||
|
||||
if (healthCheckData.healthCheckException) {
|
||||
std::rethrow_exception(healthCheckData.healthCheckException);
|
||||
}
|
||||
// If there is no exception, the likely culprit is a timeout/hang
|
||||
TORCH_CHECK(
|
||||
healthCheckData.ucxHealthCheckSuccess,
|
||||
"ProcessGroupUCC: Health check failure: Failed to initialize UCX on rank ",
|
||||
rank_);
|
||||
TORCH_CHECK(
|
||||
healthCheckData.uccHealthCheckSuccess,
|
||||
"ProcessGroupUCC: Health check failure: Failed to initialize UCC on rank ",
|
||||
|
|
@ -948,8 +759,12 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::collective_post(
|
|||
std::vector<at::Tensor>& outputTensors,
|
||||
const char* prof_title) {
|
||||
set_timeout(coll);
|
||||
auto work = c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
|
||||
opType, torch_ucc_config.enable_profiling ? prof_title : nullptr, logger);
|
||||
auto work =
|
||||
c10::make_intrusive<ProcessGroupUCC::WorkUCC>(opType, prof_title, logger);
|
||||
|
||||
if (opType == OpType::RECV) {
|
||||
work->sourceRank_ = coll.root;
|
||||
}
|
||||
|
||||
RECORD_COMMS_TRACE(
|
||||
logger->trace_generator,
|
||||
|
|
@ -1106,7 +921,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::allgather(
|
|||
tensor.device(),
|
||||
inputTensors,
|
||||
outputTensors[0],
|
||||
"ucc:allgather");
|
||||
"ucc:all_gather");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1689,7 +1504,6 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::send(
|
|||
auto& tensor = tensors[0];
|
||||
initComm(tensor.device());
|
||||
|
||||
#ifdef USE_ACTIVE_SETS
|
||||
WorkData* data = new WorkData();
|
||||
ucc_coll_args_t coll;
|
||||
coll.tag = tag;
|
||||
|
|
@ -1717,28 +1531,6 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::send(
|
|||
tensors,
|
||||
tensors,
|
||||
"ucc:send");
|
||||
#else
|
||||
ucp_tag_t ucp_tag;
|
||||
TORCH_UCX_MAKE_SEND_TAG(ucp_tag, tag, rank_, comm_id);
|
||||
ucc_coll_req_h request = comm->send_nb(
|
||||
eps[dstRank],
|
||||
tensor.data_ptr(),
|
||||
to_ucs_memType(tensor.device().type()),
|
||||
tensor.numel() * tensor.element_size(),
|
||||
ucp_tag);
|
||||
|
||||
auto work = comm->enqueue_p2p(OpType::SEND, request, "ucc:send");
|
||||
// TODO: record src, dst ranks and tag
|
||||
RECORD_COMMS_TRACE(
|
||||
logger->trace_generator,
|
||||
work,
|
||||
OpType::SEND,
|
||||
this->getRank(),
|
||||
this->getSize(),
|
||||
tensors,
|
||||
tensors);
|
||||
return work;
|
||||
#endif
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::recv(
|
||||
|
|
@ -1749,7 +1541,6 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::recv(
|
|||
auto& tensor = tensors[0];
|
||||
initComm(tensor.device());
|
||||
|
||||
#ifdef USE_ACTIVE_SETS
|
||||
WorkData* data = new WorkData();
|
||||
ucc_coll_args_t coll;
|
||||
coll.tag = tag;
|
||||
|
|
@ -1777,58 +1568,6 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::recv(
|
|||
tensors,
|
||||
tensors,
|
||||
"ucc:recv");
|
||||
#else
|
||||
ucp_tag_t ucp_tag, ucp_tag_mask;
|
||||
TORCH_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, srcRank, comm_id);
|
||||
ucc_coll_req_h request = comm->recv_nb(
|
||||
tensor.data_ptr(),
|
||||
to_ucs_memType(tensor.device().type()),
|
||||
tensor.numel() * tensor.element_size(),
|
||||
ucp_tag,
|
||||
ucp_tag_mask);
|
||||
|
||||
auto work = comm->enqueue_p2p(OpType::RECV, request, "ucc:recv");
|
||||
// TODO: record src, dst ranks and tag
|
||||
RECORD_COMMS_TRACE(
|
||||
logger->trace_generator,
|
||||
work,
|
||||
OpType::RECV,
|
||||
this->getRank(),
|
||||
this->getSize(),
|
||||
tensors,
|
||||
tensors);
|
||||
return work;
|
||||
#endif
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::recvAnysource(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
int tag) {
|
||||
check_tensor(tensors);
|
||||
auto& tensor = tensors[0];
|
||||
initComm(tensor.device());
|
||||
|
||||
ucp_tag_t ucp_tag, ucp_tag_mask;
|
||||
TORCH_UCX_MAKE_RECV_TAG(
|
||||
ucp_tag, ucp_tag_mask, tag, TORCH_UCX_ANY_SOURCE, comm_id);
|
||||
ucc_coll_req_h request = comm->recv_nb(
|
||||
tensor.data_ptr(),
|
||||
to_ucs_memType(tensor.device().type()),
|
||||
tensor.numel() * tensor.element_size(),
|
||||
ucp_tag,
|
||||
ucp_tag_mask);
|
||||
|
||||
auto work = comm->enqueue_p2p(OpType::RECVANYSOURCE, request, "ucc:recv");
|
||||
// TODO: record dst rank and tag
|
||||
RECORD_COMMS_TRACE(
|
||||
logger->trace_generator,
|
||||
work,
|
||||
OpType::RECVANYSOURCE,
|
||||
this->getRank(),
|
||||
this->getSize(),
|
||||
tensors,
|
||||
tensors);
|
||||
return work;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup> ProcessGroupUCC::createProcessGroupUCC(
|
||||
|
|
@ -1847,7 +1586,6 @@ void ProcessGroupUCC::initComm(c10::Device dev) {
|
|||
}
|
||||
#endif
|
||||
comm = Comm::get_comm(comm_id, dev, oob, logger);
|
||||
comm->ucx_connect_eps(eps, oob);
|
||||
TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCX library");
|
||||
comm->ucc_create_team(team, oob);
|
||||
TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCC library");
|
||||
|
|
|
|||
|
|
@ -24,47 +24,6 @@ namespace c10d {
|
|||
|
||||
#define TORCH_UCC_DEVICE_NOT_SET -2
|
||||
|
||||
#define TORCH_UCX_MAKE_P2P_TAG(_tag, _rank, _comm) \
|
||||
((((uint64_t)(_tag)) << TORCH_UCX_TAG_BITS_OFFSET) | \
|
||||
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
|
||||
(((uint64_t)(_comm)) << TORCH_UCX_COMM_BITS_OFFSET))
|
||||
|
||||
#define TORCH_UCX_MAKE_OOB_TAG(_tag, _rank, _comm) \
|
||||
((((uint64_t)(_tag)) << TORCH_UCX_OOB_BITS_OFFSET) | \
|
||||
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
|
||||
(((uint64_t)(_rank)) << TORCH_UCX_COMM_BITS_OFFSET))
|
||||
|
||||
#define TORCH_UCX_MAKE_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
|
||||
do { \
|
||||
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
|
||||
} while (0)
|
||||
|
||||
#define TORCH_UCX_ANY_SOURCE (TORCH_UCX_MAX_RANK - 1)
|
||||
#define TORCH_UCX_ANY_SOURCE_MASK (~TORCH_UCX_RANK_MASK)
|
||||
#define TORCH_UCX_SPECIFIC_SOURCE_MASK ((uint64_t)-1)
|
||||
|
||||
#define TORCH_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
|
||||
do { \
|
||||
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
|
||||
if ((_rank) == TORCH_UCX_ANY_SOURCE) { \
|
||||
(_ucp_tag_mask) = TORCH_UCX_ANY_SOURCE_MASK; \
|
||||
} else { \
|
||||
(_ucp_tag_mask) = TORCH_UCX_SPECIFIC_SOURCE_MASK; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define TORCH_UCX_MAKE_OOB_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
|
||||
do { \
|
||||
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
|
||||
} while (0)
|
||||
|
||||
#define TORCH_UCX_MAKE_OOB_RECV_TAG( \
|
||||
_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
|
||||
do { \
|
||||
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
|
||||
(_ucp_tag_mask) = (uint64_t)-1; \
|
||||
} while (0)
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#define SAVE_TENSORS(_TENSORS, _DATA) \
|
||||
do { \
|
||||
|
|
@ -84,8 +43,6 @@ namespace c10d {
|
|||
|
||||
constexpr const char* UCC_BACKEND_NAME = "ucc";
|
||||
|
||||
enum torch_ucx_tag_type_t { TORCH_UCX_P2P_TAG, TORCH_UCX_OOB_TAG };
|
||||
|
||||
struct event_pool_t {
|
||||
#ifdef USE_CUDA
|
||||
std::queue<std::unique_ptr<at::cuda::CUDAEvent>> event_pool;
|
||||
|
|
@ -173,10 +130,12 @@ class TORCH_API ProcessGroupUCC : public ProcessGroup {
|
|||
bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
|
||||
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
|
||||
std::vector<at::Tensor> result() override;
|
||||
int sourceRank() const override;
|
||||
#ifdef USE_CUDA
|
||||
std::unique_ptr<at::cuda::CUDAEvent> fence = nullptr;
|
||||
event_pool_t* ep = nullptr;
|
||||
#endif
|
||||
int sourceRank_;
|
||||
protected:
|
||||
std::shared_ptr<ProgressEntry> entry_;
|
||||
c10::intrusive_ptr<ProcessGroupUCCLogger> logger_;
|
||||
|
|
@ -293,10 +252,6 @@ class TORCH_API ProcessGroupUCC : public ProcessGroup {
|
|||
int srcRank,
|
||||
int tag) override;
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
int tag) override;
|
||||
|
||||
static c10::intrusive_ptr<ProcessGroup> createProcessGroupUCC(
|
||||
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||
int rank,
|
||||
|
|
@ -308,7 +263,6 @@ class TORCH_API ProcessGroupUCC : public ProcessGroup {
|
|||
std::shared_ptr<torch_ucc_oob_coll_info_t> oob;
|
||||
std::shared_ptr<Comm> comm = {nullptr};
|
||||
uint32_t comm_id;
|
||||
std::vector<ucp_ep_h> eps;
|
||||
ucc_team_h team{nullptr};
|
||||
ucc_ee_h cuda_ee{nullptr};
|
||||
#ifdef USE_CUDA
|
||||
|
|
@ -321,7 +275,6 @@ class TORCH_API ProcessGroupUCC : public ProcessGroup {
|
|||
class Comm {
|
||||
c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
|
||||
std::shared_ptr<torch_ucc_oob_coll_info_t> oob;
|
||||
CommUCX ucx_comm;
|
||||
CommUCC ucc_comm;
|
||||
std::mutex mutex;
|
||||
std::thread progress_thread;
|
||||
|
|
@ -342,16 +295,6 @@ class Comm {
|
|||
|
||||
~Comm();
|
||||
|
||||
// Connects UCX end points.
|
||||
void ucx_connect_eps(
|
||||
std::vector<ucp_ep_h>& eps,
|
||||
std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
|
||||
|
||||
// Disconnects UCX end points.
|
||||
void ucx_disconnect_eps(
|
||||
std::vector<ucp_ep_h>& eps,
|
||||
std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
|
||||
|
||||
void ucc_create_team(
|
||||
ucc_team_h& team,
|
||||
std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
|
||||
|
|
@ -386,20 +329,6 @@ class Comm {
|
|||
bool is_health_check = false);
|
||||
|
||||
void progress_loop();
|
||||
|
||||
ucc_coll_req_h send_nb(
|
||||
ucp_ep_h ep,
|
||||
void* data,
|
||||
ucs_memory_type_t mtype,
|
||||
size_t size,
|
||||
ucp_tag_t ucp_tag);
|
||||
|
||||
ucc_coll_req_h recv_nb(
|
||||
void* data,
|
||||
ucs_memory_type_t mtype,
|
||||
size_t size,
|
||||
ucp_tag_t ucp_tag,
|
||||
ucp_tag_t ucp_tag_mask);
|
||||
};
|
||||
|
||||
} // namespace c10d
|
||||
|
|
|
|||
|
|
@ -17,75 +17,6 @@ constexpr char kAllGatherDone[] = "ag_done";
|
|||
constexpr char kAllGatherFree[] = "ag_free";
|
||||
} // namespace
|
||||
|
||||
CommUCX::CommUCX(
|
||||
int comm_size,
|
||||
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger)
|
||||
: CommBase(logger) {
|
||||
ucp_params_t params;
|
||||
ucp_config_t* config;
|
||||
ucs_status_t st;
|
||||
ucp_worker_params_t worker_params;
|
||||
ucp_lib_attr_t ucp_attr;
|
||||
|
||||
ucp_attr.field_mask = UCP_LIB_ATTR_FIELD_MAX_THREAD_LEVEL;
|
||||
TORCH_UCX_CHECK(
|
||||
ucp_lib_query(&ucp_attr), "failed to query UCP lib attributes");
|
||||
TORCH_CHECK(
|
||||
ucp_attr.max_thread_level == UCS_THREAD_MODE_MULTI,
|
||||
"ucx library wasn't initialized with multithreading support, "
|
||||
"please check ucx build options");
|
||||
TORCH_UCX_CHECK(
|
||||
ucp_config_read("TORCH", nullptr, &config), "failed to read UCP config");
|
||||
|
||||
memset(¶ms, 0, sizeof(ucp_params_t));
|
||||
params.field_mask = UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_REQUEST_SIZE |
|
||||
UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | UCP_PARAM_FIELD_TAG_SENDER_MASK |
|
||||
UCP_PARAM_FIELD_REQUEST_INIT | UCP_PARAM_FIELD_REQUEST_CLEANUP;
|
||||
params.request_size = sizeof(ucc_coll_req_t);
|
||||
params.features = UCP_FEATURE_TAG;
|
||||
params.estimated_num_eps = comm_size;
|
||||
params.tag_sender_mask = TORCH_UCX_RANK_MASK;
|
||||
params.request_init = [](void* request) {
|
||||
static_cast<ucc_coll_req_h>(request)->status = UCC_INPROGRESS;
|
||||
};
|
||||
params.request_cleanup = [](void*) {};
|
||||
TORCH_UCX_CHECK(
|
||||
ucp_init(¶ms, config, &context), "failed to init UCP context");
|
||||
ucp_config_release(config);
|
||||
|
||||
memset(&worker_params, 0, sizeof(ucp_worker_params_t));
|
||||
worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
|
||||
worker_params.thread_mode = UCS_THREAD_MODE_MULTI;
|
||||
st = ucp_worker_create(context, &worker_params, &worker);
|
||||
if (st != UCS_OK) {
|
||||
TORCH_UCC_LOG_ERROR(
|
||||
TORCH_UCC_INIT,
|
||||
c10::str("UCX failed to create UCP worker:", ucs_status_string(st)));
|
||||
ucp_cleanup(context);
|
||||
throw std::runtime_error(ucs_status_string(st));
|
||||
}
|
||||
}
|
||||
|
||||
void CommUCX::progress() {
|
||||
ucp_worker_progress(worker);
|
||||
}
|
||||
|
||||
void CommUCX::free_request(ucc_coll_req_h request) {
|
||||
request->status = UCC_INPROGRESS;
|
||||
ucp_request_free(request);
|
||||
}
|
||||
|
||||
CommUCX::~CommUCX() {
|
||||
if (worker != nullptr) {
|
||||
ucp_worker_destroy(worker);
|
||||
}
|
||||
if (context != nullptr) {
|
||||
ucp_cleanup(context);
|
||||
}
|
||||
worker = nullptr;
|
||||
context = nullptr;
|
||||
}
|
||||
|
||||
ucc_status_t oob_allgather(
|
||||
void* sbuf,
|
||||
void* rbuf,
|
||||
|
|
|
|||
|
|
@ -5,28 +5,6 @@
|
|||
#include <c10d/ProcessGroup.hpp>
|
||||
#include <c10d/Store.hpp>
|
||||
#include <ucc/api/ucc.h>
|
||||
#include <ucp/api/ucp.h>
|
||||
|
||||
#define TORCH_UCX_COMM_BITS 15
|
||||
#define TORCH_UCX_RANK_BITS 16
|
||||
#define TORCH_UCX_TAG_BITS 32
|
||||
#define TORCH_UCX_OOB_BITS 1
|
||||
|
||||
#define TORCH_UCX_COMM_BITS_OFFSET 0
|
||||
#define TORCH_UCX_RANK_BITS_OFFSET TORCH_UCX_COMM_BITS
|
||||
#define TORCH_UCX_TAG_BITS_OFFSET (TORCH_UCX_COMM_BITS + TORCH_UCX_RANK_BITS)
|
||||
#define TORCH_UCX_OOB_BITS_OFFSET \
|
||||
(TORCH_UCX_COMM_BITS + TORCH_UCX_RANK_BITS + TORCH_UCX_TAG_BITS)
|
||||
|
||||
#define TORCH_UCX_MAX_COMM ((((uint64_t)1) << TORCH_UCX_COMM_BITS) - 1)
|
||||
#define TORCH_UCX_MAX_RANK ((((uint64_t)1) << TORCH_UCX_RANK_BITS) - 1)
|
||||
#define TORCH_UCX_MAX_TAG ((((uint64_t)1) << TORCH_UCX_TAG_BITS) - 1)
|
||||
#define TORCH_UCX_MAX_OOB ((((uint64_t)1) << TORCH_UCX_OOB_BITS) - 1)
|
||||
|
||||
#define TORCH_UCX_COMM_MASK (TORCH_UCX_MAX_COMM << TORCH_UCX_COMM_BITS_OFFSET)
|
||||
#define TORCH_UCX_RANK_MASK (TORCH_UCX_MAX_RANK << TORCH_UCX_RANK_BITS_OFFSET)
|
||||
#define TORCH_UCX_TAG_MASK (TORCH_UCX_MAX_TAG << TORCH_UCX_TAG_BITS_OFFSET)
|
||||
#define TORCH_UCX_OOB_MASK (TORCH_UCX_MAX_OOB << TORCH_UCX_OOB_BITS_OFFSET)
|
||||
|
||||
namespace c10d {
|
||||
|
||||
|
|
@ -53,29 +31,6 @@ namespace c10d {
|
|||
} \
|
||||
} while (0)
|
||||
|
||||
// Macro to throw on a non-successful UCX return value.
|
||||
#define TORCH_UCX_CHECK(_cmd, _error_msg) \
|
||||
do { \
|
||||
ucs_status_t result = _cmd; \
|
||||
if (result != UCS_OK) { \
|
||||
std::string err = c10::str( \
|
||||
"[", \
|
||||
std::string(__FILE__), \
|
||||
":", \
|
||||
std::to_string(__LINE__), \
|
||||
"] ", \
|
||||
logger->getLogPrefix(), \
|
||||
_error_msg, \
|
||||
", error code ", \
|
||||
result, \
|
||||
": ", \
|
||||
ucs_status_string(result), \
|
||||
", system error code ", \
|
||||
errno); \
|
||||
TORCH_CHECK(false, err); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Macros to print logs with unified format
|
||||
#define TORCH_UCC_LOG_ERROR(_phase, _msg) \
|
||||
LOG(ERROR) << logger->getLogPrefix(_phase) << "[ERROR] " << _msg;
|
||||
|
|
@ -148,21 +103,6 @@ class CommBase {
|
|||
virtual ~CommBase() {}
|
||||
c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
|
||||
};
|
||||
|
||||
class CommUCX : public CommBase {
|
||||
public:
|
||||
ucp_context_h context{nullptr};
|
||||
ucp_worker_h worker{nullptr};
|
||||
|
||||
public:
|
||||
void progress() override;
|
||||
void free_request(ucc_coll_req_h request) override;
|
||||
CommUCX(
|
||||
int comm_size,
|
||||
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger);
|
||||
~CommUCX();
|
||||
};
|
||||
|
||||
class CommUCC : public CommBase {
|
||||
public:
|
||||
ucc_lib_h lib{nullptr};
|
||||
|
|
|
|||
|
|
@ -73,17 +73,17 @@ TEST_SKIPS = {
|
|||
class DistTestCases:
|
||||
# Backends that do not support a specific collective
|
||||
skip_collective = {}
|
||||
skip_collective["allgather_coalesced"] = {"nccl", "mpi"}
|
||||
skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc"}
|
||||
skip_collective["reduce"] = set()
|
||||
skip_collective["sendrecv anysource"] = {"nccl"}
|
||||
skip_collective["cpu barrier"] = {"nccl"}
|
||||
skip_collective["sendrecv anysource"] = {"nccl", "ucc"}
|
||||
skip_collective["cpu barrier"] = {"nccl", "ucc"}
|
||||
|
||||
# Sets showing that something is implemented
|
||||
backend_feature = {}
|
||||
backend_feature["gpu"] = {"nccl", "gloo"}
|
||||
backend_feature["cuda"] = {"nccl", "gloo"}
|
||||
backend_feature["ddp"] = {"nccl", "gloo"}
|
||||
backend_feature["subgroup"] = {"nccl", "gloo"}
|
||||
backend_feature["gpu"] = {"nccl", "gloo"} # TODO(ucc): add sequence number support to ucc and enable it here
|
||||
backend_feature["cuda"] = {"nccl", "gloo", "ucc"}
|
||||
backend_feature["ddp"] = {"nccl", "gloo", "ucc"}
|
||||
backend_feature["subgroup"] = {"nccl", "gloo", "ucc"}
|
||||
backend_feature["plugin"] = set()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -143,6 +143,7 @@ PROFILING_SUPPORTED_BACKENDS = [
|
|||
dist.Backend.NCCL,
|
||||
dist.Backend.GLOO,
|
||||
dist.Backend.MPI,
|
||||
# TODO(ucc): dist.Backend.UCC,
|
||||
]
|
||||
|
||||
# Allowlist of distributed backends where profiling is supported with use_cuda=True
|
||||
|
|
@ -150,6 +151,7 @@ CUDA_PROFILING_SUPPORTED_BACKENDS = [
|
|||
dist.Backend.GLOO,
|
||||
dist.Backend.MPI,
|
||||
dist.Backend.NCCL,
|
||||
# TODO(ucc): dist.Backend.UCC,
|
||||
]
|
||||
|
||||
# Allowlist of distributed backends where profiling is supported for p2p ops
|
||||
|
|
@ -157,6 +159,7 @@ SEND_RECV_PROFILING_SUPPORTED_BACKENDS = [
|
|||
dist.Backend.MPI,
|
||||
dist.Backend.GLOO,
|
||||
dist.Backend.NCCL,
|
||||
# TODO(ucc): dist.Backend.UCC,
|
||||
]
|
||||
|
||||
# Dummy NamedTuple data structures to test DDP support for NamedTuple types.
|
||||
|
|
@ -395,6 +398,8 @@ def require_backends_available(backends):
|
|||
return dist.is_nccl_available()
|
||||
if backend == dist.Backend.MPI:
|
||||
return dist.is_mpi_available()
|
||||
if backend == dist.Backend.UCC:
|
||||
return dist.is_ucc_available()
|
||||
if backend in DistTestCases.backend_feature["plugin"]:
|
||||
return True
|
||||
return False
|
||||
|
|
@ -2913,6 +2918,7 @@ class DistributedTest:
|
|||
self._barrier()
|
||||
|
||||
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
|
||||
def test_scatter_checks(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
one = torch.ones([1])
|
||||
|
|
@ -2936,6 +2942,7 @@ class DistributedTest:
|
|||
self.assertEqual(output, one * rank)
|
||||
|
||||
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
|
||||
def test_scatter(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_scatter_helper(group, group_id, rank)
|
||||
|
|
@ -2948,6 +2955,7 @@ class DistributedTest:
|
|||
self._test_scatter_helper(group, group_id, rank, True, rank_to_GPU)
|
||||
|
||||
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
|
||||
def test_scatter_complex(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_scatter_helper(group, group_id, rank, dtype=torch.cfloat)
|
||||
|
|
@ -2960,12 +2968,14 @@ class DistributedTest:
|
|||
self._test_scatter_helper(group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat)
|
||||
|
||||
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
|
||||
@skip_if_small_worldsize
|
||||
def test_scatter_group(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_scatter_helper(group, group_id, rank)
|
||||
|
||||
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
|
||||
def test_scatter_full_group(self):
|
||||
group, group_id, rank = self._init_full_group_test()
|
||||
self._test_scatter_helper(group, group_id, rank)
|
||||
|
|
@ -2999,6 +3009,7 @@ class DistributedTest:
|
|||
self._barrier()
|
||||
|
||||
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
|
||||
def test_gather_checks(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
one = torch.ones([1])
|
||||
|
|
@ -3022,6 +3033,7 @@ class DistributedTest:
|
|||
dist.gather(one * rank)
|
||||
|
||||
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
|
||||
def test_gather(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_gather_helper(group, group_id, rank)
|
||||
|
|
@ -3034,12 +3046,14 @@ class DistributedTest:
|
|||
self._test_gather_helper(group, group_id, rank, True, rank_to_GPU)
|
||||
|
||||
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
|
||||
@skip_if_small_worldsize
|
||||
def test_gather_group(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_gather_helper(group, group_id, rank)
|
||||
|
||||
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
|
||||
def test_gather_full_group(self):
|
||||
group, group_id, rank = self._init_full_group_test()
|
||||
self._test_gather_helper(group, group_id, rank)
|
||||
|
|
@ -3611,6 +3625,7 @@ class DistributedTest:
|
|||
|
||||
@skip_if_no_gpu
|
||||
@sandcastle_skip_if(BACKEND == "mpi", "MPI doesn't supports GPU barrier")
|
||||
@sandcastle_skip_if(BACKEND == "ucc", "flaky on PyTorch CI with timeout")
|
||||
def test_barrier_cuda(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
|
||||
|
|
@ -4336,7 +4351,7 @@ class DistributedTest:
|
|||
dist.barrier()
|
||||
|
||||
@sandcastle_skip_if(
|
||||
BACKEND == "nccl",
|
||||
BACKEND == "nccl" or BACKEND == "ucc",
|
||||
"Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259"
|
||||
)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
|
|
@ -4363,7 +4378,7 @@ class DistributedTest:
|
|||
)
|
||||
|
||||
@sandcastle_skip_if(
|
||||
BACKEND == "nccl",
|
||||
BACKEND == "nccl" or BACKEND == "ucc",
|
||||
"Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259"
|
||||
)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
|
|
@ -4383,7 +4398,7 @@ class DistributedTest:
|
|||
)
|
||||
|
||||
@sandcastle_skip_if(
|
||||
BACKEND == "nccl",
|
||||
BACKEND == "nccl" or BACKEND == "ucc",
|
||||
"Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259"
|
||||
)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
|
|
@ -9018,6 +9033,10 @@ class DistributedTest:
|
|||
BACKEND not in DistTestCases.backend_feature["cuda"],
|
||||
f"The {BACKEND} backend does not support DDP communication hook on CUDA devices"
|
||||
)
|
||||
@sandcastle_skip_if(
|
||||
BACKEND == "ucc",
|
||||
"flaky on PyTorch CI: No such file or directory: '/tmp/checkpoint.pt'"
|
||||
)
|
||||
@skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
|
||||
def test_ddp_hook_pickling_powerSGD(self):
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue