From 5aec34500dd8d2f27c67bb0b1c8ad1e4e0006156 Mon Sep 17 00:00:00 2001 From: ashbhandare Date: Wed, 11 Nov 2020 16:21:36 -0800 Subject: [PATCH] Add megatron transforms for BART (#5521) * Large model export and run ORT Python support * Megatron change refine a bit workaround self attention issue use partitioned name for weights when megatron model parallel is enabled Fix Megatron Transformer Issue (cuased by the renaming) Add UTs for T5 model parallel Fix megatron seed issue fix log a bit checkkpointing changes + rebase Unintended reshape transform change t5 layer norm changes add t5 layer norm kernel use template for t5 layer norm template definition changes no build error add CPU cuda kernel first unit test other forward unit tests add T5LayerNormGrad Add c++ transform and test for T5 LN minor fix BART MLP Megatron tranform Add concat slice transform + test Cosmetic improvements in concat slice transform Constant folding bug fix + megatron attention transform for BART Undo unnecessary changes * Cleanup * Remove unnecessary changes * Cleanup megatron * Windows build * Add self attention test graph * Correcting transforms + cleanup * review comments * review comments * fix build and test failures * Fix CI * fix windows CI Co-authored-by: Peng Wang Co-authored-by: Aishwarya --- onnxruntime/test/optimizer/cse_test.cc | 6 +- .../bart_mlp_megatron_basic_test.onnx | Bin 0 -> 1716 bytes .../bart_mlp_megatron_basic_test.py | 105 +++ ...rt_self_attention_megatron_basic_test.onnx | Bin 0 -> 3274 bytes ...bart_self_attention_megatron_basic_test.py | 150 +++++ .../core/optimizer/graph_transformer_utils.cc | 6 +- .../core/optimizer/graph_transformer_utils.h | 3 +- .../core/optimizer/megatron_transformer.cc | 611 +++++++++++++++++- .../core/optimizer/megatron_transformer.h | 33 +- .../core/session/training_session.cc | 45 +- .../core/session/training_session.h | 11 +- .../python/orttraining_pybind_state.cc | 13 +- .../test/optimizer/graph_transform_test.cc | 212 ++---- 13 files changed, 979 insertions(+), 216 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.onnx create mode 100644 onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.py create mode 100644 onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.onnx create mode 100644 onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.py diff --git a/onnxruntime/test/optimizer/cse_test.cc b/onnxruntime/test/optimizer/cse_test.cc index bde7e4ff6e..32686b3683 100644 --- a/onnxruntime/test/optimizer/cse_test.cc +++ b/onnxruntime/test/optimizer/cse_test.cc @@ -95,8 +95,12 @@ TEST(CseTests, SimpleTestTraining) { .IsOK()); GraphTransformerManager graph_transformation_mgr(1); + // need to declare variables to avoid build error after making + // weights_to_train and updated_weight_names as non-const + std::unordered_set weights_to_train; + std::unordered_map updated_weight_names; auto transformers_to_register = onnxruntime::training::transformer_utils::GeneratePreTrainingTransformers( - TransformerLevel::Level1, {}, {}, CPUExecutionProvider(CPUExecutionProviderInfo())); + TransformerLevel::Level1, weights_to_train, {}, CPUExecutionProvider(CPUExecutionProviderInfo()), updated_weight_names); for (auto& entry : transformers_to_register) { ASSERT_TRUE( graph_transformation_mgr.Register(std::move(entry), TransformerLevel::Level1).IsOK()); diff --git a/onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.onnx b/onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..084cdea8dd3159c337b11e5692e3ccf5f811fdfd GIT binary patch literal 1716 zcmeHHKWGzS6i?dxxjq|uJp#p31Ex^hNDp&G)YALD8VkW-krqLaK)EE>_UQdHccFhA z>=F={EG}I{sDo1$mn<$pa7hQJEG}7Gvbe;-$(Mg;s-iTm-tc{o-+kZjz2AH9drXiU zHw|NNUfGj%OH&q~g+PZ}ijg&QN?vfTZ`h8U-4QaHiECWYT z3cHHBwe4`=lJ@vaqR5{M-cTMT7QUC`BR z-n7k42kwA-2>f~3QBCmaSP72FEUy-r?OJ(cnt=2wVdf91tmT!LCS= z?;EbnHRwY_=ro0n&QUp+1lsoRMTo8l@w5o>hlm+hgvbR@I97~j)SO~C zs_30(3A1{@vE!?N2TJ4@xsurZV zbV)*7YOYlft4ZZ#1M!wriTKG}dV(AsohNvnJP8Y=xptGJO9_H^i3ATxIVqD!afcue z@pukNB~m1o^c?S%NaNrI-aR7gm6v$`m{ePDkfTf%7TzLHg}k=jAy<`z%O8-hMvBdk z$XO??*-yyZAe)UDK$C_dQ Z7B+JQdhhu!2NOQb2;ysGe@-NW$f|TW4B4eg0Pg%n~&db-n=*Sd#WU&j_0~3yZFR# z<}>{M&)4AtXj<-k5ur{L7;ZTCLabYJxztnIU$Z%N8-zIyoN;|z8#jKbi|{?Xf!(QR z;-EJ(PH_2i_08``Eg79HGr(4`1dx{0>{i4p;Nr%lK`c-+pHP zco~1p^T#Xr_n{O^;E#i zi&?s3ghW7CR}l~C66ixd$fqgpBax=5FEb-g)Gc5`sGZBaT$n3w>3dKtgli;gQzN4B z(g#S)A|^AVh>{EH3Y{XSEy*{^{n9<1mL&O-G2YWkG1>>_16%@R(HXGl4D#q4vgiy} zmdo7ZRLuk;FatULOSMcPi6%p+Mb?p;ElZY(-N=eg73s6gLP}D6cS-S0W;}P-!S{?P zAV!44Xg9eEGc}L|N$A!mYgcY56>PghBxn=4axZM z&o3U5vGd_?pZ-C{p9~+q+#KD!vpst7;QDB-JxkU6f7-KD;s1j@M`>i&ET3Ap8dvpd zJXbfL3C|VGOYpBt8(ne2s-U(bB_5KmLz`M7Yt?LFCd%5*)9hQeFsphi7Cuj|WKpFx z;VCe&&_JyqXK z!^11Ow!==$JYcQ5O&F*MYr^m3VW^wLtyGS+dQI)6R28)=0?y7!_Dw%E%3ke1kUVXi zK>yXd9bc&wCDFQgqrQ8Q{fmL1E7eg48d8h=B=QsPz66<=p<*_fNmd9G0czdO> GeneratePreTrainingTransformers( TransformerLevel level, - const std::unordered_set& weights_to_train, + std::unordered_set& weights_to_train, const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration& config, const IExecutionProvider& execution_provider, + std::unordered_map& updated_weight_names, const std::vector& transformers_and_rules_to_enable) { std::vector> transformers; std::unique_ptr rule_transformer = nullptr; @@ -94,7 +95,6 @@ std::vector> GeneratePreTrainingTransformers( if (config.enable_gelu_approximation) { transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); } - transformers.emplace_back(onnxruntime::make_unique(execution_provider, compatible_eps, weights_to_train)); transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); auto horizontal_parallel_size = training::DistributedRunContext::GroupSize(training::WorkerGroupType::HorizontalParallel); @@ -102,7 +102,7 @@ std::vector> GeneratePreTrainingTransformers( LOGS_DEFAULT(WARNING) << horizontal_parallel_size << "-way horizontal model parallel is enabled"; transformers.emplace_back(onnxruntime::make_unique( training::DistributedRunContext::RankInGroup(training::WorkerGroupType::HorizontalParallel), - horizontal_parallel_size, compatible_eps)); + horizontal_parallel_size, updated_weight_names, weights_to_train, compatible_eps)); } transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.h b/orttraining/orttraining/core/optimizer/graph_transformer_utils.h index d6f932742e..f5c7bb601d 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.h +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.h @@ -17,9 +17,10 @@ namespace transformer_utils { /** Generates all pre-training transformers for this level. */ std::vector> GeneratePreTrainingTransformers( TransformerLevel level, - const std::unordered_set& weights_to_train, + std::unordered_set& weights_to_train, const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration& config, const IExecutionProvider& execution_provider, // required for constant folding + std::unordered_map& updated_weight_names, const std::vector& rules_and_transformers_to_enable = {}); /** Generates all predefined (both rule-based and non-rule-based) transformers for this level. diff --git a/orttraining/orttraining/core/optimizer/megatron_transformer.cc b/orttraining/orttraining/core/optimizer/megatron_transformer.cc index 385db12f57..851086baeb 100644 --- a/orttraining/orttraining/core/optimizer/megatron_transformer.cc +++ b/orttraining/orttraining/core/optimizer/megatron_transformer.cc @@ -32,6 +32,7 @@ const std::initializer_list opset_v1_13 = {1 const std::initializer_list opset_v1_11_13 = {1, 11, 13}; const std::initializer_list opset_v2_11_13 = {2, 11, 13}; const std::initializer_list opset_v5_13 = {5, 13}; +const std::initializer_list opset_v1_6_7_13 = {1, 6, 7, 13}; const std::initializer_list opset_v7_13 = {7, 13}; const std::initializer_list opset_v9 = {9}; const std::initializer_list opset_v9_13 = {9, 13}; @@ -42,11 +43,12 @@ const OpInfo reshape_info = OpInfo("Reshape", opset_v5_13); const OpInfo transpose_info = OpInfo("Transpose", opset_v1_13); const OpInfo matmul_info = OpInfo("MatMul", opset_v9_13); const OpInfo div_info = OpInfo("Div", opset_v7_13); -const OpInfo mul_info = OpInfo("Mul", opset_v7_13); +const OpInfo mul_info = OpInfo("Mul", opset_v1_6_7_13); const OpInfo sub_info = OpInfo("Sub", opset_v7_13); const OpInfo softmax_info = OpInfo("Softmax", opset_v1_11_13); const OpInfo trainable_dropout_info = OpInfo("TrainableDropout", opset_v9, kOnnxDomain); const OpInfo dropout_info = OpInfo("Dropout", opset_v12_13); +const OpInfo where_info = OpInfo("Where", opset_v9); struct NodeInfo { NodeInfo(const std::vector& op_infos, @@ -119,7 +121,6 @@ bool MegatronTransformer::PartitionWeightByColumn(const Graph& graph, const Node LOGS_DEFAULT(WARNING) << "PartitionWeightByColumn: " << input_arg.Name() << " is not an initializer"; return false; } - auto data_type = tensor_proto->data_type(); const ONNX_NAMESPACE::TensorShapeProto* shape = input_arg.Shape(); int rank = shape->dim_size(); @@ -147,8 +148,9 @@ bool MegatronTransformer::PartitionWeightByColumn(const Graph& graph, const Node auto initializer = onnxruntime::make_unique(*tensor_proto, graph.ModelPath()); const float* a_weight = initializer->data(); - initializer_partition.set_name("rank_" + std::to_string(horizontal_parallel_rank_) + - "_" + input_arg.Name() + "_partition"); + std::string new_initializer_name = input_arg.Name() + "_column_rank_" + std::to_string(horizontal_parallel_rank_); + + initializer_partition.set_name(new_initializer_name); initializer_partition.set_data_type(data_type); int64_t column_partition = column_count / horizontal_parallel_size_; @@ -209,12 +211,12 @@ bool MegatronTransformer::PartitionWeightByRow(const Graph& graph, const NodeArg << horizontal_parallel_size_ << ", not supported currently."; return false; } - auto initializer = onnxruntime::make_unique(*tensor_proto, graph.ModelPath()); const float* a_weight = initializer->data(); - initializer_partition.set_name("rank_" + std::to_string(horizontal_parallel_rank_) + - "_" + input_arg.Name() + "_partition"); + std::string new_initializer_name = input_arg.Name() + "_row_rank_" + std::to_string(horizontal_parallel_rank_); + + initializer_partition.set_name(new_initializer_name); initializer_partition.set_data_type(data_type); int64_t row_partition = row_count / horizontal_parallel_size_; @@ -231,13 +233,13 @@ bool MegatronTransformer::PartitionWeightByRow(const Graph& graph, const NodeArg const int64_t row_index_offset = horizontal_parallel_rank_ * row_partition; memcpy(result.data(), a_weight + row_index_offset * column_count, sizeof(float) * element_count); initializer_partition.set_raw_data(result.data(), element_count * sizeof(float)); - return true; } Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger, - std::vector& nodes_to_clear_shape) const { + std::vector& nodes_to_clear_shape, + int32_t& counter) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); for (auto node_index : node_topology_list) { @@ -309,12 +311,15 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph NodeArg& a_weight_partition_arg = graph_utils::AddInitializer(graph, a_weight_initializer_partition); graph_utils::ReplaceNodeInput(node, 1, a_weight_partition_arg); + updated_weight_names_.insert({a_weight_arg->Name(), a_weight_partition_arg.Name()}); NodeArg& a_bias_partition_arg = graph_utils::AddInitializer(graph, a_bias_initializer_partition); graph_utils::ReplaceNodeInput(add_node, 1, a_bias_partition_arg); + updated_weight_names_.insert({b_weight_arg->Name(), a_bias_partition_arg.Name()}); NodeArg& b_weight_partition_arg = graph_utils::AddInitializer(graph, b_weight_initializer_partition); graph_utils::ReplaceNodeInput(matmul2_node, 1, b_weight_partition_arg); + updated_weight_names_.insert({a_bias_arg->Name(), b_weight_partition_arg.Name()}); graph.RemoveInitializedTensor(a_weight_arg->Name()); graph.RemoveInitializedTensor(b_weight_arg->Name()); @@ -349,6 +354,151 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph mlp_g_node.SetExecutionProviderType(node.GetExecutionProviderType()); graph_utils::ReplaceDownstreamNodeInput(graph, matmul2_node, 0, mlp_g_node, 0); modified = true; + counter++; + } + + return Status::OK(); +} + +/* +DenseWeight -- Transpose \ + MatMul -- BiasGelu -- Dropout -- MatMul -- Add -- Dropout +*/ +Status MegatronTransformer::TransformBARTMLP(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger, + std::vector& nodes_to_clear_shape, + std::unordered_set& dropout_nodes_to_transform, int32_t& counter) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + for (auto node_index : node_topology_list) { + auto& node = *graph.GetNode(node_index); + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9, 13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || + node.GetOutputEdgesCount() != 1) { + continue; + } + Node* second_op = const_cast(graph.GetProducerNode(node.MutableInputDefs()[1]->Name())); + Node* first_op = const_cast(graph.GetProducerNode(node.MutableInputDefs()[0]->Name())); + if (node.GetInputEdgesCount() > 0) { + if (second_op == nullptr) { + break; + } + if (first_op != nullptr && first_op->OpType().compare("MegatronF") == 0) { + continue; + } + + if (second_op->OpType().compare("Transpose") != 0) { + continue; + } + } else { + continue; + } + // check if transpose is only 2-dim + if (!optimizer_utils::IsAttributeWithExpectedValues(*second_op, "perm", {1LL, 0LL})) { + continue; + } + ProviderType provider_type = node.GetExecutionProviderType(); + + Node* biasgelu_node_ptr = graph.GetNode(node.OutputNodesBegin()->Index()); + Node& biasgelu_node = *biasgelu_node_ptr; + if (!graph_utils::IsSupportedOptypeVersionAndDomain(biasgelu_node, "BiasGelu", {1}, kMSDomain) || + biasgelu_node.GetExecutionProviderType() != provider_type || + biasgelu_node.GetOutputEdgesCount() != 1) { + continue; + } + Node& dropout_node = *graph.GetNode(biasgelu_node.OutputNodesBegin()->Index()); + if (!IsExpectedOpAndProvider(dropout_node, dropout_info, provider_type)) { + continue; + } + Node& matmul2_node = *graph.GetNode(dropout_node.OutputNodesBegin()->Index()); + if (!IsExpectedOpAndProvider(matmul2_node, matmul_info, provider_type)) { + continue; + } + Node& add_node = *graph.GetNode(matmul2_node.OutputNodesBegin()->Index()); + if (!IsExpectedOpAndProvider(add_node, add_info, provider_type)) { + continue; + } + Node& dropout2_node = *graph.GetNode(add_node.OutputNodesBegin()->Index()); + if (!IsExpectedOpAndProvider(dropout2_node, dropout_info, provider_type)) { + continue; + } + Node* transpose_op_ptr = const_cast(graph.GetProducerNode(matmul2_node.MutableInputDefs()[1]->Name())); + if (transpose_op_ptr == nullptr || !IsExpectedOpAndProvider(*transpose_op_ptr, transpose_info, provider_type)) { + continue; + } + + nodes_to_clear_shape.insert(nodes_to_clear_shape.end(), {&node, second_op, &biasgelu_node, &dropout_node, + &matmul2_node, transpose_op_ptr}); + + auto dense_wi_weight_arg = second_op->MutableInputDefs()[0]; + ONNX_NAMESPACE::TensorProto dense_wi_weight_initializer_partition; + if (!PartitionWeightByRow(graph, *dense_wi_weight_arg, dense_wi_weight_initializer_partition)) { + continue; + } + + //since the bias doesnt get transposed, partitioning by col + auto dense_wi_bias_arg = biasgelu_node.MutableInputDefs()[1]; + ONNX_NAMESPACE::TensorProto dense_wi_bias_initializer_partition; + if (!PartitionWeightByColumn(graph, *dense_wi_bias_arg, dense_wi_bias_initializer_partition)) { + continue; + } + + auto dense_wo_weight_arg = transpose_op_ptr->MutableInputDefs()[0]; + ONNX_NAMESPACE::TensorProto dense_wo_weight_initializer_partition; + if (!PartitionWeightByColumn(graph, *dense_wo_weight_arg, dense_wo_weight_initializer_partition)) { + continue; + } + + NodeArg& dense_wi_weight_partition_arg = graph_utils::AddInitializer(graph, dense_wi_weight_initializer_partition); + graph_utils::ReplaceNodeInput(*second_op, 0, dense_wi_weight_partition_arg); + updated_weight_names_.insert({dense_wi_weight_arg->Name(), dense_wi_weight_partition_arg.Name()}); + + NodeArg& dense_wi_bias_partition_arg = graph_utils::AddInitializer(graph, dense_wi_bias_initializer_partition); + graph_utils::ReplaceNodeInput(biasgelu_node, 1, dense_wi_bias_partition_arg); + updated_weight_names_.insert({dense_wi_bias_arg->Name(), dense_wi_bias_partition_arg.Name()}); + + NodeArg& dense_wo_weight_partition_arg = graph_utils::AddInitializer(graph, dense_wo_weight_initializer_partition); + graph_utils::ReplaceNodeInput(*transpose_op_ptr, 0, dense_wo_weight_partition_arg); + updated_weight_names_.insert({dense_wo_weight_arg->Name(), dense_wo_weight_partition_arg.Name()}); + + graph.RemoveInitializedTensor(dense_wi_weight_arg->Name()); + graph.RemoveInitializedTensor(dense_wi_bias_arg->Name()); + graph.RemoveInitializedTensor(dense_wo_weight_arg->Name()); + + dropout_nodes_to_transform.insert(&dropout_node); + + const std::vector mlp_f_input_defs{node.MutableInputDefs()[0]}; + auto mlp_f_type_info = *node.MutableInputDefs()[0]->TypeAsProto(); + auto& mlp_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("BART_MLP_MegatronF_Output"), &mlp_f_type_info); + Node& mlp_f_node = graph.AddNode(graph.GenerateNodeName("BART_MLP_MegatronF"), + "MegatronF", + "MLP MegatronF", + mlp_f_input_defs, + {&mlp_f_out_arg}, {}, kMSDomain); + counter++; + mlp_f_node.SetExecutionProviderType(node.GetExecutionProviderType()); + const Node::EdgeEnd* edge = graph_utils::GetInputEdge(node, 0); + if (nullptr == edge) { // handle input/initializer + graph_utils::ReplaceNodeInput(node, 0, *(mlp_f_node.MutableOutputDefs()[0])); + } else { + auto input_node = const_cast(&edge->GetNode()); + graph_utils::ReplaceDownstreamNodeInput(graph, *input_node, edge->GetSrcArgIndex(), mlp_f_node, 0); + } + + const std::vector mlp_g_input_defs{matmul2_node.MutableOutputDefs()[0]}; + auto mlp_g_type_info = *matmul2_node.MutableOutputDefs()[0]->TypeAsProto(); + auto& mlp_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("BART_MLP_MegatronG_Output"), &mlp_g_type_info); + Node& mlp_g_node = graph.AddNode(graph.GenerateNodeName("BART_MLP_MegatronG"), + "MegatronG", + "MLP MegatronG", + mlp_g_input_defs, + {&mlp_g_out_arg}, {}, kMSDomain); + mlp_g_node.AddAttribute("group_type", static_cast(training::WorkerGroupType::HorizontalParallel)); + mlp_g_node.SetExecutionProviderType(node.GetExecutionProviderType()); + graph_utils::ReplaceDownstreamNodeInput(graph, matmul2_node, 0, mlp_g_node, 0); + modified = true; } return Status::OK(); @@ -357,7 +507,8 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger, std::vector& nodes_to_clear_shape, - std::unordered_set& self_attention_dropout_nodes) const { + std::unordered_set& dropout_nodes_to_transform, + int32_t& counter) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); @@ -406,7 +557,7 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified, // Get all useful nodes here as more vector push back below will change the index. Node& add_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 15]; Node& split_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 14]; - Node& transpose_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 12]; + Node& k_transpose_after_reshape_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 12]; Node* matmul_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 11]; Node* dropout_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 6]; Node* matmul_node_ptr1 = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 5]; @@ -414,7 +565,7 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified, Node& matmul_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 2]; // Transpose node attribute checking. - if (!optimizer_utils::IsAttributeWithExpectedValues(transpose_node, "perm", {0LL, 2LL, 1LL, 3LL}) || + if (!optimizer_utils::IsAttributeWithExpectedValues(k_transpose_after_reshape_node, "perm", {0LL, 2LL, 1LL, 3LL}) || !optimizer_utils::IsAttributeWithExpectedValues(transpose_node1, "perm", {0LL, 2LL, 1LL, 3LL})) { continue; } @@ -518,12 +669,15 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified, // Replace by the partition weights. NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializer(graph, qkv_weight_initializer_partition); graph_utils::ReplaceNodeInput(node, 1, qkv_weight_partition_arg); + updated_weight_names_.insert({qkv_weight_arg->Name(), qkv_weight_partition_arg.Name()}); NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializer(graph, qkv_bias_initializer_partition); graph_utils::ReplaceNodeInput(add_node, 1, qkv_bias_partition_arg); + updated_weight_names_.insert({qkv_bias_arg->Name(), qkv_bias_partition_arg.Name()}); NodeArg& dense_weight_partition_arg = graph_utils::AddInitializer(graph, dense_weight_initializer_partition); graph_utils::ReplaceNodeInput(matmul_node, 1, dense_weight_partition_arg); + updated_weight_names_.insert({dense_weight_arg->Name(), dense_weight_partition_arg.Name()}); graph.RemoveInitializedTensor(qkv_weight_arg->Name()); graph.RemoveInitializedTensor(qkv_bias_arg->Name()); @@ -554,16 +708,16 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified, } if (dropout_node_ptr != nullptr) { - self_attention_dropout_nodes.insert(dropout_node_ptr); + dropout_nodes_to_transform.insert(dropout_node_ptr); } // Add MegatronF before the 1st MatMul and MegatronG before the last Add. const std::vector sa_f_input_defs{node.MutableInputDefs()[0]}; auto sa_f_type_info = *node.MutableInputDefs()[0]->TypeAsProto(); - auto& sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("SelfAttention_MegatronF_Output"), &sa_f_type_info); - Node& sa_f_node = graph.AddNode(graph.GenerateNodeName("SelfAttention_MegatronF"), + auto& sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("SeftAttention_MegatronF_Output"), &sa_f_type_info); + Node& sa_f_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "SeftAttention_MegatronF"), "MegatronF", - "SelfAttention MegatronF", + "SeftAttention MegatronF", sa_f_input_defs, {&sa_f_out_arg}, {}, kMSDomain); sa_f_node.SetExecutionProviderType(node.GetExecutionProviderType()); @@ -577,23 +731,400 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified, const std::vector sa_g_input_defs{matmul_node.MutableOutputDefs()[0]}; auto sa_g_type_info = *matmul_node.MutableOutputDefs()[0]->TypeAsProto(); // copy - auto& sa_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("SelfAttention_MegatronG_Output"), &sa_g_type_info); - Node& sa_g_node = graph.AddNode(graph.GenerateNodeName("SelfAttention_MegatronG"), + auto& sa_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("SeftAttention_MegatronG_Output"), &sa_g_type_info); + Node& sa_g_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "SelfAttention_MegatronG"), "MegatronG", - "SelfAttention MegatronG", + "Attention MegatronG", sa_g_input_defs, {&sa_g_out_arg}, {}, kMSDomain); sa_g_node.AddAttribute("group_type", static_cast(training::WorkerGroupType::HorizontalParallel)); sa_g_node.SetExecutionProviderType(node.GetExecutionProviderType()); graph_utils::ReplaceDownstreamNodeInput(graph, matmul_node, 0, sa_g_node, 0); modified = true; + counter++; + } + + return Status::OK(); +} + +Status MegatronTransformer::TransformBARTSelfAttention(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger, + std::vector& nodes_to_clear_shape, + std::unordered_set& dropout_nodes_to_transform, + int32_t& counter) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + // Self attention sub-graph. + // + // MatMul->Add->Mul->Reshape->Transpose->MatMul->Reshape->Where->Reshape->Softmax->Dropout->MatMul->Transpose->Reshape->MatMul->Add->Droupout + // MatMul->Add->Reshape->Transpose-------> | | + // MatMul->Add->Reshape->Transpose----------------------------------------------------------> | + for (auto node_index : node_topology_list) { + auto& node = *graph.GetNode(node_index); + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", opset_v9_13) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || + node.GetOutputEdgesCount() != 1) { + continue; + } + + Node* q_matmul_input_node_ptr = const_cast(graph.GetProducerNode(node.MutableInputDefs()[0]->Name())); + if (q_matmul_input_node_ptr != nullptr && q_matmul_input_node_ptr->OpType().compare("MegatronF") == 0) { + continue; + } + std::vector sub_graph_node_ptrs; + sub_graph_node_ptrs.push_back(&node); + ProviderType provider_type = node.GetExecutionProviderType(); + + std::vector linear_pattern = { + NodeInfo({add_info}), + NodeInfo({mul_info}), + NodeInfo({reshape_info}), + NodeInfo({transpose_info}), + NodeInfo({matmul_info}), + NodeInfo({add_info}, false), // -13 + NodeInfo({reshape_info}), + NodeInfo({where_info}), + NodeInfo({reshape_info}), + NodeInfo({softmax_info}), + NodeInfo({dropout_info}, false), // -8 + NodeInfo({matmul_info}), + NodeInfo({add_info}, false), // -6 + NodeInfo({transpose_info}), + NodeInfo({reshape_info}), + NodeInfo({matmul_info}), // -3 + NodeInfo({add_info}), + NodeInfo({dropout_info}, false)}; // -1 + if (!MatchLinearPattern(graph, &node, provider_type, linear_pattern, sub_graph_node_ptrs)) { + continue; + } + // Get all useful nodes here as more vector push back below will change the index. + // Other than the optional nodes in the pattern, all other node pointers are valid + // if they match the linear pattern. + Node* q_biasadd_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 18]; + Node* q_transpose_after_reshape_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 15]; + Node* qk_matmul_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 14]; + Node* dropout_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 8]; + Node* qkv_matmul_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 7]; + Node* transpose_node1_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 5]; + Node& dense_matmul_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 3]; + + // Transpose node attribute checking. + if (!optimizer_utils::IsAttributeWithExpectedValues(*q_transpose_after_reshape_node_ptr, "perm", {1LL, 0LL, 2LL}) || + !optimizer_utils::IsAttributeWithExpectedValues(*transpose_node1_ptr, "perm", {1LL, 0LL, 2LL})) { + continue; + } + // map between reshape node and dim of reshape that must be modified + std::unordered_map reshape_node_ptrs; + reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 16]] = 1; + reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 12]] = 1; + reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 10]] = 0; + reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 4]] = 2; + // till now node should be q matmul operation + + std::vector weight_transpose_node_ptrs; + std::vector bias_add_node_ptrs; + + Node* q_transpose_ptr = const_cast(graph.GetProducerNode(node.MutableInputDefs()[1]->Name())); + if (q_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*q_transpose_ptr, transpose_info, provider_type)) { + continue; + } + weight_transpose_node_ptrs.push_back(q_transpose_ptr); + sub_graph_node_ptrs.push_back(q_transpose_ptr); + bias_add_node_ptrs.push_back(q_biasadd_node_ptr); + + Node* k_transpose_ptr = const_cast(graph.GetProducerNode(qk_matmul_node_ptr->MutableInputDefs()[1]->Name())); + if (k_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*k_transpose_ptr, transpose_info, provider_type)) { + continue; + } + sub_graph_node_ptrs.push_back(k_transpose_ptr); + + Node* k_reshape_ptr = const_cast(graph.GetProducerNode(k_transpose_ptr->MutableInputDefs()[0]->Name())); + if (k_reshape_ptr == nullptr || !IsExpectedOpAndProvider(*k_reshape_ptr, reshape_info, provider_type)) { + continue; + } + reshape_node_ptrs[k_reshape_ptr] = 1; + sub_graph_node_ptrs.push_back(k_reshape_ptr); + + Node* k_add_ptr = const_cast(graph.GetProducerNode(k_reshape_ptr->MutableInputDefs()[0]->Name())); + if (k_add_ptr == nullptr || !IsExpectedOpAndProvider(*k_add_ptr, add_info, provider_type)) { + continue; + } + sub_graph_node_ptrs.push_back(k_add_ptr); + bias_add_node_ptrs.push_back(k_add_ptr); + + Node* k_matmul_ptr = const_cast(graph.GetProducerNode(k_add_ptr->MutableInputDefs()[0]->Name())); + if (k_matmul_ptr == nullptr || !IsExpectedOpAndProvider(*k_matmul_ptr, matmul_info, provider_type)) { + continue; + } + sub_graph_node_ptrs.push_back(k_matmul_ptr); + + Node* k_weight_transpose_ptr = const_cast(graph.GetProducerNode(k_matmul_ptr->MutableInputDefs()[1]->Name())); + if (k_weight_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*k_weight_transpose_ptr, transpose_info, provider_type)) { + continue; + } + sub_graph_node_ptrs.push_back(k_weight_transpose_ptr); + weight_transpose_node_ptrs.push_back(k_weight_transpose_ptr); + + Node* v_transpose_ptr = const_cast(graph.GetProducerNode(qkv_matmul_node_ptr->MutableInputDefs()[1]->Name())); + if (v_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*v_transpose_ptr, transpose_info, provider_type)) { + continue; + } + sub_graph_node_ptrs.push_back(v_transpose_ptr); + + Node* v_reshape_ptr = const_cast(graph.GetProducerNode(v_transpose_ptr->MutableInputDefs()[0]->Name())); + if (v_reshape_ptr == nullptr || !IsExpectedOpAndProvider(*v_reshape_ptr, reshape_info, provider_type)) { + continue; + } + reshape_node_ptrs[v_reshape_ptr] = 1; + sub_graph_node_ptrs.push_back(v_reshape_ptr); + + Node* v_add_ptr = const_cast(graph.GetProducerNode(v_reshape_ptr->MutableInputDefs()[0]->Name())); + if (v_add_ptr == nullptr || !IsExpectedOpAndProvider(*v_add_ptr, add_info, provider_type)) { + continue; + } + sub_graph_node_ptrs.push_back(v_add_ptr); + bias_add_node_ptrs.push_back(v_add_ptr); + + Node* v_matmul_ptr = const_cast(graph.GetProducerNode(v_add_ptr->MutableInputDefs()[0]->Name())); + if (k_matmul_ptr == nullptr || !IsExpectedOpAndProvider(*k_matmul_ptr, matmul_info, provider_type)) { + continue; + } + sub_graph_node_ptrs.push_back(v_matmul_ptr); + + Node* v_weight_transpose_ptr = const_cast(graph.GetProducerNode(v_matmul_ptr->MutableInputDefs()[1]->Name())); + if (v_weight_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*v_weight_transpose_ptr, transpose_info, provider_type)) { + continue; + } + sub_graph_node_ptrs.push_back(v_weight_transpose_ptr); + weight_transpose_node_ptrs.push_back(v_weight_transpose_ptr); + + // K and V matmul must have the same input + Node* q_matmul_ptr = &node; + if (k_matmul_ptr->MutableInputDefs()[0]->Name() != v_matmul_ptr->MutableInputDefs()[0]->Name()) { + continue; + } + + // Check the constant value in the Reshape nodes. + bool is_reshape_valid = true; + for (auto x : reshape_node_ptrs) { + Node* node_ptr = x.first; + int64_t idx = x.second; + auto shape_arg = node_ptr->MutableInputDefs()[1]; + const ONNX_NAMESPACE::TensorProto* tensor; + if (!graph.GetInitializedTensor(shape_arg->Name(), tensor)) { + is_reshape_valid = false; + break; + } + auto data_type = tensor->data_type(); + if (data_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { + is_reshape_valid = false; + break; + } + // The number of the values should be more than idx, and the idx'th value should be divisible by parallel size, + // i.e., the attention head number should be divisible by parallel size. + auto init_const = onnxruntime::make_unique(*tensor, graph.ModelPath()); + if (init_const->size() <= idx) { + is_reshape_valid = false; + break; + } + const int64_t* val = init_const->data(); + if (val[idx] % horizontal_parallel_size_ != 0) { + LOGS_DEFAULT(WARNING) << "dim[" << idx << "]: " << val[idx] + << " is not divisible by horizontal_parallel_size_ " + << horizontal_parallel_size_ << ", not supported currently."; + is_reshape_valid = false; + break; + } + } + + if (!is_reshape_valid) { + continue; + } + + // Partition weights. If any of them fails, skip transforming the rest. + std::vector qkv_weight_initializer_partitions; + for (auto trans_ptr : weight_transpose_node_ptrs) { + auto qkv_weight_arg = trans_ptr->MutableInputDefs()[0]; + ONNX_NAMESPACE::TensorProto qkv_weight_initializer_partition; + if (!PartitionWeightByRow(graph, *qkv_weight_arg, qkv_weight_initializer_partition)) { + break; + } + qkv_weight_initializer_partitions.push_back(qkv_weight_initializer_partition); + } + + // Partition bias. If any of them fails, skip transforming the rest. + std::vector qkv_bias_initializer_partitions; + for (auto add_ptr : bias_add_node_ptrs) { + auto qkv_bias_arg = add_ptr->MutableInputDefs()[1]; + ONNX_NAMESPACE::TensorProto qkv_bias_initializer_partition; + if (!PartitionWeightByColumn(graph, *qkv_bias_arg, qkv_bias_initializer_partition)) { + break; + } + qkv_bias_initializer_partitions.push_back(qkv_bias_initializer_partition); + } + + // if all the weights or biases weren't transformed, skip transforming this subgraph + if (weight_transpose_node_ptrs.size() != qkv_weight_initializer_partitions.size()) { + continue; + } + if (bias_add_node_ptrs.size() != qkv_bias_initializer_partitions.size()) { + continue; + } + + // transform the dense weight. If it fails, skip transforming this subgraph. + Node* last_transpose = const_cast(graph.GetProducerNode(dense_matmul_node.MutableInputDefs()[1]->Name())); + auto dense_weight_arg = last_transpose->MutableInputDefs()[0]; + ONNX_NAMESPACE::TensorProto dense_weight_initializer_partition; + if (!PartitionWeightByColumn(graph, *dense_weight_arg, dense_weight_initializer_partition)) { + continue; + } + + // Ready to transform the sub-graph when reach here. + // Replace node inputs + size_t i = 0; + for (auto trans_ptr : weight_transpose_node_ptrs) { + auto weight_name = trans_ptr->MutableInputDefs()[0]->Name(); + NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializer(graph, qkv_weight_initializer_partitions[i]); + graph_utils::ReplaceNodeInput(*trans_ptr, 0, qkv_weight_partition_arg); + graph.RemoveInitializedTensor(weight_name); + updated_weight_names_.insert({weight_name, qkv_weight_partition_arg.Name()}); + i++; + } + i = 0; + for (auto add_ptr : bias_add_node_ptrs) { + auto bias_name = add_ptr->MutableInputDefs()[1]->Name(); + NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializer(graph, qkv_bias_initializer_partitions[i]); + graph_utils::ReplaceNodeInput(*add_ptr, 1, qkv_bias_partition_arg); + graph.RemoveInitializedTensor(bias_name); + updated_weight_names_.insert({bias_name, qkv_bias_partition_arg.Name()}); + i++; + } + + NodeArg& dense_weight_partition_arg = graph_utils::AddInitializer(graph, dense_weight_initializer_partition); + graph_utils::ReplaceNodeInput(*last_transpose, 0, dense_weight_partition_arg); + graph.RemoveInitializedTensor(dense_weight_arg->Name()); + updated_weight_names_.insert({dense_weight_arg->Name(), dense_weight_partition_arg.Name()}); + + // It's possible that the node vector contains nullptr due to some optinal node infos during linear pattern matching. + std::copy_if(sub_graph_node_ptrs.begin(), sub_graph_node_ptrs.end(), + std::back_inserter(nodes_to_clear_shape), + [](Node* node_ptr) { return node_ptr != nullptr; }); + + // Change the constant for the reshape nodes. + for (auto x : reshape_node_ptrs) { + Node* node_ptr = x.first; + int64_t idx = x.second; + auto shape_arg = node_ptr->MutableInputDefs()[1]; + const ONNX_NAMESPACE::TensorProto* tensor; + graph.GetInitializedTensor(shape_arg->Name(), tensor); + auto data_type = tensor->data_type(); + auto init_const = onnxruntime::make_unique(*tensor, graph.ModelPath()); + const int64_t* val = init_const->data(); + int64_t size = init_const->size(); + ONNX_NAMESPACE::TensorProto tensor_partition; + tensor_partition.set_name(graph.GenerateNodeArgName("partition_" + shape_arg->Name())); + tensor_partition.set_data_type(data_type); + tensor_partition.add_dims(size); + + std::vector val_partition; + val_partition.reserve(size); + val_partition.insert(val_partition.end(), val, val + size); + val_partition[idx] /= horizontal_parallel_size_; + tensor_partition.set_raw_data(val_partition.data(), size * sizeof(int64_t)); + NodeArg& node_arg_partition = graph_utils::AddInitializer(graph, tensor_partition); + graph_utils::ReplaceNodeInput(*node_ptr, 1, node_arg_partition); + graph.RemoveInitializedTensor(shape_arg->Name()); + } + + if (dropout_node_ptr != nullptr) { + dropout_nodes_to_transform.insert(dropout_node_ptr); + } + + // Add MegatronF before the 1st MatMul and MegatronG before the last Add. + + NodeArg* prev_input_node_ptr = k_matmul_ptr->MutableInputDefs()[0]; + std::vector new_consumer_nodes; + const auto& node_consumers = graph.GetConsumerNodes(prev_input_node_ptr->Name()); + for (auto& n : node_consumers) { + if (n->Index() == k_matmul_ptr->Index() || n->Index() == v_matmul_ptr->Index() || n->Index() == q_matmul_ptr->Index()) { + continue; + } + new_consumer_nodes.emplace_back(const_cast(n)); + } + + bool shared_same_input = k_matmul_ptr->MutableInputDefs()[0]->Name().compare(q_matmul_ptr->MutableInputDefs()[0]->Name()) == 0; + + //then for q, and k&v will have different MegatronF node. + { + const std::vector sa_f_input_defs{prev_input_node_ptr}; + auto sa_f_type_info = *prev_input_node_ptr->TypeAsProto(); + auto& sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(k_matmul_ptr->Name() + "BARTAttention_MegatronF_Output"), &sa_f_type_info); + Node& sa_f_node = graph.AddNode(graph.GenerateNodeName(k_matmul_ptr->Name() + "BARTAttention_MegatronF"), + "MegatronF", + k_matmul_ptr->Name() + " BARTAttention MegatronF", + sa_f_input_defs, + {&sa_f_out_arg}, {}, kMSDomain); + sa_f_node.SetExecutionProviderType(k_matmul_ptr->GetExecutionProviderType()); + graph_utils::ReplaceNodeInput(*k_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0])); + graph_utils::ReplaceNodeInput(*v_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0])); + if (shared_same_input) { + graph_utils::ReplaceNodeInput(*q_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0])); + } + new_consumer_nodes.push_back(&sa_f_node); + } + graph.UpdateConsumerNodes(prev_input_node_ptr->Name(), new_consumer_nodes); + counter++; + if (!shared_same_input) { + { + NodeArg* q_prev_input_node_ptr = q_matmul_ptr->MutableInputDefs()[0]; + std::vector q_new_consumer_nodes; + const auto& q_node_consumers = graph.GetConsumerNodes(q_prev_input_node_ptr->Name()); + for (auto& n : q_node_consumers) { + if (n->Index() == k_matmul_ptr->Index() || n->Index() == v_matmul_ptr->Index() || n->Index() == q_matmul_ptr->Index()) { + continue; + } + q_new_consumer_nodes.emplace_back(const_cast(n)); + } + + const std::vector q_sa_f_input_defs{q_matmul_ptr->MutableInputDefs()[0]}; + auto q_sa_f_type_info = *q_matmul_ptr->MutableInputDefs()[0]->TypeAsProto(); + auto& q_sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(q_matmul_ptr->Name() + "BARTAttention_MegatronF_Output"), &q_sa_f_type_info); + Node& q_sa_f_node = graph.AddNode(graph.GenerateNodeName(q_matmul_ptr->Name() + "BARTAttention_MegatronF"), + "MegatronF", + q_matmul_ptr->Name() + " BARTAttention MegatronF", + q_sa_f_input_defs, + {&q_sa_f_out_arg}, {}, kMSDomain); + q_sa_f_node.SetExecutionProviderType(q_matmul_ptr->GetExecutionProviderType()); + + graph_utils::ReplaceNodeInput(*q_matmul_ptr, 0, *(q_sa_f_node.MutableOutputDefs()[0])); + q_new_consumer_nodes.push_back(&q_sa_f_node); + graph.UpdateConsumerNodes(q_prev_input_node_ptr->Name(), q_new_consumer_nodes); + // todo: need update the consumer node for the input_node as well. + } + } + + const std::vector sa_g_input_defs{dense_matmul_node.MutableOutputDefs()[0]}; + auto sa_g_type_info = *dense_matmul_node.MutableOutputDefs()[0]->TypeAsProto(); // copy + auto& sa_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("BARTAttention_MegatronG_Output"), &sa_g_type_info); + Node& sa_g_node = graph.AddNode(graph.GenerateNodeName(k_matmul_ptr->Name() + "BARTAttention_MegatronG"), + "MegatronG", + "BARTAttention MegatronG", + sa_g_input_defs, + {&sa_g_out_arg}, {}, kMSDomain); + sa_g_node.AddAttribute("group_type", static_cast(training::WorkerGroupType::HorizontalParallel)); + sa_g_node.SetExecutionProviderType(k_matmul_ptr->GetExecutionProviderType()); + graph_utils::ReplaceDownstreamNodeInput(graph, dense_matmul_node, 0, sa_g_node, 0); + + modified = true; } return Status::OK(); } Status MegatronTransformer::TransformDropout(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger, - std::unordered_set& self_attention_dropout_nodes) const { + std::unordered_set& dropout_nodes_to_transform, int32_t& counter) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); for (auto node_index : node_topology_list) { @@ -610,18 +1141,17 @@ Status MegatronTransformer::TransformDropout(Graph& graph, bool& modified, int g } // Only need to set the seed if it's a transformed self-attention dropout, or the seed attribute is not set. - if (self_attention_dropout_nodes.find(&node) != self_attention_dropout_nodes.end() || - graph_utils::GetNodeAttribute(node, "seed") == nullptr) { + if (dropout_nodes_to_transform.find(&node) != dropout_nodes_to_transform.end()) { int64_t seed = static_cast(HashName(node.MutableOutputDefs()[0]->Name())) + utils::GetRandomSeed(); - if (self_attention_dropout_nodes.find(&node) != self_attention_dropout_nodes.end()) { + if (dropout_nodes_to_transform.find(&node) != dropout_nodes_to_transform.end()) { seed += horizontal_parallel_rank_; } if (graph_utils::GetNodeAttribute(node, "seed") != nullptr) { node.ClearAttribute("seed"); } - node.AddAttribute("seed", seed); + counter++; modified = true; } } @@ -635,11 +1165,19 @@ Status MegatronTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_le } std::vector nodes_to_clear_shape; - std::unordered_set self_attention_dropout_nodes; + std::unordered_set dropout_nodes_to_transform; - ORT_ENFORCE(TransformMLP(graph, modified, graph_level, logger, nodes_to_clear_shape).IsOK()); - ORT_ENFORCE(TransformSelfAttention(graph, modified, graph_level, logger, nodes_to_clear_shape, self_attention_dropout_nodes).IsOK()); - ORT_ENFORCE(TransformDropout(graph, modified, graph_level, logger, self_attention_dropout_nodes).IsOK()); + int32_t partitioned_mlp_count = 0; + int32_t partitioned_bart_mlp_count = 0; + int32_t partitioned_attention_count = 0; + int32_t partitioned_bart_attention_count = 0; + int32_t dropout_changed = 0; + + ORT_ENFORCE(TransformMLP(graph, modified, graph_level, logger, nodes_to_clear_shape, partitioned_mlp_count).IsOK()); + ORT_ENFORCE(TransformBARTMLP(graph, modified, graph_level, logger, nodes_to_clear_shape, dropout_nodes_to_transform, partitioned_bart_mlp_count).IsOK()); + ORT_ENFORCE(TransformSelfAttention(graph, modified, graph_level, logger, nodes_to_clear_shape, dropout_nodes_to_transform, partitioned_attention_count).IsOK()); + ORT_ENFORCE(TransformBARTSelfAttention(graph, modified, graph_level, logger, nodes_to_clear_shape, dropout_nodes_to_transform, partitioned_bart_attention_count).IsOK()); + ORT_ENFORCE(TransformDropout(graph, modified, graph_level, logger, dropout_nodes_to_transform, dropout_changed).IsOK()); auto& graph_inputs = graph.GetInputs(); for (auto& node : nodes_to_clear_shape) { @@ -653,13 +1191,28 @@ Status MegatronTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_le output->ClearShape(); } + for (auto x : updated_weight_names_) { + auto old_initializer_name = x.first; + auto new_initializer_name = x.second; + if (weights_to_train_.find(old_initializer_name) != weights_to_train_.end()) { + weights_to_train_.erase(old_initializer_name); + weights_to_train_.insert(new_initializer_name); + } + } + if (modified) { graph.SetGraphResolveNeeded(); auto ret = graph.Resolve(); + LOGS_DEFAULT(WARNING) << "Megatron transformer result : Partitioned " << partitioned_mlp_count << " MLP Blocks, " + << partitioned_bart_mlp_count << " BART MLP Blocks, " << partitioned_attention_count << " Attention Blocks, " + << partitioned_bart_attention_count << " BART Attention Blocks; Reset seed for " << dropout_changed + << " Dropout nodes. Error Message: " << ret.ErrorMessage() << std::endl; ORT_ENFORCE(ret.IsOK()); + } else { + LOGS_DEFAULT(WARNING) << "Megatron transformer result : unmodified\n"; } return Status::OK(); } -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/core/optimizer/megatron_transformer.h b/orttraining/orttraining/core/optimizer/megatron_transformer.h index 168d70748a..02690d00cd 100644 --- a/orttraining/orttraining/core/optimizer/megatron_transformer.h +++ b/orttraining/orttraining/core/optimizer/megatron_transformer.h @@ -11,33 +11,50 @@ namespace onnxruntime { class MegatronTransformer : public GraphTransformer { public: MegatronTransformer(int32_t horizontal_parallel_rank, int32_t horizontal_parallel_size, + std::unordered_map& updated_weight_names, + std::unordered_set& weights_to_train, const std::unordered_set& compatible_execution_providers = {}) noexcept : GraphTransformer("MegatronTransformer", compatible_execution_providers), horizontal_parallel_rank_(horizontal_parallel_rank), - horizontal_parallel_size_(horizontal_parallel_size) {} + horizontal_parallel_size_(horizontal_parallel_size), + updated_weight_names_(updated_weight_names), + weights_to_train_(weights_to_train) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; private: Status TransformMLP(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger, - std::vector& nodes_to_clear_shape) const; + std::vector& nodes_to_clear_shape, + int32_t& counter) const; Status TransformSelfAttention(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger, std::vector& nodes_to_clear_shape, - std::unordered_set& self_attention_dropout_nodes) const; + std::unordered_set& dropout_nodes_to_transform, + int32_t& counter) const; + + Status TransformBARTSelfAttention(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger, + std::vector& nodes_to_clear_shape, + std::unordered_set& dropout_nodes_to_transform, int32_t& counter) const; + + Status TransformBARTMLP(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger, + std::vector& nodes_to_clear_shape, + std::unordered_set& dropout_nodes_to_transform, int32_t& counter) const; Status TransformDropout(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger, - std::unordered_set& self_attention_dropout_nodes) const; + std::unordered_set& dropout_nodes_to_transform, int32_t& counter) const; bool PartitionWeightByColumn(const Graph& graph, const NodeArg& input_arg, - ONNX_NAMESPACE::TensorProto& initializer_partition, int stride = 1) const; + ONNX_NAMESPACE::TensorProto& initializer_partition, + int stride = 1) const; - bool PartitionWeightByRow(const Graph& graph, const NodeArg& input_arg, - ONNX_NAMESPACE::TensorProto& initializer_partition) const; + bool PartitionWeightByRow(const Graph& graph, const NodeArg& input_arg, ONNX_NAMESPACE::TensorProto& initializer_partition) const; const int32_t horizontal_parallel_rank_; const int32_t horizontal_parallel_size_; + std::unordered_map& updated_weight_names_; + std::unordered_set& weights_to_train_; }; -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index b49db28db2..9b52a55749 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -43,24 +43,36 @@ Status SetupOptimizerParams( const optional& loss_scale_input_name, const TrainingSession::TrainingConfiguration& config, OptimizerGraphConfig& opt_graph_config_result, - std::unordered_map& opt_node_configs_result) { + std::unordered_map& opt_node_configs_result, + std::unordered_map& weight_name_map_after_graph_transform) { ORT_RETURN_IF_NOT(config.optimizer_config.has_value()); const auto& optimizer_config = config.optimizer_config.value(); + // This is the mapping from the new weight name to the original weight name + // It is required to look up the optimizer config for the original weight + // passed in the training session config + std::unordered_map reversed_weight_names_map; + for (auto& p : weight_name_map_after_graph_transform) { + reversed_weight_names_map.insert({p.second, p.first}); + } + std::unordered_map opt_node_configs{}; for (const auto& weight_name : weight_names_to_train) { OptimizerNodeConfig opt_node_config{}; opt_node_config.name = optimizer_config.name; opt_node_config.lr_feed_name = optimizer_config.learning_rate_input_name; - + std::string original_weight_name = weight_name; + if (reversed_weight_names_map.find(original_weight_name) != reversed_weight_names_map.end()) { + original_weight_name = reversed_weight_names_map.at(original_weight_name); + } try { - opt_node_config.attributes = optimizer_config.weight_attributes_generator(weight_name); + opt_node_config.attributes = optimizer_config.weight_attributes_generator(original_weight_name); } catch (const std::exception& ex) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what()); } try { - opt_node_config.int_attributes = optimizer_config.weight_int_attributes_generator(weight_name); + opt_node_config.int_attributes = optimizer_config.weight_int_attributes_generator(original_weight_name); } catch (const std::exception& ex) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what()); } @@ -232,7 +244,8 @@ Status TrainingSession::ConfigureForTraining( } } - ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.graph_transformer_config)); + ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.graph_transformer_config, + config_result)); if (IsRootNode(config) && config.model_with_loss_function_path.has_value()) { ORT_IGNORE_RETURN_VALUE(Save( @@ -244,8 +257,14 @@ Status TrainingSession::ConfigureForTraining( !filtered_config_weight_names_to_train.empty() ? filtered_config_weight_names_to_train : GetTrainableModelInitializers(config.immutable_weights, loss_name); + for (const auto& weight_name_to_not_train : config.weight_names_to_not_train) { - weight_names_to_train.erase(weight_name_to_not_train); + if (config_result.weight_name_map_after_graph_transform.find(weight_name_to_not_train) != + config_result.weight_name_map_after_graph_transform.end()) { + weight_names_to_train.erase(config_result.weight_name_map_after_graph_transform.at(weight_name_to_not_train)); + } else { + weight_names_to_train.erase(weight_name_to_not_train); + } } { @@ -316,7 +335,7 @@ Status TrainingSession::ConfigureForTraining( std::unordered_map opt_node_configs{}; ORT_RETURN_IF_ERROR(SetupOptimizerParams( weights_to_train_, fp32_weight_name_to_mixed_precision_node_arg, - loss_scale_input_name, config, opt_graph_config, opt_node_configs)); + loss_scale_input_name, config, opt_graph_config, opt_node_configs, config_result.weight_name_map_after_graph_transform)); TrainingConfigurationResult::OptimizerConfigurationResult optimizer_config_result{}; ORT_RETURN_IF_ERROR(BuildOptimizer( opt_graph_config, opt_node_configs, @@ -512,8 +531,9 @@ static Status AddGradientAccumulationNodes(Graph& graph, return GraphAugmenter::AugmentGraph(graph, graph_defs); } -Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set& weights_to_train, - const TrainingConfiguration::GraphTransformerConfiguration& config) { +Status TrainingSession::ApplyTransformationsToMainGraph(std::unordered_set& weights_to_train, + const TrainingConfiguration::GraphTransformerConfiguration& config, + TrainingConfigurationResult& config_result_out) { GraphTransformerManager graph_transformation_mgr{2}; // TODO: ideally we can just reuse the CPU EP registered with the session, but in the training session case // the EPs are registered after ConfigureForTraining and before Initialize is called. Hence we don't have access @@ -522,7 +542,7 @@ Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set // Create execution frame for executing constant nodes. std::unique_ptr cpu_execution_provider = onnxruntime::make_unique(CPUExecutionProviderInfo()); - AddPreTrainingTransformers(*cpu_execution_provider, graph_transformation_mgr, weights_to_train, config); + AddPreTrainingTransformers(*cpu_execution_provider, graph_transformation_mgr, weights_to_train, config, config_result_out); // apply transformers Graph& graph = model_->MainGraph(); @@ -536,15 +556,16 @@ Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set // Registers all the pre transformers with transformer manager void TrainingSession::AddPreTrainingTransformers(const IExecutionProvider& execution_provider, GraphTransformerManager& transformer_manager, - const std::unordered_set& weights_to_train, + std::unordered_set& weights_to_train, const TrainingConfiguration::GraphTransformerConfiguration& config, + TrainingConfigurationResult& config_result_out, TransformerLevel graph_optimization_level, const std::vector& custom_list) { auto add_transformers = [&](TransformerLevel level) { // Generate and register transformers for level auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers( - level, weights_to_train, config, execution_provider, custom_list); + level, weights_to_train, config, execution_provider, config_result_out.weight_name_map_after_graph_transform, custom_list); for (auto& entry : transformers_to_register) { transformer_manager.Register(std::move(entry), level); } diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 58eb1d526c..1b687b8820 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -245,6 +245,9 @@ class TrainingSession : public InferenceSession { // The pipeline configuration output. // This is only set if an pipeline is enabled. optional pipeline_config_result; + + // Mapped initialized names after weight partitioning for example MegatronTransformer + std::unordered_map weight_name_map_after_graph_transform{}; }; /** @@ -392,14 +395,16 @@ class TrainingSession : public InferenceSession { common::Status InsertPipelineOps(const std::unordered_set& initializer_names_to_preserve, pipeline::PipelineTensorNames& pipeline_tensor_names); - common::Status ApplyTransformationsToMainGraph(const std::unordered_set& weights_to_train, - const TrainingConfiguration::GraphTransformerConfiguration& config); + common::Status ApplyTransformationsToMainGraph(std::unordered_set& weights_to_train, + const TrainingConfiguration::GraphTransformerConfiguration& config, + TrainingConfigurationResult& config_result_out); /** configure initial transformers for training */ void AddPreTrainingTransformers(const IExecutionProvider& execution_provider, // for constant folding GraphTransformerManager& transformer_manager, - const std::unordered_set& weights_to_train, + std::unordered_set& weights_to_train, const TrainingConfiguration::GraphTransformerConfiguration& config, + TrainingConfigurationResult& config_result_out, TransformerLevel graph_optimization_level = TransformerLevel::MaxLevel, const std::vector& custom_list = {}); diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index fd9f0be0e4..2ae57b30b9 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -43,6 +43,7 @@ struct TrainingParameters { int gradient_accumulation_steps = 1; int data_parallel_size = 1; int horizontal_parallel_size = 1; + int pipeline_parallel_size = 1; int deepspeed_zero_stage = 0; bool enable_grad_norm_clip = true; bool set_gradients_as_graph_outputs = false; @@ -65,11 +66,16 @@ TrainingConfigurationResult ConfigureSessionForTraining( //TODO tix, refactor the mpi related code to populate all fields correctly by default. ORT_ENFORCE(parameters.horizontal_parallel_size <= parameters.world_size); ORT_ENFORCE(parameters.data_parallel_size <= parameters.world_size); + + if (parameters.world_size % parameters.data_parallel_size != 0) { + throw std::runtime_error("Cannot split data parallel group because world_size is not divisible"); + } + if (parameters.world_size % parameters.horizontal_parallel_size != 0) { throw std::runtime_error("Cannot split horizontal parallel group because world_size is not divisible"); } - auto data_group_size = parameters.world_size / parameters.horizontal_parallel_size; + auto data_group_size = parameters.world_size / (parameters.horizontal_parallel_size * parameters.pipeline_parallel_size); if (data_group_size != parameters.data_parallel_size) { LOGS(*(sess->GetLogger()), WARNING) << "data_parallel_size is not correct, tuned automatically to " << data_group_size; @@ -201,7 +207,10 @@ void addObjectMethodsForTraining(py::module& m) { .def_readwrite("attn_dropout_recompute", &TrainingParameters::attn_dropout_recompute) .def_readwrite("gelu_recompute", &TrainingParameters::gelu_recompute) .def_readwrite("transformer_layer_recompute", &TrainingParameters::transformer_layer_recompute) - .def_readwrite("number_recompute_layers", &TrainingParameters::number_recompute_layers); + .def_readwrite("number_recompute_layers", &TrainingParameters::number_recompute_layers) + .def_readwrite("data_parallel_size", &TrainingParameters::data_parallel_size) + .def_readwrite("horizontal_parallel_size", &TrainingParameters::horizontal_parallel_size) + .def_readwrite("pipeline_parallel_size", &TrainingParameters::pipeline_parallel_size); #if defined(USE_MPI) m.def("get_mpi_context_local_rank", []() -> int { return MPIContext::GetInstance().GetLocalRank(); }); diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 6be077c7f3..340a4c07bc 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -163,7 +163,9 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionRank0) { ASSERT_TRUE(ret.IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(onnxruntime::make_unique(0, 2), TransformerLevel::Level1); + std::unordered_map updated_weight_names; + std::unordered_set weights_to_train; + graph_transformation_mgr.Register(onnxruntime::make_unique(0, 2, updated_weight_names, weights_to_train), TransformerLevel::Level1); ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); @@ -231,7 +233,9 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionRank1) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(onnxruntime::make_unique(1, 2), TransformerLevel::Level1); + std::unordered_map updated_weight_names; + std::unordered_set weights_to_train; + graph_transformation_mgr.Register(onnxruntime::make_unique(1, 2, updated_weight_names, weights_to_train), TransformerLevel::Level1); ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); @@ -298,7 +302,9 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank0) { ASSERT_TRUE(ret.IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(onnxruntime::make_unique(0, 2), TransformerLevel::Level1); + std::unordered_map updated_weight_names; + std::unordered_set weights_to_train; + graph_transformation_mgr.Register(onnxruntime::make_unique(0, 2, updated_weight_names, weights_to_train), TransformerLevel::Level1); ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); @@ -363,7 +369,9 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(onnxruntime::make_unique(1, 2), TransformerLevel::Level1); + std::unordered_map updated_weight_names; + std::unordered_set weights_to_train; + graph_transformation_mgr.Register(onnxruntime::make_unique(1, 2, updated_weight_names, weights_to_train), TransformerLevel::Level1); ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); @@ -449,27 +457,35 @@ TEST_F(GraphTransformationTests, BiasGeluRecomputeTest) { // We only tested on CUDA run. #if defined(USE_CUDA) -TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { - auto model_uri = MODEL_FOLDER "model_parallel/mlp_megatron_basic_test.onnx"; - const int total_rank = 4; +static void RunPartitionCorrectnessTest(std::string model_path, +const logging::Logger& logger, +const int total_rank, +std::vector input_names, +std::vector> input_dims) { + const PathString model_uri = ToPathString(model_path) + ORT_TSTR(".onnx"); + // const int total_rank = 4; std::vector graphs; std::vector> p_models(total_rank); for (auto i = 0; i < total_rank; i++) { - auto ret = Model::Load(model_uri, p_models[i], nullptr, *logger_); - ASSERT_TRUE(ret.IsOK()); + Status ret = Model::Load(model_uri, p_models[i], nullptr, logger); + ORT_ENFORCE(ret.IsOK()); Graph& graph = p_models[i]->MainGraph(); - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(onnxruntime::make_unique(i, total_rank), TransformerLevel::Level1); - ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); - ASSERT_TRUE(ret.IsOK()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + std::unordered_map updated_weight_names; + std::unordered_set weights_to_train; + graph_transformation_mgr.Register(onnxruntime::make_unique(i, total_rank, updated_weight_names, weights_to_train), TransformerLevel::Level1); + ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, logger); + ORT_ENFORCE(ret.IsOK()); graphs.push_back(&graph); + auto model_uri2 = ToPathString(model_path) + ORT_TSTR("_partition_rank_") + ToPathString(std::to_string(i)) + ORT_TSTR(".onnx"); + Model::Save(*p_models[i], model_uri2); } - onnxruntime::Model combine_model("combine_graph", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}, {kMSDomain, 1}}, {}, *logger_); + onnxruntime::Model combine_model("combine_graph", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}, {kMSDomain, 1}}, {}, logger); auto& combine_graph = combine_model.MainGraph(); auto ret = horizontal_parallel_test_utils::MergeGraphsOnAllWorkers(graphs, combine_graph); ORT_ENFORCE(ret.IsOK()); - auto model_uri2 = "mlp_megatron_basic_test_partition_combine.onnx"; + auto model_uri2 = ToPathString(model_path) + ORT_TSTR("_partition_combine.onnx"); Model::Save(combine_model, model_uri2); float scale = 1.f; @@ -479,10 +495,18 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { std::default_random_engine generator{gsl::narrow_cast(seed)}; std::normal_distribution distribution{mean, scale}; - std::vector dims_X = {8, 16, 4}; - std::vector values_X(TensorShape(dims_X).Size()); - std::for_each(values_X.begin(), values_X.end(), - [&generator, &distribution](float& value) { value = distribution(generator); }); + ORT_ENFORCE(input_names.size() == input_dims.size()); + NameMLValMap feeds; + for(size_t i = 0; i< input_dims.size();i++){ + std::vector dims_X = input_dims[i]; + std::vector values_X(TensorShape(dims_X).Size()); + std::for_each(values_X.begin(), values_X.end(), + [&generator, &distribution](float& value) { value = distribution(generator); }); + + OrtValue ml_value; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value); + feeds.insert(std::make_pair(input_names[i], ml_value)); + } std::vector expected_ort_values; { @@ -497,11 +521,6 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { ASSERT_TRUE((st = session_object.Load(model_uri)).IsOK()) << st; ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st; - OrtValue ml_value; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value); - NameMLValMap feeds; - feeds.insert(std::make_pair("input", ml_value)); - // prepare outputs std::vector output_names; output_names.push_back("output"); @@ -527,11 +546,6 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { ASSERT_TRUE((st = session_object.Load(model_uri2)).IsOK()) << st; ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st; - OrtValue ml_value; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value); - NameMLValMap feeds; - feeds.insert(std::make_pair("input", ml_value)); - // prepare outputs std::vector output_names; for (auto i = 0; i < total_rank; i++) { @@ -554,137 +568,21 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { } } +TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { + RunPartitionCorrectnessTest("testdata/transform/model_parallel/mlp_megatron_basic_test", *logger_, 4, {"input"}, {{8, 16, 4}}); +} + +TEST_F(GraphTransformationTests, MegatronBARTMLPPartitionCorrectnessTest) { + RunPartitionCorrectnessTest("testdata/transform/model_parallel/bart_mlp_megatron_basic_test", *logger_, 4, {"input"}, {{8, 16, 4}}); +} + TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest) { - auto model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx"; - const int total_rank = 2; // The test graph is too small to partition to 4, so use 2 instead here. - std::vector graphs; - std::vector> p_models(total_rank); - for (auto i = 0; i < total_rank; i++) { - auto ret = Model::Load(model_uri, p_models[i], nullptr, *logger_); - ASSERT_TRUE(ret.IsOK()); - Graph& graph = p_models[i]->MainGraph(); - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(onnxruntime::make_unique(i, total_rank), TransformerLevel::Level1); - ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); - ASSERT_TRUE(ret.IsOK()); - graphs.push_back(&graph); - } - - // Dropout seed checking. - const AttributeProto* attr = graph_utils::GetNodeAttribute(*GetNodeByName(*graphs[0], "dropout1"), "seed"); - ORT_ENFORCE(attr != nullptr && attr->has_i()); - int64_t dropout1_rank0_seed = attr->i(); - attr = graph_utils::GetNodeAttribute(*GetNodeByName(*graphs[0], "dropout2"), "seed"); - ORT_ENFORCE(attr != nullptr && attr->has_i()); - int64_t dropout2_rank0_seed = attr->i(); - for (auto i = 1; i < total_rank; i++) { - attr = graph_utils::GetNodeAttribute(*GetNodeByName(*graphs[i], "dropout1"), "seed"); - ORT_ENFORCE(attr != nullptr && attr->has_i() && attr->i() == dropout1_rank0_seed + i); - attr = graph_utils::GetNodeAttribute(*GetNodeByName(*graphs[i], "dropout2"), "seed"); - ORT_ENFORCE(attr != nullptr && attr->has_i() && attr->i() == dropout2_rank0_seed); - } - - onnxruntime::Model combine_model("combine_graph", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}, {kMSDomain, 1}}, {}, *logger_); - auto& combine_graph = combine_model.MainGraph(); - auto ret = horizontal_parallel_test_utils::MergeGraphsOnAllWorkers(graphs, combine_graph); - ORT_ENFORCE(ret.IsOK()); - auto model_uri2 = "self_attention_megatron_basic_test_partition_combine.onnx"; - Model::Save(combine_model, model_uri2); - - float scale = 1.f; - float mean = 0.f; - float seed = 123.f; - - std::default_random_engine generator{gsl::narrow_cast(seed)}; - std::normal_distribution distribution{mean, scale}; - - std::vector dims_X = {8, 16, 4}; - std::vector values_X(TensorShape(dims_X).Size()); - std::for_each(values_X.begin(), values_X.end(), - [&generator, &distribution](float& value) { value = distribution(generator); }); - - std::vector dims_Mask = {8, 1, 16, 16}; - std::vector values_Mask(TensorShape(dims_Mask).Size()); - std::for_each(values_Mask.begin(), values_Mask.end(), - [&generator, &distribution](float& value) { value = distribution(generator); }); - - std::vector expected_ort_values; - { - SessionOptions so; - so.session_logid = "RawGraphRun"; - - InferenceSession session_object{so, GetEnvironment()}; - std::unique_ptr execution_provider = DefaultCudaExecutionProvider(); - EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); - - Status st; - ASSERT_TRUE((st = session_object.Load(model_uri)).IsOK()) << st; - ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st; - - NameMLValMap feeds; - - OrtValue ml_value; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value); - feeds.insert(std::make_pair("input", ml_value)); - - OrtValue mask_value; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_Mask, values_Mask, &mask_value); - feeds.insert(std::make_pair("mask", mask_value)); - - // prepare outputs - std::vector output_names; - output_names.push_back("output"); - - // Now run - RunOptions run_options; - run_options.training_mode = true; - st = session_object.Run(run_options, feeds, output_names, &expected_ort_values); - EXPECT_TRUE(st.IsOK()); - } - - std::vector actual_ort_values; - { - SessionOptions so; - so.session_logid = "SplitThenCombineRun"; - - InferenceSession session_object{so, GetEnvironment()}; - std::unique_ptr execution_provider = DefaultCudaExecutionProvider(); - EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); - - Status st; - ASSERT_TRUE((st = session_object.Load(model_uri2)).IsOK()) << st; - ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st; - - NameMLValMap feeds; - OrtValue ml_value; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value); - feeds.insert(std::make_pair("input", ml_value)); - - OrtValue mask_value; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_Mask, values_Mask, &mask_value); - feeds.insert(std::make_pair("mask", mask_value)); - - // prepare outputs - std::vector output_names; - for (auto i = 0; i < total_rank; i++) { - output_names.push_back("output_rank_" + std::to_string(i)); - } - - // Now run - RunOptions run_options; - run_options.training_mode = true; - st = session_object.Run(run_options, feeds, output_names, &actual_ort_values); - EXPECT_TRUE(st.IsOK()); - } - - auto& expected_val = expected_ort_values[0].Get(); - for (auto i = 0; i < total_rank; i++) { - auto& actual_val = actual_ort_values[i].Get(); - horizontal_parallel_test_utils::VerifyOutputs(expected_val, actual_val, true); - horizontal_parallel_test_utils::VerifyOutputs(expected_val, actual_val, false); - } + RunPartitionCorrectnessTest("testdata/transform/model_parallel/self_attention_megatron_basic_test", *logger_, 2, {"input", "mask"}, {{8, 16, 4}, {8, 1, 16, 16}}); } +TEST_F(GraphTransformationTests, MegatronBARTSelfAttentionPartitionCorrectnessTest) { + RunPartitionCorrectnessTest("testdata/transform/model_parallel/bart_self_attention_megatron_basic_test", *logger_, 2, {"input"}, {{6, 8, 4}}); +} // end of USE_CUDA #endif