mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
update optimizers for opset14 (#7722)
* update optimizers for opset14 * plus 1 more * fix reshape fusion
This commit is contained in:
parent
26a472c948
commit
7834ca983c
21 changed files with 59 additions and 54 deletions
|
|
@ -1280,7 +1280,7 @@ TODO: replace Gemm_Subgraph by MatMul + Add
|
|||
bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::map<std::string, NodeArg*>& mask_int32_map, bool use_shared_node, const logging::Logger& logger) {
|
||||
DEBUG_LOG("Start FuseGptAttention");
|
||||
const Node* parent_node = graph_utils::GetInputNode(layer_norm, 0);
|
||||
if (nullptr == parent_node || !graph_utils::IsSupportedOptypeVersionAndDomain(*parent_node, "Add", {7, 13}, kOnnxDomain)) {
|
||||
if (nullptr == parent_node || !graph_utils::IsSupportedOptypeVersionAndDomain(*parent_node, "Add", {7, 13, 14}, kOnnxDomain)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ void FuseResidualAddIfAny(Graph& graph, const Node& dropout_node,
|
|||
for (auto last_node_itr = dropout_node.OutputNodesBegin(); last_node_itr != dropout_node.OutputNodesEnd(); ++last_node_itr) {
|
||||
const Node& last_node = (*last_node_itr);
|
||||
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Add", {7, 13}) &&
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Add", {7, 13, 14}) &&
|
||||
last_node.GetExecutionProviderType() == dropout_node.GetExecutionProviderType()) {
|
||||
const TensorShapeProto* input1_shape = last_node.InputDefs()[0]->Shape();
|
||||
const TensorShapeProto* input2_shape = last_node.InputDefs()[1]->Shape();
|
||||
|
|
@ -90,7 +90,7 @@ Status BiasDropoutFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve
|
|||
std::vector<std::reference_wrapper<Node>> nodes_to_fuse;
|
||||
|
||||
// matching for bias Add node
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13, 14}) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13, 14}) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, node, 1)) {
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ bool TryBiasSoftmaxSubgraphMatch(Graph& graph, Node& start, Node*& add, Node*& s
|
|||
add = softmax = nullptr;
|
||||
|
||||
// check node is add and has single output
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13, 14}) ||
|
||||
!graph_utils::IsSupportedProvider(node, {kCudaExecutionProvider, kRocmExecutionProvider}) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, node, 1)) {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
continue;
|
||||
}
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13})) {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13, 14})) {
|
||||
Node& conv_node = *node;
|
||||
Node& act_node = *graph.GetNode(next_node.Index());
|
||||
auto node_name = graph.GenerateNodeName(conv_node.Name() + "_" + act_node.Name());
|
||||
|
|
@ -120,12 +120,12 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
fused_conv.AddAttribute("activation", "Relu");
|
||||
graph_utils::FinalizeNodeFusion(graph, {conv_node, act_node}, fused_conv);
|
||||
modified = true;
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {6, 7, 13})) {
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {6, 7, 13, 14})) {
|
||||
const auto& last_node = *(next_node.OutputNodesBegin());
|
||||
if (last_node.GetExecutionProviderType() != node->GetExecutionProviderType()) {
|
||||
continue;
|
||||
}
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Relu", {6, 13}) &&
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Relu", {6, 13, 14}) &&
|
||||
next_node.GetOutputEdgesCount() == 1) {
|
||||
Node& conv_node = *node;
|
||||
Node& add_node = *graph.GetNode(next_node.Index());
|
||||
|
|
@ -158,7 +158,7 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
// Test if this is an activation that can be fused and also extract the
|
||||
// activation's parameters.
|
||||
std::vector<float> activation_params;
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13}) &&
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13, 14}) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Sigmoid", {6, 13}) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Tanh", {6, 13})) {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LeakyRelu", {6})) {
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const
|
|||
}
|
||||
|
||||
const auto& next_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13, 14}) ||
|
||||
next_node.GetInputEdgesCount() != 1 ||
|
||||
// Make sure the two nodes do not span execution providers.
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const
|
|||
}
|
||||
|
||||
const auto& next_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "BatchNormalization", {7, 9}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "BatchNormalization", {7, 9, 14}) ||
|
||||
next_node.GetInputEdgesCount() != 1 ||
|
||||
// Make sure the two nodes do not span execution providers.
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const
|
|||
}
|
||||
|
||||
const auto& next_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13, 14}) ||
|
||||
next_node.GetInputEdgesCount() != 1 ||
|
||||
// Make sure the two nodes do not span execution providers.
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
|
|
|
|||
|
|
@ -17,13 +17,13 @@ when the first input to Div is 1.
|
|||
1 / x1 * x2 -> x2 / x1
|
||||
*/
|
||||
bool DivMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Div", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Div", {7, 13, 14}) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& next_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13, 14}) ||
|
||||
// Make sure the two nodes do not span execution providers.
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ static bool CheckNode(Graph& graph, const Node& node, const std::string& op_name
|
|||
MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
|
||||
std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const {
|
||||
MatchResult matchResult{false, nullptr, nullptr};
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul1_node, "Mul", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul1_node, "Mul", {7, 13, 14}) ||
|
||||
!graph_utils::IsSupportedProvider(mul1_node, GetCompatibleExecutionProviders()) ||
|
||||
mul1_node.GetOutputEdgesCount() != 1 ||
|
||||
!IsSupportedDataType(mul1_node)) {
|
||||
|
|
@ -60,7 +60,7 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
|
|||
|
||||
Node& mul2_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index());
|
||||
input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *mul1_node.MutableOutputDefs()[0]);
|
||||
if (!CheckNode(graph, mul2_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true) ||
|
||||
if (!CheckNode(graph, mul2_node, "Mul", {7, 13, 14}, mul1_node.GetExecutionProviderType(), true) ||
|
||||
mul2_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != gelu_without_bias_input_arg->Name()) {
|
||||
return matchResult;
|
||||
}
|
||||
|
|
@ -68,14 +68,14 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
|
|||
|
||||
Node& add1_node = *graph.GetNode(mul2_node.OutputNodesBegin()->Index());
|
||||
input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul2_node.MutableOutputDefs()[0]);
|
||||
if (!CheckNode(graph, add1_node, "Add", {7, 13}, mul1_node.GetExecutionProviderType(), true) ||
|
||||
if (!CheckNode(graph, add1_node, "Add", {7, 13, 14}, mul1_node.GetExecutionProviderType(), true) ||
|
||||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(add1_node.InputDefs()[(input_index + 1) % 2]), 1.0f, true)) {
|
||||
return matchResult;
|
||||
}
|
||||
nodes_to_fuse.push_back(add1_node);
|
||||
|
||||
Node& mul3_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index());
|
||||
if (!CheckNode(graph, mul3_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true)) {
|
||||
if (!CheckNode(graph, mul3_node, "Mul", {7, 13, 14}, mul1_node.GetExecutionProviderType(), true)) {
|
||||
return matchResult;
|
||||
}
|
||||
nodes_to_fuse.push_back(mul3_node);
|
||||
|
|
@ -84,7 +84,7 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
|
|||
const Node* p_mul3_input_node = graph_utils::GetInputNode(mul3_node, (input_index + 1) % 2);
|
||||
if (p_mul3_input_node == nullptr) return matchResult;
|
||||
Node& mul4_node = const_cast<Node&>(*p_mul3_input_node);
|
||||
if (!CheckNode(graph, mul4_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true)) {
|
||||
if (!CheckNode(graph, mul4_node, "Mul", {7, 13, 14}, mul1_node.GetExecutionProviderType(), true)) {
|
||||
return matchResult;
|
||||
}
|
||||
|
||||
|
|
@ -126,7 +126,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
|
|||
|
||||
Node& mul1_node = *graph.GetNode(pow1_node.OutputNodesBegin()->Index());
|
||||
auto input_index = optimizer_utils::IndexOfNodeInput(mul1_node, *pow1_node.MutableOutputDefs()[0]);
|
||||
if (!CheckNode(graph, mul1_node, "Mul", {7, 13}, pow1_node.GetExecutionProviderType(), true) ||
|
||||
if (!CheckNode(graph, mul1_node, "Mul", {7, 13, 14}, pow1_node.GetExecutionProviderType(), true) ||
|
||||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul1_node.InputDefs()[(input_index + 1) % 2]),
|
||||
0.044714998453855515f, true)) {
|
||||
return matchResult;
|
||||
|
|
@ -135,7 +135,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
|
|||
|
||||
Node& add1_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index());
|
||||
input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul1_node.MutableOutputDefs()[0]);
|
||||
if (!CheckNode(graph, add1_node, "Add", {7, 13}, pow1_node.GetExecutionProviderType(), true) ||
|
||||
if (!CheckNode(graph, add1_node, "Add", {7, 13, 14}, pow1_node.GetExecutionProviderType(), true) ||
|
||||
add1_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != pow_input_arg->Name()) {
|
||||
return matchResult;
|
||||
}
|
||||
|
|
@ -162,7 +162,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
|
|||
|
||||
Node& mul2_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index());
|
||||
input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *add1_node.MutableOutputDefs()[0]);
|
||||
if (!CheckNode(graph, mul2_node, "Mul", {7, 13}, pow1_node.GetExecutionProviderType(), true) ||
|
||||
if (!CheckNode(graph, mul2_node, "Mul", {7, 13, 14}, pow1_node.GetExecutionProviderType(), true) ||
|
||||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul2_node.InputDefs()[(input_index + 1) % 2]),
|
||||
0.7978845834732056f, true)) {
|
||||
return matchResult;
|
||||
|
|
@ -220,7 +220,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
}
|
||||
|
||||
Node& add2_node = *graph.GetNode(tanh_node.OutputNodesBegin()->Index());
|
||||
if (!CheckNode(graph, add2_node, "Add", {7, 13}, node.GetExecutionProviderType(), true)) {
|
||||
if (!CheckNode(graph, add2_node, "Add", {7, 13, 14}, node.GetExecutionProviderType(), true)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -231,7 +231,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
|
||||
Node& mul5_node = *graph.GetNode(add2_node.OutputNodesBegin()->Index());
|
||||
// This is the output of the Gelu subgraph, we don't need check it has single edge.
|
||||
if (!CheckNode(graph, mul5_node, "Mul", {7, 13}, node.GetExecutionProviderType(), false)) {
|
||||
if (!CheckNode(graph, mul5_node, "Mul", {7, 13, 14}, node.GetExecutionProviderType(), false)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -263,7 +263,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
}
|
||||
|
||||
Node& mul6_node = const_cast<Node&>(*p_mul5_input_node);
|
||||
if (!CheckNode(graph, mul6_node, "Mul", {7, 13}, node.GetExecutionProviderType(), false)) {
|
||||
if (!CheckNode(graph, mul6_node, "Mul", {7, 13, 14}, node.GetExecutionProviderType(), false)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
|
|||
Node& div = *p_div;
|
||||
ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7, 13, 14}) ||
|
||||
!graph_utils::IsSupportedProvider(div, GetCompatibleExecutionProviders()) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, div, 1) ||
|
||||
!IsSupportedDataType(div)) {
|
||||
|
|
@ -79,7 +79,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
|
|||
}
|
||||
|
||||
Node& add_node = *graph.GetNode(erf_node.OutputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13, 14}) ||
|
||||
add_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, add_node, 1) ||
|
||||
!IsSupportedDataType(add_node)) {
|
||||
|
|
@ -95,7 +95,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
|
|||
|
||||
Node& mul_node = *graph.GetNode(add_node.OutputNodesBegin()->Index());
|
||||
// note: output edges count doesn't matter as the new Gelu node will produce outputs with the same names
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
|
||||
mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
!IsSupportedDataType(mul_node)) {
|
||||
continue;
|
||||
|
|
@ -106,7 +106,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
|
|||
if (p_mul2_node != nullptr) {
|
||||
// Match subgraph pattern 1
|
||||
Node& mul2_node = *graph.GetNode(p_mul2_node->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13, 14}) ||
|
||||
mul2_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, mul2_node, 1) ||
|
||||
!IsSupportedDataType(mul2_node)) {
|
||||
|
|
@ -139,7 +139,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
|
|||
continue;
|
||||
|
||||
Node& mul2_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13, 14}) ||
|
||||
mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
!IsSupportedDataType(mul_node)) {
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ bool IsFusableActivation(const Node& node) {
|
|||
return IsSupportedOptypeVersionAndDomain(node, "Elu", {6}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "HardSigmoid", {6}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", {6}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "Selu", {6}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "Softplus", {1}, kOnnxDomain) ||
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
}
|
||||
|
||||
Node& sub_node = *graph.GetNode(p_sub_node->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node, "Sub", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node, "Sub", {7, 13, 14}) ||
|
||||
sub_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, sub_node, subCnt == 1 ? 2u : 1u) ||
|
||||
!IsSupportedDataType(sub_node)) {
|
||||
|
|
@ -124,7 +124,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
// Find the sub_dup node if exist
|
||||
if (p_sub_node_dup != nullptr) {
|
||||
Node& sub_node_dup = *graph.GetNode(p_sub_node_dup->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node_dup, "Sub", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node_dup, "Sub", {7, 13, 14}) ||
|
||||
sub_node_dup.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, sub_node, 1) ||
|
||||
!IsSupportedDataType(sub_node_dup)) {
|
||||
|
|
@ -141,7 +141,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
continue;
|
||||
}
|
||||
Node& div_node = *graph.GetNode(p_div->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13, 14}) ||
|
||||
div_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, div_node, 1) ||
|
||||
!IsSupportedDataType(div_node)) {
|
||||
|
|
@ -167,7 +167,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
|
||||
// Traceback the sqrt node to find add --> sqrt
|
||||
Node& add2_node = *graph.GetNode(sqrt_node.InputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7, 13, 14}) ||
|
||||
add2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, add2_node, 1) ||
|
||||
!IsSupportedDataType(add2_node)) {
|
||||
|
|
@ -224,7 +224,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
|
||||
// div --> mul
|
||||
Node& mul_node = *graph.GetNode(div_node.OutputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
|
||||
mul_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, mul_node, 1) ||
|
||||
!IsSupportedDataType(mul_node)) {
|
||||
|
|
@ -235,7 +235,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
// mul --> add
|
||||
// Need not check output edges of last node since they will be moved to fused node.
|
||||
Node& last_add_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(last_add_node, "Add", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(last_add_node, "Add", {7, 13, 14}) ||
|
||||
last_add_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!IsSupportedDataType(last_add_node)) {
|
||||
continue;
|
||||
|
|
@ -404,7 +404,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
continue;
|
||||
}
|
||||
Node& add_node = *graph.GetNode(p_add->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13, 14}) ||
|
||||
add_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, add_node, 1) ||
|
||||
!IsSupportedDataType(add_node)) {
|
||||
|
|
@ -432,7 +432,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
continue;
|
||||
}
|
||||
Node& div_node = *graph.GetNode(p_div->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13, 14}) ||
|
||||
div_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, div_node, 1) ||
|
||||
!IsSupportedDataType(div_node)) {
|
||||
|
|
@ -488,7 +488,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
}
|
||||
|
||||
Node& mul_node = *next_node;
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
|
||||
mul_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() ||
|
||||
!IsSupportedDataType(mul_node)) {
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
}
|
||||
|
||||
const Node& next_node = (*next_node_itr);
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13, 14}) ||
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g
|
|||
|
||||
ORT_RETURN_IF_ERROR(Recurse(mul_node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
|
||||
!graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders())) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ optional<std::pair<float, int>> GetScaleFromNode(
|
|||
return excluded_initializer_names.find(node_arg.Name()) != excluded_initializer_names.end();
|
||||
};
|
||||
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Div", {7, 13})) {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Div", {7, 13, 14})) {
|
||||
// (x / scale_reciprocal)
|
||||
const auto div_inputs = scale_node.InputDefs();
|
||||
ORT_ENFORCE(div_inputs.size() == 2);
|
||||
|
|
@ -79,7 +79,7 @@ optional<std::pair<float, int>> GetScaleFromNode(
|
|||
return {std::make_pair(1.0f / divisor.value(), scale_reciprocal_arg_index)};
|
||||
}
|
||||
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Mul", {7, 13})) {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Mul", {7, 13, 14})) {
|
||||
// (x * scale) or (scale * x)
|
||||
const auto mul_inputs = scale_node.InputDefs();
|
||||
ORT_ENFORCE(mul_inputs.size() == 2);
|
||||
|
|
|
|||
|
|
@ -1170,18 +1170,18 @@ void NchwcTransformerImpl::Transform(Node& node) {
|
|||
// node may already have all inputs converted to NCHWc format and is not
|
||||
// needed for correct operation. This avoids doing extra string checks for
|
||||
// nodes unrelated to this transformer.
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) ||
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13, 14}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sum", {6, 8, 13})) {
|
||||
TransformBinary(node, true);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13})) {
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14})) {
|
||||
TransformBinary(node, false);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11, 13})) {
|
||||
TransformConcat(node);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13}) ||
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13})) {
|
||||
TransformActivation(node);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9})) {
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9, 14})) {
|
||||
TransformBatchNormalization(node);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13})) {
|
||||
TransformTransposeToNhwc(node);
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ namespace onnxruntime {
|
|||
|
||||
static bool CanNodePropagate(const Node& node) {
|
||||
return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {12}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {5, 13}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {5, 13, 14}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ using namespace onnxruntime::common;
|
|||
namespace onnxruntime {
|
||||
|
||||
bool ReluQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, node, 1)) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
|
|||
}
|
||||
|
||||
bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13})) {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,11 +23,16 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c
|
|||
Node& reshape = *p_reshape;
|
||||
ORT_RETURN_IF_ERROR(Recurse(reshape, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reshape, "Reshape", {5, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reshape, "Reshape", {5, 13, 14}) ||
|
||||
!graph_utils::IsSupportedProvider(reshape, GetCompatibleExecutionProviders())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto* attr_proto = graph_utils::GetNodeAttribute(reshape, "allowzero");
|
||||
if ((nullptr != attr_proto) && attr_proto->has_i() && attr_proto->i() != 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ReshapeFusion::Fuse_Subgraph(reshape, graph, logger)) {
|
||||
fused_count++;
|
||||
LOGS(logger, INFO) << "Fused reshape node: " << reshape.OutputDefs()[0]->Name();
|
||||
|
|
@ -255,11 +260,11 @@ bool ReshapeFusion::Is_One_Element_Output_Subgraph(Graph& graph, const NodeArg&
|
|||
|
||||
std::vector<graph_utils::EdgeEndToMatch> div_path{
|
||||
{0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Div", {7, 13}, kOnnxDomain}};
|
||||
{0, 0, "Div", {7, 13, 14}, kOnnxDomain}};
|
||||
|
||||
std::vector<graph_utils::EdgeEndToMatch> mul_path{
|
||||
{0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Mul", {7, 13}, kOnnxDomain}};
|
||||
{0, 0, "Mul", {7, 13, 14}, kOnnxDomain}};
|
||||
|
||||
std::vector<graph_utils::EdgeEndToMatch> unsqueeze_path{
|
||||
{0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain}};
|
||||
|
|
|
|||
Loading…
Reference in a new issue