mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
Remove signed/unsigned compiler warnings, add additional pipeline test case (#4314)
* Avoid signed/unsigned warning on loops * Report sizes when distributed world configuration is inconsistent * Add DistributedRunContextTest for pipeline stage configuration
This commit is contained in:
parent
44f06ec480
commit
5c6a27408a
3 changed files with 35 additions and 12 deletions
|
|
@ -24,13 +24,21 @@ DistributedRunContext::DistributedRunContext(int32_t world_rank,
|
|||
"data_parallel_size(" + std::to_string(data_parallel_size) + ") and horizontal_parallel_size(" +
|
||||
std::to_string(horizontal_parallel_size) + ") MUST range from 0 ~ world_size(" + std::to_string(world_size) + ")");
|
||||
|
||||
ORT_ENFORCE(world_size % horizontal_parallel_size == 0, "world size is not divisible by horizontal model parallel size.");
|
||||
ORT_ENFORCE(world_size % data_parallel_size == 0, "world size is not divisible by data parallel size.");
|
||||
ORT_ENFORCE(world_size % horizontal_parallel_size == 0,
|
||||
"world_size(" + std::to_string(world_size) + ") is not divisible by "
|
||||
"horizontal_parallel_size(" + std::to_string(horizontal_parallel_size) + ").");
|
||||
|
||||
ORT_ENFORCE(world_size % data_parallel_size == 0,
|
||||
"world_size(" + std::to_string(world_size) + ") is not divisible by "
|
||||
"data_parallel_size(" + std::to_string(data_parallel_size) + ").");
|
||||
|
||||
// Be noted: this check and subsequent logic should be updated when we introduce pipeline group
|
||||
// depending how to split the pipeline groups.
|
||||
ORT_ENFORCE(data_parallel_size * horizontal_parallel_size * pipeline_stage_size == world_size,
|
||||
"total worker number != data_parallel_size * horizontal_parallel_size * pipeline_stage_size");
|
||||
"data_parallel_size(" + std::to_string(data_parallel_size) + ") "
|
||||
"* horizontal_parallel_size(" + std::to_string(horizontal_parallel_size) + ") "
|
||||
"* pipeline_stage_size(" + std::to_string(pipeline_stage_size) + ") "
|
||||
"!= world_size(" + std::to_string(world_size) + ").");
|
||||
|
||||
params_.world_rank = world_rank;
|
||||
params_.world_size = world_size;
|
||||
|
|
|
|||
|
|
@ -17,8 +17,13 @@ namespace test {
|
|||
class DistributedRunTestContext : public DistributedRunContext {
|
||||
public:
|
||||
DistributedRunTestContext(DistributedRunConfig config)
|
||||
: DistributedRunContext(config.world_rank, config.world_size, config.local_rank,
|
||||
config.local_size, config.data_parallel_size, config.horizontal_parallel_size) {
|
||||
: DistributedRunContext(config.world_rank,
|
||||
config.world_size,
|
||||
config.local_rank,
|
||||
config.local_size,
|
||||
config.data_parallel_size,
|
||||
config.horizontal_parallel_size,
|
||||
config.pipeline_stage_size) {
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -431,7 +436,7 @@ TEST(DistributedRunContextTest, FailTest1) {
|
|||
DistributedRunConfig config = {63, 64, 3, 4, 16, 5};
|
||||
DistributedRunTestContext ctx(config);
|
||||
} catch (const std::exception& ex) {
|
||||
auto ret = std::string(ex.what()).find("world size is not divisible by horizontal model parallel size");
|
||||
auto ret = std::string(ex.what()).find("world_size(64) is not divisible by horizontal_parallel_size(5)");
|
||||
ASSERT_TRUE(ret != std::string::npos);
|
||||
}
|
||||
}
|
||||
|
|
@ -441,7 +446,17 @@ TEST(DistributedRunContextTest, FailTest2) {
|
|||
DistributedRunConfig config = {63, 64, 3, 4, 8, 4};
|
||||
DistributedRunTestContext ctx(config);
|
||||
} catch (const std::exception& ex) {
|
||||
auto ret = std::string(ex.what()).find("total worker number != data_parallel_size * horizontal_parallel_size");
|
||||
auto ret = std::string(ex.what()).find("data_parallel_size(8) * horizontal_parallel_size(4) * pipeline_stage_size(1) != world_size(64)");
|
||||
ASSERT_TRUE(ret != std::string::npos);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DistributedRunContextTest, FailTest3) {
|
||||
try {
|
||||
DistributedRunConfig config = {63, 64, 3, 4, 1, 4, 4};
|
||||
DistributedRunTestContext ctx(config);
|
||||
} catch (const std::exception& ex) {
|
||||
auto ret = std::string(ex.what()).find("data_parallel_size(1) * horizontal_parallel_size(4) * pipeline_stage_size(4) != world_size(64)");
|
||||
ASSERT_TRUE(ret != std::string::npos);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ Status MergeGraph(Graph& graph, Graph& graph_to_merge, int rank, std::vector<Nod
|
|||
Status MergeGraphsOnAllWorkers(std::vector<Graph*>& graphs, Graph& combine_graph) {
|
||||
auto total_rank = graphs.size();
|
||||
std::vector<std::vector<Node*>> megatronGs(total_rank, std::vector<Node*>());
|
||||
for (auto i = 0; i < total_rank; i++) {
|
||||
for (auto i = 0u; i < total_rank; i++) {
|
||||
auto merge_ret = horizontal_parallel_test_utils::MergeGraph(combine_graph, *graphs[i], i, megatronGs[i]);
|
||||
ORT_ENFORCE(merge_ret.IsOK());
|
||||
ORT_ENFORCE(megatronGs[i].size() == megatronGs[0].size());
|
||||
|
|
@ -90,14 +90,14 @@ Status MergeGraphsOnAllWorkers(std::vector<Graph*>& graphs, Graph& combine_graph
|
|||
|
||||
std::vector<onnxruntime::NodeIndex> nodes_to_remove;
|
||||
// Merge megatron g at the same index for different ranks
|
||||
for (auto g_index = 0; g_index < megatronGs[0].size(); g_index++) {
|
||||
for (auto g_index = 0u; g_index < megatronGs[0].size(); g_index++) {
|
||||
// Merge the "g_index"th MegatronG on each rank into one Sum node.
|
||||
std::vector<NodeArg*> input_defs{};
|
||||
auto type_info = *megatronGs[0][g_index]->MutableOutputDefs()[0]->TypeAsProto();
|
||||
auto& input_arg = combine_graph.GetOrCreateNodeArg("sum_" + std::to_string(g_index), &type_info);
|
||||
std::vector<NodeArg*> output_defs{&input_arg};
|
||||
|
||||
for (auto rank_index = 0; rank_index < total_rank; rank_index++) {
|
||||
for (auto rank_index = 0u; rank_index < total_rank; rank_index++) {
|
||||
input_defs.push_back(megatronGs[rank_index][g_index]->MutableInputDefs()[0]);
|
||||
}
|
||||
auto& sum_node = combine_graph.AddNode(combine_graph.GenerateNodeName("Sum_For_MegatronG"),
|
||||
|
|
@ -107,7 +107,7 @@ Status MergeGraphsOnAllWorkers(std::vector<Graph*>& graphs, Graph& combine_graph
|
|||
output_defs);
|
||||
sum_node.SetExecutionProviderType(megatronGs[0][g_index]->GetExecutionProviderType());
|
||||
|
||||
for (auto rank_index = 0; rank_index < total_rank; rank_index++) {
|
||||
for (auto rank_index = 0u; rank_index < total_rank; rank_index++) {
|
||||
graph_utils::ReplaceDownstreamNodeInput(combine_graph, *megatronGs[rank_index][g_index], 0, sum_node, 0);
|
||||
nodes_to_remove.push_back(megatronGs[rank_index][g_index]->Index());
|
||||
}
|
||||
|
|
@ -134,7 +134,7 @@ void VerifyOutputs(const std::vector<float>& expected, const std::vector<float>&
|
|||
bool use_threshold_compare, float atol, float rtol, float threshold) {
|
||||
auto size = expected.size();
|
||||
ORT_ENFORCE(size == actual.size());
|
||||
for (auto i = 0; i < size; ++i) {
|
||||
for (auto i = 0u; i < size; ++i) {
|
||||
const auto expected_value = expected[i], actual_value = actual[i];
|
||||
if (std::isnan(expected_value)) {
|
||||
ASSERT_TRUE(std::isnan(actual_value)) << "value mismatch at index " << i << "; expected is NaN, actual is not NaN";
|
||||
|
|
|
|||
Loading…
Reference in a new issue