diff --git a/orttraining/orttraining/core/framework/distributed_run_context.cc b/orttraining/orttraining/core/framework/distributed_run_context.cc index c7c1201bae..05b010d62c 100644 --- a/orttraining/orttraining/core/framework/distributed_run_context.cc +++ b/orttraining/orttraining/core/framework/distributed_run_context.cc @@ -24,13 +24,21 @@ DistributedRunContext::DistributedRunContext(int32_t world_rank, "data_parallel_size(" + std::to_string(data_parallel_size) + ") and horizontal_parallel_size(" + std::to_string(horizontal_parallel_size) + ") MUST range from 0 ~ world_size(" + std::to_string(world_size) + ")"); - ORT_ENFORCE(world_size % horizontal_parallel_size == 0, "world size is not divisible by horizontal model parallel size."); - ORT_ENFORCE(world_size % data_parallel_size == 0, "world size is not divisible by data parallel size."); + ORT_ENFORCE(world_size % horizontal_parallel_size == 0, + "world_size(" + std::to_string(world_size) + ") is not divisible by " + "horizontal_parallel_size(" + std::to_string(horizontal_parallel_size) + ")."); + + ORT_ENFORCE(world_size % data_parallel_size == 0, + "world_size(" + std::to_string(world_size) + ") is not divisible by " + "data_parallel_size(" + std::to_string(data_parallel_size) + ")."); // Be noted: this check and subsequent logic should be updated when we introduce pipeline group // depending how to split the pipeline groups. ORT_ENFORCE(data_parallel_size * horizontal_parallel_size * pipeline_stage_size == world_size, - "total worker number != data_parallel_size * horizontal_parallel_size * pipeline_stage_size"); + "data_parallel_size(" + std::to_string(data_parallel_size) + ") " + "* horizontal_parallel_size(" + std::to_string(horizontal_parallel_size) + ") " + "* pipeline_stage_size(" + std::to_string(pipeline_stage_size) + ") " + "!= world_size(" + std::to_string(world_size) + ")."); params_.world_rank = world_rank; params_.world_size = world_size; diff --git a/orttraining/orttraining/test/framework/distributed_run_context_test.cc b/orttraining/orttraining/test/framework/distributed_run_context_test.cc index 5fd7439cae..edda0908eb 100644 --- a/orttraining/orttraining/test/framework/distributed_run_context_test.cc +++ b/orttraining/orttraining/test/framework/distributed_run_context_test.cc @@ -17,8 +17,13 @@ namespace test { class DistributedRunTestContext : public DistributedRunContext { public: DistributedRunTestContext(DistributedRunConfig config) - : DistributedRunContext(config.world_rank, config.world_size, config.local_rank, - config.local_size, config.data_parallel_size, config.horizontal_parallel_size) { + : DistributedRunContext(config.world_rank, + config.world_size, + config.local_rank, + config.local_size, + config.data_parallel_size, + config.horizontal_parallel_size, + config.pipeline_stage_size) { } }; @@ -431,7 +436,7 @@ TEST(DistributedRunContextTest, FailTest1) { DistributedRunConfig config = {63, 64, 3, 4, 16, 5}; DistributedRunTestContext ctx(config); } catch (const std::exception& ex) { - auto ret = std::string(ex.what()).find("world size is not divisible by horizontal model parallel size"); + auto ret = std::string(ex.what()).find("world_size(64) is not divisible by horizontal_parallel_size(5)"); ASSERT_TRUE(ret != std::string::npos); } } @@ -441,7 +446,17 @@ TEST(DistributedRunContextTest, FailTest2) { DistributedRunConfig config = {63, 64, 3, 4, 8, 4}; DistributedRunTestContext ctx(config); } catch (const std::exception& ex) { - auto ret = std::string(ex.what()).find("total worker number != data_parallel_size * horizontal_parallel_size"); + auto ret = std::string(ex.what()).find("data_parallel_size(8) * horizontal_parallel_size(4) * pipeline_stage_size(1) != world_size(64)"); + ASSERT_TRUE(ret != std::string::npos); + } +} + +TEST(DistributedRunContextTest, FailTest3) { + try { + DistributedRunConfig config = {63, 64, 3, 4, 1, 4, 4}; + DistributedRunTestContext ctx(config); + } catch (const std::exception& ex) { + auto ret = std::string(ex.what()).find("data_parallel_size(1) * horizontal_parallel_size(4) * pipeline_stage_size(4) != world_size(64)"); ASSERT_TRUE(ret != std::string::npos); } } diff --git a/orttraining/orttraining/test/optimizer/horizontal_parallel_test_utils.cc b/orttraining/orttraining/test/optimizer/horizontal_parallel_test_utils.cc index 4ff99cbfef..f13dfedd7c 100644 --- a/orttraining/orttraining/test/optimizer/horizontal_parallel_test_utils.cc +++ b/orttraining/orttraining/test/optimizer/horizontal_parallel_test_utils.cc @@ -82,7 +82,7 @@ Status MergeGraph(Graph& graph, Graph& graph_to_merge, int rank, std::vector& graphs, Graph& combine_graph) { auto total_rank = graphs.size(); std::vector> megatronGs(total_rank, std::vector()); - for (auto i = 0; i < total_rank; i++) { + for (auto i = 0u; i < total_rank; i++) { auto merge_ret = horizontal_parallel_test_utils::MergeGraph(combine_graph, *graphs[i], i, megatronGs[i]); ORT_ENFORCE(merge_ret.IsOK()); ORT_ENFORCE(megatronGs[i].size() == megatronGs[0].size()); @@ -90,14 +90,14 @@ Status MergeGraphsOnAllWorkers(std::vector& graphs, Graph& combine_graph std::vector nodes_to_remove; // Merge megatron g at the same index for different ranks - for (auto g_index = 0; g_index < megatronGs[0].size(); g_index++) { + for (auto g_index = 0u; g_index < megatronGs[0].size(); g_index++) { // Merge the "g_index"th MegatronG on each rank into one Sum node. std::vector input_defs{}; auto type_info = *megatronGs[0][g_index]->MutableOutputDefs()[0]->TypeAsProto(); auto& input_arg = combine_graph.GetOrCreateNodeArg("sum_" + std::to_string(g_index), &type_info); std::vector output_defs{&input_arg}; - for (auto rank_index = 0; rank_index < total_rank; rank_index++) { + for (auto rank_index = 0u; rank_index < total_rank; rank_index++) { input_defs.push_back(megatronGs[rank_index][g_index]->MutableInputDefs()[0]); } auto& sum_node = combine_graph.AddNode(combine_graph.GenerateNodeName("Sum_For_MegatronG"), @@ -107,7 +107,7 @@ Status MergeGraphsOnAllWorkers(std::vector& graphs, Graph& combine_graph output_defs); sum_node.SetExecutionProviderType(megatronGs[0][g_index]->GetExecutionProviderType()); - for (auto rank_index = 0; rank_index < total_rank; rank_index++) { + for (auto rank_index = 0u; rank_index < total_rank; rank_index++) { graph_utils::ReplaceDownstreamNodeInput(combine_graph, *megatronGs[rank_index][g_index], 0, sum_node, 0); nodes_to_remove.push_back(megatronGs[rank_index][g_index]->Index()); } @@ -134,7 +134,7 @@ void VerifyOutputs(const std::vector& expected, const std::vector& bool use_threshold_compare, float atol, float rtol, float threshold) { auto size = expected.size(); ORT_ENFORCE(size == actual.size()); - for (auto i = 0; i < size; ++i) { + for (auto i = 0u; i < size; ++i) { const auto expected_value = expected[i], actual_value = actual[i]; if (std::isnan(expected_value)) { ASSERT_TRUE(std::isnan(actual_value)) << "value mismatch at index " << i << "; expected is NaN, actual is not NaN";