Remove signed/unsigned compiler warnings, add additional pipeline test case (#4314)

* Avoid signed/unsigned warning on loops

* Report sizes when distributed world configuration is inconsistent

* Add DistributedRunContextTest for pipeline stage configuration
This commit is contained in:
Tim Harris 2020-06-24 11:36:18 +01:00 committed by GitHub
parent 44f06ec480
commit 5c6a27408a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 12 deletions

View file

@ -24,13 +24,21 @@ DistributedRunContext::DistributedRunContext(int32_t world_rank,
"data_parallel_size(" + std::to_string(data_parallel_size) + ") and horizontal_parallel_size(" +
std::to_string(horizontal_parallel_size) + ") MUST range from 0 ~ world_size(" + std::to_string(world_size) + ")");
ORT_ENFORCE(world_size % horizontal_parallel_size == 0, "world size is not divisible by horizontal model parallel size.");
ORT_ENFORCE(world_size % data_parallel_size == 0, "world size is not divisible by data parallel size.");
ORT_ENFORCE(world_size % horizontal_parallel_size == 0,
"world_size(" + std::to_string(world_size) + ") is not divisible by "
"horizontal_parallel_size(" + std::to_string(horizontal_parallel_size) + ").");
ORT_ENFORCE(world_size % data_parallel_size == 0,
"world_size(" + std::to_string(world_size) + ") is not divisible by "
"data_parallel_size(" + std::to_string(data_parallel_size) + ").");
// Be noted: this check and subsequent logic should be updated when we introduce pipeline group
// depending how to split the pipeline groups.
ORT_ENFORCE(data_parallel_size * horizontal_parallel_size * pipeline_stage_size == world_size,
"total worker number != data_parallel_size * horizontal_parallel_size * pipeline_stage_size");
"data_parallel_size(" + std::to_string(data_parallel_size) + ") "
"* horizontal_parallel_size(" + std::to_string(horizontal_parallel_size) + ") "
"* pipeline_stage_size(" + std::to_string(pipeline_stage_size) + ") "
"!= world_size(" + std::to_string(world_size) + ").");
params_.world_rank = world_rank;
params_.world_size = world_size;

View file

@ -17,8 +17,13 @@ namespace test {
class DistributedRunTestContext : public DistributedRunContext {
public:
DistributedRunTestContext(DistributedRunConfig config)
: DistributedRunContext(config.world_rank, config.world_size, config.local_rank,
config.local_size, config.data_parallel_size, config.horizontal_parallel_size) {
: DistributedRunContext(config.world_rank,
config.world_size,
config.local_rank,
config.local_size,
config.data_parallel_size,
config.horizontal_parallel_size,
config.pipeline_stage_size) {
}
};
@ -431,7 +436,7 @@ TEST(DistributedRunContextTest, FailTest1) {
DistributedRunConfig config = {63, 64, 3, 4, 16, 5};
DistributedRunTestContext ctx(config);
} catch (const std::exception& ex) {
auto ret = std::string(ex.what()).find("world size is not divisible by horizontal model parallel size");
auto ret = std::string(ex.what()).find("world_size(64) is not divisible by horizontal_parallel_size(5)");
ASSERT_TRUE(ret != std::string::npos);
}
}
@ -441,7 +446,17 @@ TEST(DistributedRunContextTest, FailTest2) {
DistributedRunConfig config = {63, 64, 3, 4, 8, 4};
DistributedRunTestContext ctx(config);
} catch (const std::exception& ex) {
auto ret = std::string(ex.what()).find("total worker number != data_parallel_size * horizontal_parallel_size");
auto ret = std::string(ex.what()).find("data_parallel_size(8) * horizontal_parallel_size(4) * pipeline_stage_size(1) != world_size(64)");
ASSERT_TRUE(ret != std::string::npos);
}
}
TEST(DistributedRunContextTest, FailTest3) {
try {
DistributedRunConfig config = {63, 64, 3, 4, 1, 4, 4};
DistributedRunTestContext ctx(config);
} catch (const std::exception& ex) {
auto ret = std::string(ex.what()).find("data_parallel_size(1) * horizontal_parallel_size(4) * pipeline_stage_size(4) != world_size(64)");
ASSERT_TRUE(ret != std::string::npos);
}
}

View file

@ -82,7 +82,7 @@ Status MergeGraph(Graph& graph, Graph& graph_to_merge, int rank, std::vector<Nod
Status MergeGraphsOnAllWorkers(std::vector<Graph*>& graphs, Graph& combine_graph) {
auto total_rank = graphs.size();
std::vector<std::vector<Node*>> megatronGs(total_rank, std::vector<Node*>());
for (auto i = 0; i < total_rank; i++) {
for (auto i = 0u; i < total_rank; i++) {
auto merge_ret = horizontal_parallel_test_utils::MergeGraph(combine_graph, *graphs[i], i, megatronGs[i]);
ORT_ENFORCE(merge_ret.IsOK());
ORT_ENFORCE(megatronGs[i].size() == megatronGs[0].size());
@ -90,14 +90,14 @@ Status MergeGraphsOnAllWorkers(std::vector<Graph*>& graphs, Graph& combine_graph
std::vector<onnxruntime::NodeIndex> nodes_to_remove;
// Merge megatron g at the same index for different ranks
for (auto g_index = 0; g_index < megatronGs[0].size(); g_index++) {
for (auto g_index = 0u; g_index < megatronGs[0].size(); g_index++) {
// Merge the "g_index"th MegatronG on each rank into one Sum node.
std::vector<NodeArg*> input_defs{};
auto type_info = *megatronGs[0][g_index]->MutableOutputDefs()[0]->TypeAsProto();
auto& input_arg = combine_graph.GetOrCreateNodeArg("sum_" + std::to_string(g_index), &type_info);
std::vector<NodeArg*> output_defs{&input_arg};
for (auto rank_index = 0; rank_index < total_rank; rank_index++) {
for (auto rank_index = 0u; rank_index < total_rank; rank_index++) {
input_defs.push_back(megatronGs[rank_index][g_index]->MutableInputDefs()[0]);
}
auto& sum_node = combine_graph.AddNode(combine_graph.GenerateNodeName("Sum_For_MegatronG"),
@ -107,7 +107,7 @@ Status MergeGraphsOnAllWorkers(std::vector<Graph*>& graphs, Graph& combine_graph
output_defs);
sum_node.SetExecutionProviderType(megatronGs[0][g_index]->GetExecutionProviderType());
for (auto rank_index = 0; rank_index < total_rank; rank_index++) {
for (auto rank_index = 0u; rank_index < total_rank; rank_index++) {
graph_utils::ReplaceDownstreamNodeInput(combine_graph, *megatronGs[rank_index][g_index], 0, sum_node, 0);
nodes_to_remove.push_back(megatronGs[rank_index][g_index]->Index());
}
@ -134,7 +134,7 @@ void VerifyOutputs(const std::vector<float>& expected, const std::vector<float>&
bool use_threshold_compare, float atol, float rtol, float threshold) {
auto size = expected.size();
ORT_ENFORCE(size == actual.size());
for (auto i = 0; i < size; ++i) {
for (auto i = 0u; i < size; ++i) {
const auto expected_value = expected[i], actual_value = actual[i];
if (std::isnan(expected_value)) {
ASSERT_TRUE(std::isnan(actual_value)) << "value mismatch at index " << i << "; expected is NaN, actual is not NaN";