From 7834ca983c2fafd4705313a8aaa1cd5743724e4d Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Tue, 18 May 2021 11:58:14 -0700 Subject: [PATCH] update optimizers for opset14 (#7722) * update optimizers for opset14 * plus 1 more * fix reshape fusion --- .../core/optimizer/attention_fusion_helper.h | 2 +- .../core/optimizer/bias_dropout_fusion.cc | 4 ++-- .../core/optimizer/bias_gelu_fusion.cc | 2 +- .../core/optimizer/bias_softmax_fusion.cc | 2 +- .../core/optimizer/conv_activation_fusion.cc | 8 +++---- onnxruntime/core/optimizer/conv_add_fusion.cc | 2 +- onnxruntime/core/optimizer/conv_bn_fusion.cc | 2 +- onnxruntime/core/optimizer/conv_mul_fusion.cc | 2 +- onnxruntime/core/optimizer/div_mul_fusion.cc | 4 ++-- .../core/optimizer/fast_gelu_fusion.cc | 22 +++++++++---------- onnxruntime/core/optimizer/gelu_fusion.cc | 10 ++++----- .../core/optimizer/gemm_activation_fusion.cc | 2 +- .../core/optimizer/layer_norm_fusion.cc | 18 +++++++-------- .../core/optimizer/matmul_add_fusion.cc | 2 +- .../core/optimizer/matmul_integer_to_float.cc | 2 +- .../core/optimizer/matmul_scale_fusion.cc | 4 ++-- .../core/optimizer/nchwc_transformer.cc | 8 +++---- .../qdq_transformer/qdq_propagation.cc | 2 +- .../qdq_transformer/relu_quantizelinear.cc | 2 +- .../core/optimizer/relu_clip_fusion.cc | 2 +- onnxruntime/core/optimizer/reshape_fusion.cc | 11 +++++++--- 21 files changed, 59 insertions(+), 54 deletions(-) diff --git a/onnxruntime/core/optimizer/attention_fusion_helper.h b/onnxruntime/core/optimizer/attention_fusion_helper.h index 4a3ccae6c0..97a52bb9b1 100644 --- a/onnxruntime/core/optimizer/attention_fusion_helper.h +++ b/onnxruntime/core/optimizer/attention_fusion_helper.h @@ -1280,7 +1280,7 @@ TODO: replace Gemm_Subgraph by MatMul + Add bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::map& mask_int32_map, bool use_shared_node, const logging::Logger& logger) { DEBUG_LOG("Start FuseGptAttention"); const Node* parent_node = graph_utils::GetInputNode(layer_norm, 0); - if (nullptr == parent_node || !graph_utils::IsSupportedOptypeVersionAndDomain(*parent_node, "Add", {7, 13}, kOnnxDomain)) { + if (nullptr == parent_node || !graph_utils::IsSupportedOptypeVersionAndDomain(*parent_node, "Add", {7, 13, 14}, kOnnxDomain)) { return false; } diff --git a/onnxruntime/core/optimizer/bias_dropout_fusion.cc b/onnxruntime/core/optimizer/bias_dropout_fusion.cc index e27750f6ae..0c59123656 100644 --- a/onnxruntime/core/optimizer/bias_dropout_fusion.cc +++ b/onnxruntime/core/optimizer/bias_dropout_fusion.cc @@ -29,7 +29,7 @@ void FuseResidualAddIfAny(Graph& graph, const Node& dropout_node, for (auto last_node_itr = dropout_node.OutputNodesBegin(); last_node_itr != dropout_node.OutputNodesEnd(); ++last_node_itr) { const Node& last_node = (*last_node_itr); - if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Add", {7, 13}) && + if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Add", {7, 13, 14}) && last_node.GetExecutionProviderType() == dropout_node.GetExecutionProviderType()) { const TensorShapeProto* input1_shape = last_node.InputDefs()[0]->Shape(); const TensorShapeProto* input2_shape = last_node.InputDefs()[1]->Shape(); @@ -90,7 +90,7 @@ Status BiasDropoutFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve std::vector> nodes_to_fuse; // matching for bias Add node - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13, 14}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || node.GetOutputEdgesCount() != 1) { continue; diff --git a/onnxruntime/core/optimizer/bias_gelu_fusion.cc b/onnxruntime/core/optimizer/bias_gelu_fusion.cc index cfc763c880..88ed36c123 100644 --- a/onnxruntime/core/optimizer/bias_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/bias_gelu_fusion.cc @@ -24,7 +24,7 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13, 14}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || !optimizer_utils::CheckOutputEdges(graph, node, 1)) { continue; diff --git a/onnxruntime/core/optimizer/bias_softmax_fusion.cc b/onnxruntime/core/optimizer/bias_softmax_fusion.cc index 6f4e761e56..a235da602e 100644 --- a/onnxruntime/core/optimizer/bias_softmax_fusion.cc +++ b/onnxruntime/core/optimizer/bias_softmax_fusion.cc @@ -43,7 +43,7 @@ bool TryBiasSoftmaxSubgraphMatch(Graph& graph, Node& start, Node*& add, Node*& s add = softmax = nullptr; // check node is add and has single output - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13, 14}) || !graph_utils::IsSupportedProvider(node, {kCudaExecutionProvider, kRocmExecutionProvider}) || !optimizer_utils::CheckOutputEdges(graph, node, 1)) { return false; diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index 933d374eaf..52351ebabc 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -105,7 +105,7 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { continue; } - if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13})) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13, 14})) { Node& conv_node = *node; Node& act_node = *graph.GetNode(next_node.Index()); auto node_name = graph.GenerateNodeName(conv_node.Name() + "_" + act_node.Name()); @@ -120,12 +120,12 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l fused_conv.AddAttribute("activation", "Relu"); graph_utils::FinalizeNodeFusion(graph, {conv_node, act_node}, fused_conv); modified = true; - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {6, 7, 13})) { + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {6, 7, 13, 14})) { const auto& last_node = *(next_node.OutputNodesBegin()); if (last_node.GetExecutionProviderType() != node->GetExecutionProviderType()) { continue; } - if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Relu", {6, 13}) && + if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Relu", {6, 13, 14}) && next_node.GetOutputEdgesCount() == 1) { Node& conv_node = *node; Node& add_node = *graph.GetNode(next_node.Index()); @@ -158,7 +158,7 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l // Test if this is an activation that can be fused and also extract the // activation's parameters. std::vector activation_params; - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13}) && + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13, 14}) && !graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Sigmoid", {6, 13}) && !graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Tanh", {6, 13})) { if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LeakyRelu", {6})) { diff --git a/onnxruntime/core/optimizer/conv_add_fusion.cc b/onnxruntime/core/optimizer/conv_add_fusion.cc index f9fb9a3178..57c52643f9 100644 --- a/onnxruntime/core/optimizer/conv_add_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_fusion.cc @@ -111,7 +111,7 @@ bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const } const auto& next_node = *node.OutputNodesBegin(); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13, 14}) || next_node.GetInputEdgesCount() != 1 || // Make sure the two nodes do not span execution providers. next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { diff --git a/onnxruntime/core/optimizer/conv_bn_fusion.cc b/onnxruntime/core/optimizer/conv_bn_fusion.cc index 991249f16d..a3cceb9524 100644 --- a/onnxruntime/core/optimizer/conv_bn_fusion.cc +++ b/onnxruntime/core/optimizer/conv_bn_fusion.cc @@ -151,7 +151,7 @@ bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const } const auto& next_node = *node.OutputNodesBegin(); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "BatchNormalization", {7, 9}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "BatchNormalization", {7, 9, 14}) || next_node.GetInputEdgesCount() != 1 || // Make sure the two nodes do not span execution providers. next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { diff --git a/onnxruntime/core/optimizer/conv_mul_fusion.cc b/onnxruntime/core/optimizer/conv_mul_fusion.cc index 13ba90911f..a70618460c 100644 --- a/onnxruntime/core/optimizer/conv_mul_fusion.cc +++ b/onnxruntime/core/optimizer/conv_mul_fusion.cc @@ -119,7 +119,7 @@ bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const } const auto& next_node = *node.OutputNodesBegin(); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13, 14}) || next_node.GetInputEdgesCount() != 1 || // Make sure the two nodes do not span execution providers. next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { diff --git a/onnxruntime/core/optimizer/div_mul_fusion.cc b/onnxruntime/core/optimizer/div_mul_fusion.cc index 4f191b1b63..39701c3248 100644 --- a/onnxruntime/core/optimizer/div_mul_fusion.cc +++ b/onnxruntime/core/optimizer/div_mul_fusion.cc @@ -17,13 +17,13 @@ when the first input to Div is 1. 1 / x1 * x2 -> x2 / x1 */ bool DivMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Div", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Div", {7, 13, 14}) || node.GetOutputEdgesCount() != 1) { return false; } const auto& next_node = *node.OutputNodesBegin(); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13, 14}) || // Make sure the two nodes do not span execution providers. next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { return false; diff --git a/onnxruntime/core/optimizer/fast_gelu_fusion.cc b/onnxruntime/core/optimizer/fast_gelu_fusion.cc index 6c622e096f..ad2ea5a369 100644 --- a/onnxruntime/core/optimizer/fast_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/fast_gelu_fusion.cc @@ -37,7 +37,7 @@ static bool CheckNode(Graph& graph, const Node& node, const std::string& op_name MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node, std::vector>& nodes_to_fuse) const { MatchResult matchResult{false, nullptr, nullptr}; - if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul1_node, "Mul", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul1_node, "Mul", {7, 13, 14}) || !graph_utils::IsSupportedProvider(mul1_node, GetCompatibleExecutionProviders()) || mul1_node.GetOutputEdgesCount() != 1 || !IsSupportedDataType(mul1_node)) { @@ -60,7 +60,7 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node, Node& mul2_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index()); input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *mul1_node.MutableOutputDefs()[0]); - if (!CheckNode(graph, mul2_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true) || + if (!CheckNode(graph, mul2_node, "Mul", {7, 13, 14}, mul1_node.GetExecutionProviderType(), true) || mul2_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != gelu_without_bias_input_arg->Name()) { return matchResult; } @@ -68,14 +68,14 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node, Node& add1_node = *graph.GetNode(mul2_node.OutputNodesBegin()->Index()); input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul2_node.MutableOutputDefs()[0]); - if (!CheckNode(graph, add1_node, "Add", {7, 13}, mul1_node.GetExecutionProviderType(), true) || + if (!CheckNode(graph, add1_node, "Add", {7, 13, 14}, mul1_node.GetExecutionProviderType(), true) || !optimizer_utils::IsInitializerWithExpectedValue(graph, *(add1_node.InputDefs()[(input_index + 1) % 2]), 1.0f, true)) { return matchResult; } nodes_to_fuse.push_back(add1_node); Node& mul3_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index()); - if (!CheckNode(graph, mul3_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true)) { + if (!CheckNode(graph, mul3_node, "Mul", {7, 13, 14}, mul1_node.GetExecutionProviderType(), true)) { return matchResult; } nodes_to_fuse.push_back(mul3_node); @@ -84,7 +84,7 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node, const Node* p_mul3_input_node = graph_utils::GetInputNode(mul3_node, (input_index + 1) % 2); if (p_mul3_input_node == nullptr) return matchResult; Node& mul4_node = const_cast(*p_mul3_input_node); - if (!CheckNode(graph, mul4_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true)) { + if (!CheckNode(graph, mul4_node, "Mul", {7, 13, 14}, mul1_node.GetExecutionProviderType(), true)) { return matchResult; } @@ -126,7 +126,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node, Node& mul1_node = *graph.GetNode(pow1_node.OutputNodesBegin()->Index()); auto input_index = optimizer_utils::IndexOfNodeInput(mul1_node, *pow1_node.MutableOutputDefs()[0]); - if (!CheckNode(graph, mul1_node, "Mul", {7, 13}, pow1_node.GetExecutionProviderType(), true) || + if (!CheckNode(graph, mul1_node, "Mul", {7, 13, 14}, pow1_node.GetExecutionProviderType(), true) || !optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul1_node.InputDefs()[(input_index + 1) % 2]), 0.044714998453855515f, true)) { return matchResult; @@ -135,7 +135,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node, Node& add1_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index()); input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul1_node.MutableOutputDefs()[0]); - if (!CheckNode(graph, add1_node, "Add", {7, 13}, pow1_node.GetExecutionProviderType(), true) || + if (!CheckNode(graph, add1_node, "Add", {7, 13, 14}, pow1_node.GetExecutionProviderType(), true) || add1_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != pow_input_arg->Name()) { return matchResult; } @@ -162,7 +162,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node, Node& mul2_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index()); input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *add1_node.MutableOutputDefs()[0]); - if (!CheckNode(graph, mul2_node, "Mul", {7, 13}, pow1_node.GetExecutionProviderType(), true) || + if (!CheckNode(graph, mul2_node, "Mul", {7, 13, 14}, pow1_node.GetExecutionProviderType(), true) || !optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul2_node.InputDefs()[(input_index + 1) % 2]), 0.7978845834732056f, true)) { return matchResult; @@ -220,7 +220,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } Node& add2_node = *graph.GetNode(tanh_node.OutputNodesBegin()->Index()); - if (!CheckNode(graph, add2_node, "Add", {7, 13}, node.GetExecutionProviderType(), true)) { + if (!CheckNode(graph, add2_node, "Add", {7, 13, 14}, node.GetExecutionProviderType(), true)) { continue; } @@ -231,7 +231,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& mul5_node = *graph.GetNode(add2_node.OutputNodesBegin()->Index()); // This is the output of the Gelu subgraph, we don't need check it has single edge. - if (!CheckNode(graph, mul5_node, "Mul", {7, 13}, node.GetExecutionProviderType(), false)) { + if (!CheckNode(graph, mul5_node, "Mul", {7, 13, 14}, node.GetExecutionProviderType(), false)) { continue; } @@ -263,7 +263,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } Node& mul6_node = const_cast(*p_mul5_input_node); - if (!CheckNode(graph, mul6_node, "Mul", {7, 13}, node.GetExecutionProviderType(), false)) { + if (!CheckNode(graph, mul6_node, "Mul", {7, 13, 14}, node.GetExecutionProviderType(), false)) { continue; } diff --git a/onnxruntime/core/optimizer/gelu_fusion.cc b/onnxruntime/core/optimizer/gelu_fusion.cc index b5ddc68f4f..130f222eb0 100644 --- a/onnxruntime/core/optimizer/gelu_fusion.cc +++ b/onnxruntime/core/optimizer/gelu_fusion.cc @@ -55,7 +55,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons Node& div = *p_div; ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7, 13, 14}) || !graph_utils::IsSupportedProvider(div, GetCompatibleExecutionProviders()) || !optimizer_utils::CheckOutputEdges(graph, div, 1) || !IsSupportedDataType(div)) { @@ -79,7 +79,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons } Node& add_node = *graph.GetNode(erf_node.OutputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13, 14}) || add_node.GetExecutionProviderType() != div.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, add_node, 1) || !IsSupportedDataType(add_node)) { @@ -95,7 +95,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons Node& mul_node = *graph.GetNode(add_node.OutputNodesBegin()->Index()); // note: output edges count doesn't matter as the new Gelu node will produce outputs with the same names - if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) || mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() || !IsSupportedDataType(mul_node)) { continue; @@ -106,7 +106,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons if (p_mul2_node != nullptr) { // Match subgraph pattern 1 Node& mul2_node = *graph.GetNode(p_mul2_node->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13, 14}) || mul2_node.GetExecutionProviderType() != div.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, mul2_node, 1) || !IsSupportedDataType(mul2_node)) { @@ -139,7 +139,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons continue; Node& mul2_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13, 14}) || mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() || !IsSupportedDataType(mul_node)) { continue; diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index b2cb4995d2..214d3163fb 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -24,7 +24,7 @@ bool IsFusableActivation(const Node& node) { return IsSupportedOptypeVersionAndDomain(node, "Elu", {6}, kOnnxDomain) || IsSupportedOptypeVersionAndDomain(node, "HardSigmoid", {6}, kOnnxDomain) || IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", {6}, kOnnxDomain) || - IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13}, kOnnxDomain) || + IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}, kOnnxDomain) || IsSupportedOptypeVersionAndDomain(node, "Selu", {6}, kOnnxDomain) || IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}, kOnnxDomain) || IsSupportedOptypeVersionAndDomain(node, "Softplus", {1}, kOnnxDomain) || diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 205df4f971..c6fee2556c 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -109,7 +109,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } Node& sub_node = *graph.GetNode(p_sub_node->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node, "Sub", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node, "Sub", {7, 13, 14}) || sub_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, sub_node, subCnt == 1 ? 2u : 1u) || !IsSupportedDataType(sub_node)) { @@ -124,7 +124,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // Find the sub_dup node if exist if (p_sub_node_dup != nullptr) { Node& sub_node_dup = *graph.GetNode(p_sub_node_dup->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node_dup, "Sub", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node_dup, "Sub", {7, 13, 14}) || sub_node_dup.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, sub_node, 1) || !IsSupportedDataType(sub_node_dup)) { @@ -141,7 +141,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } Node& div_node = *graph.GetNode(p_div->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13, 14}) || div_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, div_node, 1) || !IsSupportedDataType(div_node)) { @@ -167,7 +167,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // Traceback the sqrt node to find add --> sqrt Node& add2_node = *graph.GetNode(sqrt_node.InputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7, 13, 14}) || add2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, add2_node, 1) || !IsSupportedDataType(add2_node)) { @@ -224,7 +224,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // div --> mul Node& mul_node = *graph.GetNode(div_node.OutputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) || mul_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, mul_node, 1) || !IsSupportedDataType(mul_node)) { @@ -235,7 +235,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // mul --> add // Need not check output edges of last node since they will be moved to fused node. Node& last_add_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(last_add_node, "Add", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(last_add_node, "Add", {7, 13, 14}) || last_add_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !IsSupportedDataType(last_add_node)) { continue; @@ -404,7 +404,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr continue; } Node& add_node = *graph.GetNode(p_add->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13, 14}) || add_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, add_node, 1) || !IsSupportedDataType(add_node)) { @@ -432,7 +432,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr continue; } Node& div_node = *graph.GetNode(p_div->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13, 14}) || div_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, div_node, 1) || !IsSupportedDataType(div_node)) { @@ -488,7 +488,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr } Node& mul_node = *next_node; - if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) || mul_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() || !IsSupportedDataType(mul_node)) { continue; diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index 6ed0b3734d..d65a8343c1 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -40,7 +40,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } const Node& next_node = (*next_node_itr); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13, 14}) || next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { continue; } diff --git a/onnxruntime/core/optimizer/matmul_integer_to_float.cc b/onnxruntime/core/optimizer/matmul_integer_to_float.cc index 630455c621..758bfdd75e 100644 --- a/onnxruntime/core/optimizer/matmul_integer_to_float.cc +++ b/onnxruntime/core/optimizer/matmul_integer_to_float.cc @@ -64,7 +64,7 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g ORT_RETURN_IF_ERROR(Recurse(mul_node, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) || !graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders())) { continue; } diff --git a/onnxruntime/core/optimizer/matmul_scale_fusion.cc b/onnxruntime/core/optimizer/matmul_scale_fusion.cc index 2909114aad..36affcba9b 100644 --- a/onnxruntime/core/optimizer/matmul_scale_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_scale_fusion.cc @@ -62,7 +62,7 @@ optional> GetScaleFromNode( return excluded_initializer_names.find(node_arg.Name()) != excluded_initializer_names.end(); }; - if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Div", {7, 13})) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Div", {7, 13, 14})) { // (x / scale_reciprocal) const auto div_inputs = scale_node.InputDefs(); ORT_ENFORCE(div_inputs.size() == 2); @@ -79,7 +79,7 @@ optional> GetScaleFromNode( return {std::make_pair(1.0f / divisor.value(), scale_reciprocal_arg_index)}; } - if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Mul", {7, 13})) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Mul", {7, 13, 14})) { // (x * scale) or (scale * x) const auto mul_inputs = scale_node.InputDefs(); ORT_ENFORCE(mul_inputs.size() == 2); diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 253db85e64..2241474a35 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -1170,18 +1170,18 @@ void NchwcTransformerImpl::Transform(Node& node) { // node may already have all inputs converted to NCHWc format and is not // needed for correct operation. This avoids doing extra string checks for // nodes unrelated to this transformer. - if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) || + if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13, 14}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sum", {6, 8, 13})) { TransformBinary(node, true); - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13})) { + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14})) { TransformBinary(node, false); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11, 13})) { TransformConcat(node); - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13}) || + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13})) { TransformActivation(node); - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9})) { + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9, 14})) { TransformBatchNormalization(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13})) { TransformTransposeToNhwc(node); diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index 894bd034a8..2e3f3c71e5 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -14,7 +14,7 @@ namespace onnxruntime { static bool CanNodePropagate(const Node& node) { return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {12}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {5, 13}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {5, 13, 14}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13}); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/relu_quantizelinear.cc b/onnxruntime/core/optimizer/qdq_transformer/relu_quantizelinear.cc index 525b244dab..b8f15fb9a2 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/relu_quantizelinear.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/relu_quantizelinear.cc @@ -11,7 +11,7 @@ using namespace onnxruntime::common; namespace onnxruntime { bool ReluQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) || !optimizer_utils::CheckOutputEdges(graph, node, 1)) { return false; } diff --git a/onnxruntime/core/optimizer/relu_clip_fusion.cc b/onnxruntime/core/optimizer/relu_clip_fusion.cc index 981ce9b554..a8ada8853d 100644 --- a/onnxruntime/core/optimizer/relu_clip_fusion.cc +++ b/onnxruntime/core/optimizer/relu_clip_fusion.cc @@ -115,7 +115,7 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff } bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13})) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14})) { return false; } diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index bbd6ff5d96..436dcaaf3f 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -23,11 +23,16 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c Node& reshape = *p_reshape; ORT_RETURN_IF_ERROR(Recurse(reshape, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(reshape, "Reshape", {5, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(reshape, "Reshape", {5, 13, 14}) || !graph_utils::IsSupportedProvider(reshape, GetCompatibleExecutionProviders())) { continue; } + const auto* attr_proto = graph_utils::GetNodeAttribute(reshape, "allowzero"); + if ((nullptr != attr_proto) && attr_proto->has_i() && attr_proto->i() != 0) { + continue; + } + if (ReshapeFusion::Fuse_Subgraph(reshape, graph, logger)) { fused_count++; LOGS(logger, INFO) << "Fused reshape node: " << reshape.OutputDefs()[0]->Name(); @@ -255,11 +260,11 @@ bool ReshapeFusion::Is_One_Element_Output_Subgraph(Graph& graph, const NodeArg& std::vector div_path{ {0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Div", {7, 13}, kOnnxDomain}}; + {0, 0, "Div", {7, 13, 14}, kOnnxDomain}}; std::vector mul_path{ {0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Mul", {7, 13}, kOnnxDomain}}; + {0, 0, "Mul", {7, 13, 14}, kOnnxDomain}}; std::vector unsqueeze_path{ {0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain}};