From 5eaac31faa291ec994963c7accbd8f2c4ca5ca5a Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Wed, 19 Aug 2020 11:13:37 +0800 Subject: [PATCH] support opset13 on transformers. (#4837) Co-authored-by: Vincent Wang --- .../onnxruntime/core/framework/data_types.h | 4 ++ .../core/optimizer/bias_gelu_fusion.cc | 2 +- .../core/optimizer/conv_activation_fusion.cc | 8 ++-- onnxruntime/core/optimizer/conv_add_fusion.cc | 2 +- onnxruntime/core/optimizer/conv_mul_fusion.cc | 2 +- .../core/optimizer/dropout_elimination.cc | 2 +- .../dynamic_quantize_matmul_fusion.cc | 6 +-- .../core/optimizer/embed_layer_norm_fusion.cc | 10 ++--- .../core/optimizer/fast_gelu_fusion.cc | 35 ++++++++------- .../core/optimizer/gelu_approximation.cc | 4 +- onnxruntime/core/optimizer/gelu_fusion.cc | 12 ++--- .../core/optimizer/gemm_activation_fusion.cc | 8 ++-- .../core/optimizer/layer_norm_fusion.cc | 22 ++++----- .../core/optimizer/matmul_add_fusion.cc | 6 +-- .../core/optimizer/matmul_scale_fusion.cc | 8 ++-- .../core/optimizer/matmul_transpose_fusion.cc | 4 +- .../core/optimizer/nchwc_transformer.cc | 20 ++++----- .../core/optimizer/relu_clip_fusion.cc | 9 +++- onnxruntime/core/optimizer/reshape_fusion.cc | 26 +++++------ .../core/optimizer/shape_to_initializer.cc | 2 +- .../core/optimizer/skip_layer_norm_fusion.cc | 12 ++--- .../core/optimizer/slice_elimination.cc | 6 +-- .../core/optimizer/bias_dropout_fusion.cc | 6 +-- .../core/optimizer/insert_output_rewriter.cc | 2 +- .../core/optimizer/megatron_transformer.cc | 45 ++++++++++--------- 25 files changed, 137 insertions(+), 126 deletions(-) diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index d89bdc555f..87aa7012b7 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -139,6 +139,10 @@ struct BFloat16 { } return result; } + + operator float() const { + return ToFloat(); + } }; inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) { diff --git a/onnxruntime/core/optimizer/bias_gelu_fusion.cc b/onnxruntime/core/optimizer/bias_gelu_fusion.cc index e7da15aa14..cfc763c880 100644 --- a/onnxruntime/core/optimizer/bias_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/bias_gelu_fusion.cc @@ -24,7 +24,7 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || !optimizer_utils::CheckOutputEdges(graph, node, 1)) { continue; diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index cd0457f32a..138deeeb36 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -103,12 +103,12 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l // Test if this is an activation that can be fused and also extract the // activation's parameters. std::vector activation_params; - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6}) && - !graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Sigmoid", {6}) && - !graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Tanh", {6})) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13}) && + !graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Sigmoid", {6, 13}) && + !graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Tanh", {6, 13})) { if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LeakyRelu", {6})) { activation_params.push_back(graph_utils::GetNodeAttribute(next_node, "alpha")->f()); - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12})) { + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12, 13})) { float min, max; if (GetClipConstantMinMax(graph, next_node, min, max)) { activation_params.push_back(min); diff --git a/onnxruntime/core/optimizer/conv_add_fusion.cc b/onnxruntime/core/optimizer/conv_add_fusion.cc index d3552c9cf7..5b2520a0dc 100644 --- a/onnxruntime/core/optimizer/conv_add_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_fusion.cc @@ -110,7 +110,7 @@ bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const } const auto& next_node = *node.OutputNodesBegin(); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13}) || next_node.GetInputEdgesCount() != 1 || // Make sure the two nodes do not span execution providers. next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { diff --git a/onnxruntime/core/optimizer/conv_mul_fusion.cc b/onnxruntime/core/optimizer/conv_mul_fusion.cc index 4cfe89d6ac..8f1b4aab95 100644 --- a/onnxruntime/core/optimizer/conv_mul_fusion.cc +++ b/onnxruntime/core/optimizer/conv_mul_fusion.cc @@ -121,7 +121,7 @@ bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const } const auto& next_node = *node.OutputNodesBegin(); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13}) || next_node.GetInputEdgesCount() != 1 || // Make sure the two nodes do not span execution providers. next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { diff --git a/onnxruntime/core/optimizer/dropout_elimination.cc b/onnxruntime/core/optimizer/dropout_elimination.cc index 5f2f097dff..327fb923c9 100644 --- a/onnxruntime/core/optimizer/dropout_elimination.cc +++ b/onnxruntime/core/optimizer/dropout_elimination.cc @@ -21,7 +21,7 @@ Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { // We currently support elimination for Dropout operator v1, v6, v7, v10 and v12. // REVIEW(mzs): v10 implementation does not exist. - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {1, 6, 7, 10, 12})) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {1, 6, 7, 10, 12, 13})) { return false; } diff --git a/onnxruntime/core/optimizer/dynamic_quantize_matmul_fusion.cc b/onnxruntime/core/optimizer/dynamic_quantize_matmul_fusion.cc index e817d099cc..404b2c13cd 100644 --- a/onnxruntime/core/optimizer/dynamic_quantize_matmul_fusion.cc +++ b/onnxruntime/core/optimizer/dynamic_quantize_matmul_fusion.cc @@ -86,19 +86,19 @@ Status DynamicQuantizeMatMulFusion::ApplyImpl(Graph& graph, bool& modified, int ORT_RETURN_IF_ERROR(Recurse(mul_node, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) || !graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders())) { continue; } // Left Parents path std::vector left_parent_path{ - {0, 0, "Cast", {6, 9}, kOnnxDomain}, + {0, 0, "Cast", {6, 9, 13}, kOnnxDomain}, {0, 0, "MatMulInteger", {10}, kOnnxDomain}, {0, 0, "DynamicQuantizeLinear", {11}, kOnnxDomain}}; std::vector right_parent_path{ - {0, 1, "Mul", {7}, kOnnxDomain}, + {0, 1, "Mul", {7, 13}, kOnnxDomain}, {1, 0, "DynamicQuantizeLinear", {11}, kOnnxDomain}}; std::vector> left_nodes; diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc index 0ab7b84b85..079caf405e 100644 --- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc @@ -501,20 +501,20 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l } // Find ReduceSum --> Attention std::vector edges; - if (!graph_utils::FindPath(attention_node, true, {{0, 3, "ReduceSum", {1, 11}, kOnnxDomain}}, edges, logger)) { + if (!graph_utils::FindPath(attention_node, true, {{0, 3, "ReduceSum", {1, 11, 13}, kOnnxDomain}}, edges, logger)) { continue; } Node& reduce_sum_node = *graph.GetNode(edges[0]->GetNode().Index()); // Find Add --> LayerNormalization - if (!graph_utils::FindPath(layer_norm_node, true, {{0, 0, "Add", {7}, kOnnxDomain}}, edges, logger)) { + if (!graph_utils::FindPath(layer_norm_node, true, {{0, 0, "Add", {7, 13}, kOnnxDomain}}, edges, logger)) { continue; } Node& layer_norm_add_node = *graph.GetNode(edges[0]->GetNode().Index()); // Trace back to find the Gather for segment embedding. std::vector segment_embedding_path{ - {0, 1, "Gather", {1, 11}, kOnnxDomain}}; + {0, 1, "Gather", {1, 11, 13}, kOnnxDomain}}; if (!graph_utils::FindPath(layer_norm_add_node, true, segment_embedding_path, edges, logger)) { continue; } @@ -534,8 +534,8 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l // Trace back to find Gather --> Add --> LayerNormalization std::vector word_embedding_path{ - {0, 0, "Add", {7}, kOnnxDomain}, - {0, 0, "Gather", {1, 11}, kOnnxDomain}}; + {0, 0, "Add", {7, 13}, kOnnxDomain}, + {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}}; if (!graph_utils::FindPath(layer_norm_add_node, true, word_embedding_path, edges, logger)) { continue; } diff --git a/onnxruntime/core/optimizer/fast_gelu_fusion.cc b/onnxruntime/core/optimizer/fast_gelu_fusion.cc index a2e8528ff1..57a4e5f86d 100644 --- a/onnxruntime/core/optimizer/fast_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/fast_gelu_fusion.cc @@ -13,7 +13,7 @@ using namespace onnxruntime::common; namespace onnxruntime { // FastGelu supports limited data types. -static std::vector gpu_supported_data_types{"tensor(float16)", "tensor(float)"}; +static std::vector gpu_supported_data_types{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"}; static std::vector cpu_supported_data_types{"tensor(float)"}; static bool IsSupportedDataType(const Node& node) { @@ -24,9 +24,10 @@ static bool IsSupportedDataType(const Node& node) { } } -static bool CheckNode(Graph& graph, const Node& node, const std::string& op_name, int32_t opset_version, ProviderType provider, - bool require_single_output) { - return graph_utils::IsSupportedOptypeVersionAndDomain(node, op_name, {opset_version}) && +static bool CheckNode(Graph& graph, const Node& node, const std::string& op_name, + const std::initializer_list& opset_versions, + ProviderType provider, bool require_single_output) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, op_name, opset_versions) && node.GetExecutionProviderType() == provider && IsSupportedDataType(node) && (!require_single_output || node.GetOutputEdgesCount() == 1) && @@ -36,7 +37,7 @@ static bool CheckNode(Graph& graph, const Node& node, const std::string& op_name MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node, std::vector>& nodes_to_fuse) const { MatchResult matchResult{false, nullptr, nullptr}; - if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul1_node, "Mul", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul1_node, "Mul", {7, 13}) || !graph_utils::IsSupportedProvider(mul1_node, GetCompatibleExecutionProviders()) || mul1_node.GetOutputEdgesCount() != 1 || !IsSupportedDataType(mul1_node)) { @@ -59,7 +60,7 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node, Node& mul2_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index()); input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *mul1_node.MutableOutputDefs()[0]); - if (!CheckNode(graph, mul2_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true) || + if (!CheckNode(graph, mul2_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true) || mul2_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != gelu_without_bias_input_arg->Name()) { return matchResult; } @@ -67,14 +68,14 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node, Node& add1_node = *graph.GetNode(mul2_node.OutputNodesBegin()->Index()); input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul2_node.MutableOutputDefs()[0]); - if (!CheckNode(graph, add1_node, "Add", 7, mul1_node.GetExecutionProviderType(), true) || + if (!CheckNode(graph, add1_node, "Add", {7, 13}, mul1_node.GetExecutionProviderType(), true) || !optimizer_utils::IsInitializerWithExpectedValue(graph, *(add1_node.InputDefs()[(input_index + 1) % 2]), 1.0f, true)) { return matchResult; } nodes_to_fuse.push_back(add1_node); Node& mul3_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index()); - if (!CheckNode(graph, mul3_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true)) { + if (!CheckNode(graph, mul3_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true)) { return matchResult; } nodes_to_fuse.push_back(mul3_node); @@ -83,7 +84,7 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node, const Node* p_mul3_input_node = graph_utils::GetInputNode(mul3_node, (input_index + 1) % 2); if (p_mul3_input_node == nullptr) return matchResult; Node& mul4_node = const_cast(*p_mul3_input_node); - if (!CheckNode(graph, mul4_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true)) { + if (!CheckNode(graph, mul4_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true)) { return matchResult; } @@ -109,7 +110,7 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node, MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node, std::vector>& nodes_to_fuse) const { MatchResult matchResult{false, nullptr, nullptr}; - if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow1_node, "Pow", {7, 12}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow1_node, "Pow", {7, 12, 13}) || !graph_utils::IsSupportedProvider(pow1_node, GetCompatibleExecutionProviders()) || pow1_node.GetOutputEdgesCount() != 1 || !IsSupportedDataType(pow1_node)) { @@ -125,7 +126,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node, Node& mul1_node = *graph.GetNode(pow1_node.OutputNodesBegin()->Index()); auto input_index = optimizer_utils::IndexOfNodeInput(mul1_node, *pow1_node.MutableOutputDefs()[0]); - if (!CheckNode(graph, mul1_node, "Mul", 7, pow1_node.GetExecutionProviderType(), true) || + if (!CheckNode(graph, mul1_node, "Mul", {7, 13}, pow1_node.GetExecutionProviderType(), true) || !optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul1_node.InputDefs()[(input_index + 1) % 2]), 0.044714998453855515f, true)) { return matchResult; @@ -134,7 +135,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node, Node& add1_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index()); input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul1_node.MutableOutputDefs()[0]); - if (!CheckNode(graph, add1_node, "Add", 7, pow1_node.GetExecutionProviderType(), true) || + if (!CheckNode(graph, add1_node, "Add", {7, 13}, pow1_node.GetExecutionProviderType(), true) || add1_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != pow_input_arg->Name()) { return matchResult; } @@ -142,7 +143,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node, Node& mul2_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index()); input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *add1_node.MutableOutputDefs()[0]); - if (!CheckNode(graph, mul2_node, "Mul", 7, pow1_node.GetExecutionProviderType(), true) || + if (!CheckNode(graph, mul2_node, "Mul", {7, 13}, pow1_node.GetExecutionProviderType(), true) || !optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul2_node.InputDefs()[(input_index + 1) % 2]), 0.7978845834732056f, true)) { return matchResult; @@ -177,12 +178,12 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, }; Node& tanh_node = *graph.GetNode(matchRet.tanh_input_node->OutputNodesBegin()->Index()); - if (!CheckNode(graph, tanh_node, "Tanh", 6, node.GetExecutionProviderType(), true)) { + if (!CheckNode(graph, tanh_node, "Tanh", {6, 13}, node.GetExecutionProviderType(), true)) { continue; } Node& add2_node = *graph.GetNode(tanh_node.OutputNodesBegin()->Index()); - if (!CheckNode(graph, add2_node, "Add", 7, node.GetExecutionProviderType(), true)) { + if (!CheckNode(graph, add2_node, "Add", {7, 13}, node.GetExecutionProviderType(), true)) { continue; } @@ -193,7 +194,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& mul5_node = *graph.GetNode(add2_node.OutputNodesBegin()->Index()); // This is the output of the Gelu subgraph, we don't need check it has single edge. - if (!CheckNode(graph, mul5_node, "Mul", 7, node.GetExecutionProviderType(), false)) { + if (!CheckNode(graph, mul5_node, "Mul", {7, 13}, node.GetExecutionProviderType(), false)) { continue; } @@ -201,7 +202,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const Node* p_mul5_input_node = graph_utils::GetInputNode(mul5_node, (input_index + 1) % 2); if (p_mul5_input_node == nullptr) continue; Node& mul6_node = const_cast(*p_mul5_input_node); - if (!CheckNode(graph, mul6_node, "Mul", 7, node.GetExecutionProviderType(), false)) { + if (!CheckNode(graph, mul6_node, "Mul", {7, 13}, node.GetExecutionProviderType(), false)) { continue; } diff --git a/onnxruntime/core/optimizer/gelu_approximation.cc b/onnxruntime/core/optimizer/gelu_approximation.cc index 494666dd9e..8bd2058349 100644 --- a/onnxruntime/core/optimizer/gelu_approximation.cc +++ b/onnxruntime/core/optimizer/gelu_approximation.cc @@ -13,7 +13,7 @@ using namespace onnxruntime::common; namespace onnxruntime { // FastGelu supports limited data types. -static std::vector supported_data_types{"tensor(float16)", "tensor(float)"}; +static std::vector supported_data_types{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"}; static bool IsSupportedDataType(const Node& node) { for (const auto& input_arg : node.InputDefs()) { @@ -51,7 +51,7 @@ static bool CheckInputShape(const Node& node, const NodeArg& input, const NodeAr // it means that the shape of MatMul output is good for FastGelu. const Node* parent_node = graph_utils::GetInputNode(node, 0); if (nullptr != parent_node && - graph_utils::IsSupportedOptypeVersionAndDomain(*parent_node, "MatMul", {1, 9}, kOnnxDomain)) { + graph_utils::IsSupportedOptypeVersionAndDomain(*parent_node, "MatMul", {1, 9, 13}, kOnnxDomain)) { const NodeArg& input_b = *(parent_node->InputDefs()[1]); if (optimizer_utils::ValidateShape(input_b, {-1, bias_length})) { return true; diff --git a/onnxruntime/core/optimizer/gelu_fusion.cc b/onnxruntime/core/optimizer/gelu_fusion.cc index a93338f6c3..b5ddc68f4f 100644 --- a/onnxruntime/core/optimizer/gelu_fusion.cc +++ b/onnxruntime/core/optimizer/gelu_fusion.cc @@ -55,7 +55,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons Node& div = *p_div; ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7, 13}) || !graph_utils::IsSupportedProvider(div, GetCompatibleExecutionProviders()) || !optimizer_utils::CheckOutputEdges(graph, div, 1) || !IsSupportedDataType(div)) { @@ -71,7 +71,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons } Node& erf_node = *graph.GetNode(div.OutputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(erf_node, "Erf", {9}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(erf_node, "Erf", {9, 13}) || erf_node.GetExecutionProviderType() != div.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, erf_node, 1) || !IsSupportedDataType(erf_node)) { @@ -79,7 +79,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons } Node& add_node = *graph.GetNode(erf_node.OutputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13}) || add_node.GetExecutionProviderType() != div.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, add_node, 1) || !IsSupportedDataType(add_node)) { @@ -95,7 +95,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons Node& mul_node = *graph.GetNode(add_node.OutputNodesBegin()->Index()); // note: output edges count doesn't matter as the new Gelu node will produce outputs with the same names - if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) || mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() || !IsSupportedDataType(mul_node)) { continue; @@ -106,7 +106,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons if (p_mul2_node != nullptr) { // Match subgraph pattern 1 Node& mul2_node = *graph.GetNode(p_mul2_node->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13}) || mul2_node.GetExecutionProviderType() != div.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, mul2_node, 1) || !IsSupportedDataType(mul2_node)) { @@ -139,7 +139,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons continue; Node& mul2_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13}) || mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() || !IsSupportedDataType(mul_node)) { continue; diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index b129e43f51..b2cb4995d2 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -24,12 +24,12 @@ bool IsFusableActivation(const Node& node) { return IsSupportedOptypeVersionAndDomain(node, "Elu", {6}, kOnnxDomain) || IsSupportedOptypeVersionAndDomain(node, "HardSigmoid", {6}, kOnnxDomain) || IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", {6}, kOnnxDomain) || - IsSupportedOptypeVersionAndDomain(node, "Relu", {6}, kOnnxDomain) || + IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13}, kOnnxDomain) || IsSupportedOptypeVersionAndDomain(node, "Selu", {6}, kOnnxDomain) || - IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6}, kOnnxDomain) || + IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}, kOnnxDomain) || IsSupportedOptypeVersionAndDomain(node, "Softplus", {1}, kOnnxDomain) || IsSupportedOptypeVersionAndDomain(node, "Softsign", {1}, kOnnxDomain) || - IsSupportedOptypeVersionAndDomain(node, "Tanh", {6}, kOnnxDomain) || + IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13}, kOnnxDomain) || #ifndef DISABLE_CONTRIB_OPS IsSupportedOptypeVersionAndDomain(node, "ScaledTanh", {1}, kOnnxDomain) || IsSupportedOptypeVersionAndDomain(node, "ParametricSoftplus", {1}, kOnnxDomain) || @@ -51,7 +51,7 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l auto& node = *node_ptr; ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gemm", {7, 9, 11}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gemm", {7, 9, 11, 13}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || node.GetOutputEdgesCount() != 1) { continue; } diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index fd82270923..4a3ed60e10 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -69,7 +69,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& reduce_mean_node = *p_reduce_mean; ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13}) || !graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) || (reduce_mean_node.GetOutputEdgesCount() != 1 && reduce_mean_node.GetOutputEdgesCount() != 2) || !graph.GetNodeOutputsInGraphOutputs(reduce_mean_node).empty() || @@ -102,7 +102,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } Node& sub_node = *graph.GetNode(p_sub_node->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node, "Sub", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node, "Sub", {7, 13}) || sub_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, sub_node, subCnt == 1 ? 2u : 1u) || !IsSupportedDataType(sub_node)) { @@ -117,7 +117,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // Find the sub_dup node if exist if (p_sub_node_dup != nullptr) { Node& sub_node_dup = *graph.GetNode(p_sub_node_dup->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node_dup, "Sub", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node_dup, "Sub", {7, 13}) || sub_node_dup.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, sub_node, 1) || !IsSupportedDataType(sub_node_dup)) { @@ -134,7 +134,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } Node& div_node = *graph.GetNode(p_div->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13}) || div_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, div_node, 1) || !IsSupportedDataType(div_node)) { @@ -149,7 +149,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } Node& sqrt_node = *graph.GetNode(p_sqrt->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(sqrt_node, "Sqrt", {6}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(sqrt_node, "Sqrt", {6, 13}) || sqrt_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, sqrt_node, 1) || !IsSupportedDataType(sqrt_node) || @@ -160,7 +160,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // Traceback the sqrt node to find add --> sqrt Node& add2_node = *graph.GetNode(sqrt_node.InputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7, 13}) || add2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, add2_node, 1) || !IsSupportedDataType(add2_node)) { @@ -175,7 +175,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } Node& reduce_mean2_node = *graph.GetNode(p_reduce_mean2->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11, 13}) || reduce_mean2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, reduce_mean2_node, 1) || !IsSupportedDataType(reduce_mean2_node) || @@ -186,7 +186,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // Traceback the reduceMean node to find pow --> reduceMean Node& pow_node = *graph.GetNode(reduce_mean2_node.InputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow_node, "Pow", {7, 12}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow_node, "Pow", {7, 12, 13}) || pow_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, pow_node, 1) || !IsSupportedDataType(pow_node)) { @@ -198,7 +198,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const Node* p_cast_node = graph_utils::FirstParentByType(pow_node, "Cast"); if (p_cast_node != nullptr) { Node& cast_node = *graph.GetNode(pow_node.InputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13}) || cast_node.GetExecutionProviderType() != cast_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, cast_node, 1) || !IsSupportedDataType(cast_node)) { @@ -216,7 +216,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // div --> mul Node& mul_node = *graph.GetNode(div_node.OutputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) || mul_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, mul_node, 1) || !IsSupportedDataType(mul_node)) { @@ -227,7 +227,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // mul --> add // Need not check output edges of last node since they will be moved to fused node. Node& last_add_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(last_add_node, "Add", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(last_add_node, "Add", {7, 13}) || last_add_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !IsSupportedDataType(last_add_node)) { continue; diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index caec86f7f4..de83c32417 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -24,7 +24,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9, 13}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || node.GetOutputEdgesCount() != 1) { continue; @@ -40,7 +40,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } const Node& next_node = (*next_node_itr); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13}) || next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { continue; } @@ -58,7 +58,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, if ((*matmul_type) != (*add_type)) { continue; } - if ((*matmul_type) != "tensor(float)" && (*matmul_type) != "tensor(float16)") { + if ((*matmul_type) != "tensor(float)" && (*matmul_type) != "tensor(float16)" && (*matmul_type) != "tensor(bfloat16)") { continue; } diff --git a/onnxruntime/core/optimizer/matmul_scale_fusion.cc b/onnxruntime/core/optimizer/matmul_scale_fusion.cc index 2b540d691b..8411da1b57 100644 --- a/onnxruntime/core/optimizer/matmul_scale_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_scale_fusion.cc @@ -46,7 +46,7 @@ optional GetScalarConstantInitializer(const Graph& graph, const NodeArg& float scalar=0.f; utils::MLTypeCallDispatcherRet< Status, ExtractScalarAsFloatDispatchTarget, - uint32_t, uint64_t, int32_t, int64_t, MLFloat16, float, double> + uint32_t, uint64_t, int32_t, int64_t, MLFloat16, float, double, BFloat16> dispatcher{initializer->data_type()}; ORT_THROW_IF_ERROR(dispatcher.Invoke(*initializer, scalar)); @@ -62,7 +62,7 @@ optional> GetScaleFromNode( return excluded_initializer_names.find(node_arg.Name()) != excluded_initializer_names.end(); }; - if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Div", {7})) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Div", {7, 13})) { // (x / scale_reciprocal) const auto div_inputs = scale_node.InputDefs(); ORT_ENFORCE(div_inputs.size() == 2); @@ -79,7 +79,7 @@ optional> GetScaleFromNode( return {std::make_pair(1.0f / divisor.value(), scale_reciprocal_arg_index)}; } - if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Mul", {7})) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Mul", {7, 13})) { // (x * scale) or (scale * x) const auto mul_inputs = scale_node.InputDefs(); ORT_ENFORCE(mul_inputs.size() == 2); @@ -170,7 +170,7 @@ std::vector GetOutputNodeMerges( Status ProcessNode( Graph& graph, Node& node, bool& modified, const std::unordered_set& excluded_initializer_names) { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9}) && + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9, 13}) && !graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeScaleMatMul", {1}, kMSDomain)) { return Status::OK(); } diff --git a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc index 05c71d4fa2..bdb6329758 100644 --- a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc @@ -96,8 +96,8 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ auto& node = *graph.GetNode(node_index); ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - if ((!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9}) && - !graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeScaleMatMul", {1}, kMSDomain)) || + if ((!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9, 13}) && + !graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeScaleMatMul", {1, 13}, kMSDomain)) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { continue; } diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index f00709dcbc..7414c99e13 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -937,23 +937,23 @@ void NchwcTransformerImpl::Transform(Node& node) { // node may already have all inputs converted to NCHWc format and is not // needed for correct operation. This avoids doing extra string checks for // nodes unrelated to this transformer. - if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sum", {6, 8})) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sum", {6, 8, 13})) { TransformBinary(node, true); - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7})) { + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13})) { TransformBinary(node, false); - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11})) { + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11, 13})) { TransformConcat(node); - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6})) { + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13})) { TransformActivation(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9})) { TransformBatchNormalization(node); - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1})) { + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13})) { TransformTranspose(node); - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Upsample", {9}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "Resize", {10, 11})) { + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Upsample", {9, 13}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Resize", {10, 11, 13})) { TransformResize(node); } } diff --git a/onnxruntime/core/optimizer/relu_clip_fusion.cc b/onnxruntime/core/optimizer/relu_clip_fusion.cc index e0b86f9a7f..981ce9b554 100644 --- a/onnxruntime/core/optimizer/relu_clip_fusion.cc +++ b/onnxruntime/core/optimizer/relu_clip_fusion.cc @@ -68,6 +68,11 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff replace_min = true; } break; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + if (i.data()->ToFloat() < 0.f) { + replace_min = true; + } + break; default: ORT_THROW("Unexpected data type for Clip 'min' input of ", initializer->data_type()); } @@ -110,7 +115,7 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff } bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6})) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13})) { return false; } @@ -122,7 +127,7 @@ bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node, const // as Clip will apply the minimum. If the Clip 'min' value is < 0 we need // to update it to 0 to apply what the Relu would have done. We do that in Apply. const auto& next_node = *node.OutputNodesBegin(); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12, 13}) || next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { return false; } diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index c320ed2fb0..3e2389495a 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -23,7 +23,7 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c Node& reshape = *p_reshape; ORT_RETURN_IF_ERROR(Recurse(reshape, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(reshape, "Reshape", {5}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(reshape, "Reshape", {5, 13}) || !graph_utils::IsSupportedProvider(reshape, GetCompatibleExecutionProviders())) { continue; } @@ -48,9 +48,9 @@ bool ReshapeFusion::Match_One_Element_Output_Subgraph_1(Graph& graph, const Node int index, std::vector shape_value, bool checkOneElementOnly, const logging::Logger& logger) { std::vector parent_path{ - {0, index, "Unsqueeze", {1, 11}, kOnnxDomain}, - {0, 0, "Gather", {1, 11}, kOnnxDomain}, - {0, 0, "Shape", {1}, kOnnxDomain}}; + {0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Shape", {1, 13}, kOnnxDomain}}; std::vector edges; if (graph_utils::FindPath(concat, true, parent_path, edges, logger)) { const Node& unsqueeze = edges[0]->GetNode(); @@ -87,9 +87,9 @@ bool ReshapeFusion::Match_One_Element_Output_Subgraph_1(Graph& graph, const Node bool ReshapeFusion::Match_One_Element_Output_Subgraph_2(Graph& graph, const NodeArg& root_input, const Node& cur_node, int index, const logging::Logger& logger) { std::vector parent_path{ - {0, index, "Squeeze", {1, 11}, kOnnxDomain}, - {0, 0, "Slice", {1, 11}, kOnnxDomain}, - {0, 0, "Shape", {1}, kOnnxDomain}}; + {0, index, "Squeeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Slice", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Shape", {1, 13}, kOnnxDomain}}; std::vector edges; if (graph_utils::FindPath(cur_node, true, parent_path, edges, logger)) { const Node& slice = edges[1]->GetNode(); @@ -157,15 +157,15 @@ bool ReshapeFusion::Is_One_Element_Output_Subgraph(Graph& graph, const NodeArg& } std::vector div_path{ - {0, index, "Unsqueeze", {1, 11}, kOnnxDomain}, - {0, 0, "Div", {7}, kOnnxDomain}}; + {0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Div", {7, 13}, kOnnxDomain}}; std::vector mul_path{ - {0, index, "Unsqueeze", {1, 11}, kOnnxDomain}, - {0, 0, "Mul", {7}, kOnnxDomain}}; + {0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Mul", {7, 13}, kOnnxDomain}}; std::vector unsqueeze_path{ - {0, index, "Unsqueeze", {1, 11}, kOnnxDomain}}; + {0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain}}; std::vector edges; if (graph_utils::FindPath(concat, true, div_path, edges, logger) || @@ -242,7 +242,7 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo } const Node& concat = *p_concat; - if (!graph_utils::IsSupportedOptypeVersionAndDomain(concat, "Concat", {1, 4, 11})) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(concat, "Concat", {1, 4, 11, 13})) { return false; } diff --git a/onnxruntime/core/optimizer/shape_to_initializer.cc b/onnxruntime/core/optimizer/shape_to_initializer.cc index 784b1b3541..02aa17d772 100644 --- a/onnxruntime/core/optimizer/shape_to_initializer.cc +++ b/onnxruntime/core/optimizer/shape_to_initializer.cc @@ -50,7 +50,7 @@ Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& ru } bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Shape", {1})) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Shape", {1, 13})) { return false; } diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc index 946be98776..3fa535d084 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -12,7 +12,7 @@ using namespace onnxruntime::common; namespace onnxruntime { // LayerNorm supports limited data types. -static std::vector supported_data_types{"tensor(float16)", "tensor(float)"}; +static std::vector supported_data_types{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"}; static bool IsSupportedDataType(const Node& node) { for (const auto& input_arg : node.InputDefs()) { @@ -163,8 +163,8 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le // Format 1 std::vector format1_parent_path{ - {0, 0, "Add", {7}, kOnnxDomain}, - {0, 0, "Add", {7}, kOnnxDomain}}; + {0, 0, "Add", {7, 13}, kOnnxDomain}, + {0, 0, "Add", {7, 13}, kOnnxDomain}}; std::vector edges; if (graph_utils::FindPath(ln_node, true, format1_parent_path, edges, logger)) { @@ -182,8 +182,8 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (matched_format == Format::None) { // Format 2 std::vector format2_parent_path{ - {0, 0, "Add", {7}, kOnnxDomain}, - {0, 1, "Add", {7}, kOnnxDomain}}; + {0, 0, "Add", {7, 13}, kOnnxDomain}, + {0, 1, "Add", {7, 13}, kOnnxDomain}}; if (graph_utils::FindPath(ln_node, true, format2_parent_path, edges, logger)) { p_add1 = const_cast(&edges[0]->GetNode()); @@ -201,7 +201,7 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (matched_format == Format::None) { // Format 3 std::vector format3_parent_path{ - {0, 0, "Add", {7}, kOnnxDomain}}; + {0, 0, "Add", {7, 13}, kOnnxDomain}}; if (graph_utils::FindPath(ln_node, true, format3_parent_path, edges, logger)) { p_add1 = const_cast(&edges[0]->GetNode()); diff --git a/onnxruntime/core/optimizer/slice_elimination.cc b/onnxruntime/core/optimizer/slice_elimination.cc index 4052d68dff..d87564aafb 100644 --- a/onnxruntime/core/optimizer/slice_elimination.cc +++ b/onnxruntime/core/optimizer/slice_elimination.cc @@ -19,7 +19,7 @@ Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_e bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { // We currently support elimination for Slice operator v1. - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11})) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13})) { return false; } @@ -42,8 +42,8 @@ bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node, cons if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) { return false; } - } else if (graph_utils::MatchesOpSinceVersion(node, {10, 11})) { - // If it is a Slice operator of opset version 10 or 11, starts/ends/axes/steps are provided as node inputs. + } else if (graph_utils::MatchesOpSinceVersion(node, {10, 11, 13})) { + // If it is a Slice operator of opset version >= 10, starts/ends/axes/steps are provided as node inputs. // Returns a pointer to the corresponding NodeArg if input of the node at this index exists; otherwise, a nullptr. auto get_input_if_exists = [&node](size_t input_idx) -> const NodeArg* { diff --git a/orttraining/orttraining/core/optimizer/bias_dropout_fusion.cc b/orttraining/orttraining/core/optimizer/bias_dropout_fusion.cc index d0e467c35a..394b7ee55f 100644 --- a/orttraining/orttraining/core/optimizer/bias_dropout_fusion.cc +++ b/orttraining/orttraining/core/optimizer/bias_dropout_fusion.cc @@ -18,7 +18,7 @@ void FuseResidualAddIfAny(Graph& graph, const Node& dropout_node, for (auto last_node_itr = dropout_node.OutputNodesBegin(); last_node_itr != dropout_node.OutputNodesEnd(); ++last_node_itr) { const Node& last_node = (*last_node_itr); - if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Add", {7}) && + if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Add", {7, 13}) && last_node.GetExecutionProviderType() == dropout_node.GetExecutionProviderType()) { const TensorShapeProto* input1_shape = last_node.InputDefs()[0]->Shape(); const TensorShapeProto* input2_shape = last_node.InputDefs()[1]->Shape(); @@ -83,7 +83,7 @@ Status BiasDropoutFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve std::vector> nodes_to_fuse; // matching for bias Add node - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || node.GetOutputEdgesCount() != 1) { continue; @@ -127,7 +127,7 @@ Status BiasDropoutFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve } const Node& next_node = (*next_node_itr); - if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Dropout", {12}, kOnnxDomain) || + if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Dropout", {12, 13}, kOnnxDomain) || graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "TrainableDropout", {9}, kOnnxDomain)) || next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { continue; diff --git a/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc b/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc index ac8f1caad2..b240aa9ef3 100644 --- a/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc +++ b/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc @@ -53,7 +53,7 @@ Status InsertSoftmaxCrossEntropyLossOutput::Apply(Graph& graph, Node& node, Rewr } bool InsertSoftmaxCrossEntropyLossOutput::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const { - if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "SoftmaxCrossEntropyLoss", {12}) && + if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "SoftmaxCrossEntropyLoss", {12, 13}) && node.OutputDefs().size() == 1) { return true; } diff --git a/orttraining/orttraining/core/optimizer/megatron_transformer.cc b/orttraining/orttraining/core/optimizer/megatron_transformer.cc index 999ab0fbea..385db12f57 100644 --- a/orttraining/orttraining/core/optimizer/megatron_transformer.cc +++ b/orttraining/orttraining/core/optimizer/megatron_transformer.cc @@ -28,24 +28,25 @@ struct OpInfo { size_t output_count; }; -const std::initializer_list opset_v1 = {1}; -const std::initializer_list opset_v1_11 = {1, 11}; -const std::initializer_list opset_v2_11 = {2, 11}; -const std::initializer_list opset_v5 = {5}; -const std::initializer_list opset_v7 = {7}; +const std::initializer_list opset_v1_13 = {1, 13}; +const std::initializer_list opset_v1_11_13 = {1, 11, 13}; +const std::initializer_list opset_v2_11_13 = {2, 11, 13}; +const std::initializer_list opset_v5_13 = {5, 13}; +const std::initializer_list opset_v7_13 = {7, 13}; const std::initializer_list opset_v9 = {9}; -const std::initializer_list opset_v12 = {12}; -const OpInfo add_info = OpInfo("Add", opset_v7); -const OpInfo split_info = OpInfo("Split", opset_v2_11, kOnnxDomainAlias, 3); -const OpInfo reshape_info = OpInfo("Reshape", opset_v5); -const OpInfo transpose_info = OpInfo("Transpose", opset_v1); -const OpInfo matmul_info = OpInfo("MatMul", opset_v9); -const OpInfo div_info = OpInfo("Div", opset_v7); -const OpInfo mul_info = OpInfo("Mul", opset_v7); -const OpInfo sub_info = OpInfo("Sub", opset_v7); -const OpInfo softmax_info = OpInfo("Softmax", opset_v1_11); +const std::initializer_list opset_v9_13 = {9, 13}; +const std::initializer_list opset_v12_13 = {12, 13}; +const OpInfo add_info = OpInfo("Add", opset_v7_13); +const OpInfo split_info = OpInfo("Split", opset_v2_11_13, kOnnxDomainAlias, 3); +const OpInfo reshape_info = OpInfo("Reshape", opset_v5_13); +const OpInfo transpose_info = OpInfo("Transpose", opset_v1_13); +const OpInfo matmul_info = OpInfo("MatMul", opset_v9_13); +const OpInfo div_info = OpInfo("Div", opset_v7_13); +const OpInfo mul_info = OpInfo("Mul", opset_v7_13); +const OpInfo sub_info = OpInfo("Sub", opset_v7_13); +const OpInfo softmax_info = OpInfo("Softmax", opset_v1_11_13); const OpInfo trainable_dropout_info = OpInfo("TrainableDropout", opset_v9, kOnnxDomain); -const OpInfo dropout_info = OpInfo("Dropout", opset_v12); +const OpInfo dropout_info = OpInfo("Dropout", opset_v12_13); struct NodeInfo { NodeInfo(const std::vector& op_infos, @@ -243,7 +244,7 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph auto& node = *graph.GetNode(node_index); ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9, 13}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || node.GetOutputEdgesCount() != 1) { continue; @@ -257,7 +258,7 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph } Node& add_node = *graph.GetNode(node.OutputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13}) || add_node.GetExecutionProviderType() != node.GetExecutionProviderType() || add_node.GetOutputEdgesCount() != 1) { continue; @@ -272,14 +273,14 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph } Node& matmul2_node = *graph.GetNode(gelu_node.OutputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(matmul2_node, "MatMul", {9}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(matmul2_node, "MatMul", {9, 13}) || matmul2_node.GetExecutionProviderType() != node.GetExecutionProviderType() || matmul2_node.GetOutputEdgesCount() != 1) { continue; } Node& add2_node = *graph.GetNode(matmul2_node.OutputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7, 13}) || add2_node.GetExecutionProviderType() != node.GetExecutionProviderType() || add2_node.GetOutputEdgesCount() != 1) { continue; @@ -368,7 +369,7 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified, auto& node = *graph.GetNode(node_index); ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", opset_v9) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", opset_v9_13) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || node.GetOutputEdgesCount() != 1) { continue; @@ -603,7 +604,7 @@ Status MegatronTransformer::TransformDropout(Graph& graph, bool& modified, int g continue; } - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", opset_v12) && + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", opset_v12_13) && !graph_utils::IsSupportedOptypeVersionAndDomain(node, "TrainableDropout", opset_v9, kOnnxDomain)) { continue; }