update optimizers for opset14 (#7722)

* update optimizers for opset14

* plus 1 more

* fix reshape fusion
This commit is contained in:
Ashwini Khade 2021-05-18 11:58:14 -07:00 committed by GitHub
parent 26a472c948
commit 7834ca983c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 59 additions and 54 deletions

View file

@ -1280,7 +1280,7 @@ TODO: replace Gemm_Subgraph by MatMul + Add
bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::map<std::string, NodeArg*>& mask_int32_map, bool use_shared_node, const logging::Logger& logger) {
DEBUG_LOG("Start FuseGptAttention");
const Node* parent_node = graph_utils::GetInputNode(layer_norm, 0);
if (nullptr == parent_node || !graph_utils::IsSupportedOptypeVersionAndDomain(*parent_node, "Add", {7, 13}, kOnnxDomain)) {
if (nullptr == parent_node || !graph_utils::IsSupportedOptypeVersionAndDomain(*parent_node, "Add", {7, 13, 14}, kOnnxDomain)) {
return false;
}

View file

@ -29,7 +29,7 @@ void FuseResidualAddIfAny(Graph& graph, const Node& dropout_node,
for (auto last_node_itr = dropout_node.OutputNodesBegin(); last_node_itr != dropout_node.OutputNodesEnd(); ++last_node_itr) {
const Node& last_node = (*last_node_itr);
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Add", {7, 13}) &&
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Add", {7, 13, 14}) &&
last_node.GetExecutionProviderType() == dropout_node.GetExecutionProviderType()) {
const TensorShapeProto* input1_shape = last_node.InputDefs()[0]->Shape();
const TensorShapeProto* input2_shape = last_node.InputDefs()[1]->Shape();
@ -90,7 +90,7 @@ Status BiasDropoutFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve
std::vector<std::reference_wrapper<Node>> nodes_to_fuse;
// matching for bias Add node
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13, 14}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
node.GetOutputEdgesCount() != 1) {
continue;

View file

@ -24,7 +24,7 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13, 14}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
!optimizer_utils::CheckOutputEdges(graph, node, 1)) {
continue;

View file

@ -43,7 +43,7 @@ bool TryBiasSoftmaxSubgraphMatch(Graph& graph, Node& start, Node*& add, Node*& s
add = softmax = nullptr;
// check node is add and has single output
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13, 14}) ||
!graph_utils::IsSupportedProvider(node, {kCudaExecutionProvider, kRocmExecutionProvider}) ||
!optimizer_utils::CheckOutputEdges(graph, node, 1)) {
return false;

View file

@ -105,7 +105,7 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
continue;
}
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13})) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13, 14})) {
Node& conv_node = *node;
Node& act_node = *graph.GetNode(next_node.Index());
auto node_name = graph.GenerateNodeName(conv_node.Name() + "_" + act_node.Name());
@ -120,12 +120,12 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
fused_conv.AddAttribute("activation", "Relu");
graph_utils::FinalizeNodeFusion(graph, {conv_node, act_node}, fused_conv);
modified = true;
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {6, 7, 13})) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {6, 7, 13, 14})) {
const auto& last_node = *(next_node.OutputNodesBegin());
if (last_node.GetExecutionProviderType() != node->GetExecutionProviderType()) {
continue;
}
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Relu", {6, 13}) &&
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Relu", {6, 13, 14}) &&
next_node.GetOutputEdgesCount() == 1) {
Node& conv_node = *node;
Node& add_node = *graph.GetNode(next_node.Index());
@ -158,7 +158,7 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
// Test if this is an activation that can be fused and also extract the
// activation's parameters.
std::vector<float> activation_params;
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13}) &&
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13, 14}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Sigmoid", {6, 13}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Tanh", {6, 13})) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LeakyRelu", {6})) {

View file

@ -111,7 +111,7 @@ bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const
}
const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13, 14}) ||
next_node.GetInputEdgesCount() != 1 ||
// Make sure the two nodes do not span execution providers.
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {

View file

@ -151,7 +151,7 @@ bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const
}
const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "BatchNormalization", {7, 9}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "BatchNormalization", {7, 9, 14}) ||
next_node.GetInputEdgesCount() != 1 ||
// Make sure the two nodes do not span execution providers.
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {

View file

@ -119,7 +119,7 @@ bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const
}
const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13, 14}) ||
next_node.GetInputEdgesCount() != 1 ||
// Make sure the two nodes do not span execution providers.
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {

View file

@ -17,13 +17,13 @@ when the first input to Div is 1.
1 / x1 * x2 -> x2 / x1
*/
bool DivMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Div", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Div", {7, 13, 14}) ||
node.GetOutputEdgesCount() != 1) {
return false;
}
const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13, 14}) ||
// Make sure the two nodes do not span execution providers.
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
return false;

View file

@ -37,7 +37,7 @@ static bool CheckNode(Graph& graph, const Node& node, const std::string& op_name
MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const {
MatchResult matchResult{false, nullptr, nullptr};
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul1_node, "Mul", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul1_node, "Mul", {7, 13, 14}) ||
!graph_utils::IsSupportedProvider(mul1_node, GetCompatibleExecutionProviders()) ||
mul1_node.GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(mul1_node)) {
@ -60,7 +60,7 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
Node& mul2_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index());
input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *mul1_node.MutableOutputDefs()[0]);
if (!CheckNode(graph, mul2_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true) ||
if (!CheckNode(graph, mul2_node, "Mul", {7, 13, 14}, mul1_node.GetExecutionProviderType(), true) ||
mul2_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != gelu_without_bias_input_arg->Name()) {
return matchResult;
}
@ -68,14 +68,14 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
Node& add1_node = *graph.GetNode(mul2_node.OutputNodesBegin()->Index());
input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul2_node.MutableOutputDefs()[0]);
if (!CheckNode(graph, add1_node, "Add", {7, 13}, mul1_node.GetExecutionProviderType(), true) ||
if (!CheckNode(graph, add1_node, "Add", {7, 13, 14}, mul1_node.GetExecutionProviderType(), true) ||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(add1_node.InputDefs()[(input_index + 1) % 2]), 1.0f, true)) {
return matchResult;
}
nodes_to_fuse.push_back(add1_node);
Node& mul3_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index());
if (!CheckNode(graph, mul3_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true)) {
if (!CheckNode(graph, mul3_node, "Mul", {7, 13, 14}, mul1_node.GetExecutionProviderType(), true)) {
return matchResult;
}
nodes_to_fuse.push_back(mul3_node);
@ -84,7 +84,7 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
const Node* p_mul3_input_node = graph_utils::GetInputNode(mul3_node, (input_index + 1) % 2);
if (p_mul3_input_node == nullptr) return matchResult;
Node& mul4_node = const_cast<Node&>(*p_mul3_input_node);
if (!CheckNode(graph, mul4_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true)) {
if (!CheckNode(graph, mul4_node, "Mul", {7, 13, 14}, mul1_node.GetExecutionProviderType(), true)) {
return matchResult;
}
@ -126,7 +126,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
Node& mul1_node = *graph.GetNode(pow1_node.OutputNodesBegin()->Index());
auto input_index = optimizer_utils::IndexOfNodeInput(mul1_node, *pow1_node.MutableOutputDefs()[0]);
if (!CheckNode(graph, mul1_node, "Mul", {7, 13}, pow1_node.GetExecutionProviderType(), true) ||
if (!CheckNode(graph, mul1_node, "Mul", {7, 13, 14}, pow1_node.GetExecutionProviderType(), true) ||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul1_node.InputDefs()[(input_index + 1) % 2]),
0.044714998453855515f, true)) {
return matchResult;
@ -135,7 +135,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
Node& add1_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index());
input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul1_node.MutableOutputDefs()[0]);
if (!CheckNode(graph, add1_node, "Add", {7, 13}, pow1_node.GetExecutionProviderType(), true) ||
if (!CheckNode(graph, add1_node, "Add", {7, 13, 14}, pow1_node.GetExecutionProviderType(), true) ||
add1_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != pow_input_arg->Name()) {
return matchResult;
}
@ -162,7 +162,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
Node& mul2_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index());
input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *add1_node.MutableOutputDefs()[0]);
if (!CheckNode(graph, mul2_node, "Mul", {7, 13}, pow1_node.GetExecutionProviderType(), true) ||
if (!CheckNode(graph, mul2_node, "Mul", {7, 13, 14}, pow1_node.GetExecutionProviderType(), true) ||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul2_node.InputDefs()[(input_index + 1) % 2]),
0.7978845834732056f, true)) {
return matchResult;
@ -220,7 +220,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
Node& add2_node = *graph.GetNode(tanh_node.OutputNodesBegin()->Index());
if (!CheckNode(graph, add2_node, "Add", {7, 13}, node.GetExecutionProviderType(), true)) {
if (!CheckNode(graph, add2_node, "Add", {7, 13, 14}, node.GetExecutionProviderType(), true)) {
continue;
}
@ -231,7 +231,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
Node& mul5_node = *graph.GetNode(add2_node.OutputNodesBegin()->Index());
// This is the output of the Gelu subgraph, we don't need check it has single edge.
if (!CheckNode(graph, mul5_node, "Mul", {7, 13}, node.GetExecutionProviderType(), false)) {
if (!CheckNode(graph, mul5_node, "Mul", {7, 13, 14}, node.GetExecutionProviderType(), false)) {
continue;
}
@ -263,7 +263,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
Node& mul6_node = const_cast<Node&>(*p_mul5_input_node);
if (!CheckNode(graph, mul6_node, "Mul", {7, 13}, node.GetExecutionProviderType(), false)) {
if (!CheckNode(graph, mul6_node, "Mul", {7, 13, 14}, node.GetExecutionProviderType(), false)) {
continue;
}

View file

@ -55,7 +55,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
Node& div = *p_div;
ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7, 13, 14}) ||
!graph_utils::IsSupportedProvider(div, GetCompatibleExecutionProviders()) ||
!optimizer_utils::CheckOutputEdges(graph, div, 1) ||
!IsSupportedDataType(div)) {
@ -79,7 +79,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
}
Node& add_node = *graph.GetNode(erf_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13, 14}) ||
add_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, add_node, 1) ||
!IsSupportedDataType(add_node)) {
@ -95,7 +95,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
Node& mul_node = *graph.GetNode(add_node.OutputNodesBegin()->Index());
// note: output edges count doesn't matter as the new Gelu node will produce outputs with the same names
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
!IsSupportedDataType(mul_node)) {
continue;
@ -106,7 +106,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
if (p_mul2_node != nullptr) {
// Match subgraph pattern 1
Node& mul2_node = *graph.GetNode(p_mul2_node->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13, 14}) ||
mul2_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, mul2_node, 1) ||
!IsSupportedDataType(mul2_node)) {
@ -139,7 +139,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
continue;
Node& mul2_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13, 14}) ||
mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
!IsSupportedDataType(mul_node)) {
continue;

View file

@ -24,7 +24,7 @@ bool IsFusableActivation(const Node& node) {
return IsSupportedOptypeVersionAndDomain(node, "Elu", {6}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "HardSigmoid", {6}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", {6}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Selu", {6}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Softplus", {1}, kOnnxDomain) ||

View file

@ -109,7 +109,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
Node& sub_node = *graph.GetNode(p_sub_node->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node, "Sub", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node, "Sub", {7, 13, 14}) ||
sub_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, sub_node, subCnt == 1 ? 2u : 1u) ||
!IsSupportedDataType(sub_node)) {
@ -124,7 +124,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// Find the sub_dup node if exist
if (p_sub_node_dup != nullptr) {
Node& sub_node_dup = *graph.GetNode(p_sub_node_dup->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node_dup, "Sub", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node_dup, "Sub", {7, 13, 14}) ||
sub_node_dup.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, sub_node, 1) ||
!IsSupportedDataType(sub_node_dup)) {
@ -141,7 +141,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
Node& div_node = *graph.GetNode(p_div->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13, 14}) ||
div_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, div_node, 1) ||
!IsSupportedDataType(div_node)) {
@ -167,7 +167,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// Traceback the sqrt node to find add --> sqrt
Node& add2_node = *graph.GetNode(sqrt_node.InputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7, 13, 14}) ||
add2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, add2_node, 1) ||
!IsSupportedDataType(add2_node)) {
@ -224,7 +224,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// div --> mul
Node& mul_node = *graph.GetNode(div_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
mul_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, mul_node, 1) ||
!IsSupportedDataType(mul_node)) {
@ -235,7 +235,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// mul --> add
// Need not check output edges of last node since they will be moved to fused node.
Node& last_add_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(last_add_node, "Add", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(last_add_node, "Add", {7, 13, 14}) ||
last_add_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!IsSupportedDataType(last_add_node)) {
continue;
@ -404,7 +404,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
continue;
}
Node& add_node = *graph.GetNode(p_add->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13, 14}) ||
add_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, add_node, 1) ||
!IsSupportedDataType(add_node)) {
@ -432,7 +432,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
continue;
}
Node& div_node = *graph.GetNode(p_div->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13, 14}) ||
div_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, div_node, 1) ||
!IsSupportedDataType(div_node)) {
@ -488,7 +488,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
}
Node& mul_node = *next_node;
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
mul_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() ||
!IsSupportedDataType(mul_node)) {
continue;

View file

@ -40,7 +40,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
const Node& next_node = (*next_node_itr);
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13, 14}) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
continue;
}

View file

@ -64,7 +64,7 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g
ORT_RETURN_IF_ERROR(Recurse(mul_node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
!graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders())) {
continue;
}

View file

@ -62,7 +62,7 @@ optional<std::pair<float, int>> GetScaleFromNode(
return excluded_initializer_names.find(node_arg.Name()) != excluded_initializer_names.end();
};
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Div", {7, 13})) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Div", {7, 13, 14})) {
// (x / scale_reciprocal)
const auto div_inputs = scale_node.InputDefs();
ORT_ENFORCE(div_inputs.size() == 2);
@ -79,7 +79,7 @@ optional<std::pair<float, int>> GetScaleFromNode(
return {std::make_pair(1.0f / divisor.value(), scale_reciprocal_arg_index)};
}
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Mul", {7, 13})) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Mul", {7, 13, 14})) {
// (x * scale) or (scale * x)
const auto mul_inputs = scale_node.InputDefs();
ORT_ENFORCE(mul_inputs.size() == 2);

View file

@ -1170,18 +1170,18 @@ void NchwcTransformerImpl::Transform(Node& node) {
// node may already have all inputs converted to NCHWc format and is not
// needed for correct operation. This avoids doing extra string checks for
// nodes unrelated to this transformer.
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) ||
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sum", {6, 8, 13})) {
TransformBinary(node, true);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13})) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14})) {
TransformBinary(node, false);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11, 13})) {
TransformConcat(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13}) ||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13})) {
TransformActivation(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9})) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9, 14})) {
TransformBatchNormalization(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13})) {
TransformTransposeToNhwc(node);

View file

@ -14,7 +14,7 @@ namespace onnxruntime {
static bool CanNodePropagate(const Node& node) {
return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {12}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {5, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {5, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13});
}

View file

@ -11,7 +11,7 @@ using namespace onnxruntime::common;
namespace onnxruntime {
bool ReluQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) ||
!optimizer_utils::CheckOutputEdges(graph, node, 1)) {
return false;
}

View file

@ -115,7 +115,7 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
}
bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13})) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14})) {
return false;
}

View file

@ -23,11 +23,16 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c
Node& reshape = *p_reshape;
ORT_RETURN_IF_ERROR(Recurse(reshape, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reshape, "Reshape", {5, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reshape, "Reshape", {5, 13, 14}) ||
!graph_utils::IsSupportedProvider(reshape, GetCompatibleExecutionProviders())) {
continue;
}
const auto* attr_proto = graph_utils::GetNodeAttribute(reshape, "allowzero");
if ((nullptr != attr_proto) && attr_proto->has_i() && attr_proto->i() != 0) {
continue;
}
if (ReshapeFusion::Fuse_Subgraph(reshape, graph, logger)) {
fused_count++;
LOGS(logger, INFO) << "Fused reshape node: " << reshape.OutputDefs()[0]->Name();
@ -255,11 +260,11 @@ bool ReshapeFusion::Is_One_Element_Output_Subgraph(Graph& graph, const NodeArg&
std::vector<graph_utils::EdgeEndToMatch> div_path{
{0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Div", {7, 13}, kOnnxDomain}};
{0, 0, "Div", {7, 13, 14}, kOnnxDomain}};
std::vector<graph_utils::EdgeEndToMatch> mul_path{
{0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Mul", {7, 13}, kOnnxDomain}};
{0, 0, "Mul", {7, 13, 14}, kOnnxDomain}};
std::vector<graph_utils::EdgeEndToMatch> unsqueeze_path{
{0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain}};