support opset13 on transformers. (#4837)

Co-authored-by: Vincent Wang <weicwang@microsoft.com>
This commit is contained in:
Vincent Wang 2020-08-19 11:13:37 +08:00 committed by GitHub
parent 61a5502af0
commit 5eaac31faa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 137 additions and 126 deletions

View file

@ -139,6 +139,10 @@ struct BFloat16 {
}
return result;
}
operator float() const {
return ToFloat();
}
};
inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) {

View file

@ -24,7 +24,7 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
!optimizer_utils::CheckOutputEdges(graph, node, 1)) {
continue;

View file

@ -103,12 +103,12 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
// Test if this is an activation that can be fused and also extract the
// activation's parameters.
std::vector<float> activation_params;
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Sigmoid", {6}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Tanh", {6})) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Sigmoid", {6, 13}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Tanh", {6, 13})) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LeakyRelu", {6})) {
activation_params.push_back(graph_utils::GetNodeAttribute(next_node, "alpha")->f());
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12})) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12, 13})) {
float min, max;
if (GetClipConstantMinMax(graph, next_node, min, max)) {
activation_params.push_back(min);

View file

@ -110,7 +110,7 @@ bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const
}
const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13}) ||
next_node.GetInputEdgesCount() != 1 ||
// Make sure the two nodes do not span execution providers.
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {

View file

@ -121,7 +121,7 @@ bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const
}
const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13}) ||
next_node.GetInputEdgesCount() != 1 ||
// Make sure the two nodes do not span execution providers.
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {

View file

@ -21,7 +21,7 @@ Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule
bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
// We currently support elimination for Dropout operator v1, v6, v7, v10 and v12.
// REVIEW(mzs): v10 implementation does not exist.
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {1, 6, 7, 10, 12})) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {1, 6, 7, 10, 12, 13})) {
return false;
}

View file

@ -86,19 +86,19 @@ Status DynamicQuantizeMatMulFusion::ApplyImpl(Graph& graph, bool& modified, int
ORT_RETURN_IF_ERROR(Recurse(mul_node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) ||
!graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders())) {
continue;
}
// Left Parents path
std::vector<graph_utils::EdgeEndToMatch> left_parent_path{
{0, 0, "Cast", {6, 9}, kOnnxDomain},
{0, 0, "Cast", {6, 9, 13}, kOnnxDomain},
{0, 0, "MatMulInteger", {10}, kOnnxDomain},
{0, 0, "DynamicQuantizeLinear", {11}, kOnnxDomain}};
std::vector<graph_utils::EdgeEndToMatch> right_parent_path{
{0, 1, "Mul", {7}, kOnnxDomain},
{0, 1, "Mul", {7, 13}, kOnnxDomain},
{1, 0, "DynamicQuantizeLinear", {11}, kOnnxDomain}};
std::vector<std::reference_wrapper<Node>> left_nodes;

View file

@ -501,20 +501,20 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
}
// Find ReduceSum --> Attention
std::vector<const Node::EdgeEnd*> edges;
if (!graph_utils::FindPath(attention_node, true, {{0, 3, "ReduceSum", {1, 11}, kOnnxDomain}}, edges, logger)) {
if (!graph_utils::FindPath(attention_node, true, {{0, 3, "ReduceSum", {1, 11, 13}, kOnnxDomain}}, edges, logger)) {
continue;
}
Node& reduce_sum_node = *graph.GetNode(edges[0]->GetNode().Index());
// Find Add --> LayerNormalization
if (!graph_utils::FindPath(layer_norm_node, true, {{0, 0, "Add", {7}, kOnnxDomain}}, edges, logger)) {
if (!graph_utils::FindPath(layer_norm_node, true, {{0, 0, "Add", {7, 13}, kOnnxDomain}}, edges, logger)) {
continue;
}
Node& layer_norm_add_node = *graph.GetNode(edges[0]->GetNode().Index());
// Trace back to find the Gather for segment embedding.
std::vector<graph_utils::EdgeEndToMatch> segment_embedding_path{
{0, 1, "Gather", {1, 11}, kOnnxDomain}};
{0, 1, "Gather", {1, 11, 13}, kOnnxDomain}};
if (!graph_utils::FindPath(layer_norm_add_node, true, segment_embedding_path, edges, logger)) {
continue;
}
@ -534,8 +534,8 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
// Trace back to find Gather --> Add --> LayerNormalization
std::vector<graph_utils::EdgeEndToMatch> word_embedding_path{
{0, 0, "Add", {7}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain}};
{0, 0, "Add", {7, 13}, kOnnxDomain},
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain}};
if (!graph_utils::FindPath(layer_norm_add_node, true, word_embedding_path, edges, logger)) {
continue;
}

View file

@ -13,7 +13,7 @@ using namespace onnxruntime::common;
namespace onnxruntime {
// FastGelu supports limited data types.
static std::vector<std::string> gpu_supported_data_types{"tensor(float16)", "tensor(float)"};
static std::vector<std::string> gpu_supported_data_types{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"};
static std::vector<std::string> cpu_supported_data_types{"tensor(float)"};
static bool IsSupportedDataType(const Node& node) {
@ -24,9 +24,10 @@ static bool IsSupportedDataType(const Node& node) {
}
}
static bool CheckNode(Graph& graph, const Node& node, const std::string& op_name, int32_t opset_version, ProviderType provider,
bool require_single_output) {
return graph_utils::IsSupportedOptypeVersionAndDomain(node, op_name, {opset_version}) &&
static bool CheckNode(Graph& graph, const Node& node, const std::string& op_name,
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion>& opset_versions,
ProviderType provider, bool require_single_output) {
return graph_utils::IsSupportedOptypeVersionAndDomain(node, op_name, opset_versions) &&
node.GetExecutionProviderType() == provider &&
IsSupportedDataType(node) &&
(!require_single_output || node.GetOutputEdgesCount() == 1) &&
@ -36,7 +37,7 @@ static bool CheckNode(Graph& graph, const Node& node, const std::string& op_name
MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const {
MatchResult matchResult{false, nullptr, nullptr};
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul1_node, "Mul", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul1_node, "Mul", {7, 13}) ||
!graph_utils::IsSupportedProvider(mul1_node, GetCompatibleExecutionProviders()) ||
mul1_node.GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(mul1_node)) {
@ -59,7 +60,7 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
Node& mul2_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index());
input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *mul1_node.MutableOutputDefs()[0]);
if (!CheckNode(graph, mul2_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true) ||
if (!CheckNode(graph, mul2_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true) ||
mul2_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != gelu_without_bias_input_arg->Name()) {
return matchResult;
}
@ -67,14 +68,14 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
Node& add1_node = *graph.GetNode(mul2_node.OutputNodesBegin()->Index());
input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul2_node.MutableOutputDefs()[0]);
if (!CheckNode(graph, add1_node, "Add", 7, mul1_node.GetExecutionProviderType(), true) ||
if (!CheckNode(graph, add1_node, "Add", {7, 13}, mul1_node.GetExecutionProviderType(), true) ||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(add1_node.InputDefs()[(input_index + 1) % 2]), 1.0f, true)) {
return matchResult;
}
nodes_to_fuse.push_back(add1_node);
Node& mul3_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index());
if (!CheckNode(graph, mul3_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true)) {
if (!CheckNode(graph, mul3_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true)) {
return matchResult;
}
nodes_to_fuse.push_back(mul3_node);
@ -83,7 +84,7 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
const Node* p_mul3_input_node = graph_utils::GetInputNode(mul3_node, (input_index + 1) % 2);
if (p_mul3_input_node == nullptr) return matchResult;
Node& mul4_node = const_cast<Node&>(*p_mul3_input_node);
if (!CheckNode(graph, mul4_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true)) {
if (!CheckNode(graph, mul4_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true)) {
return matchResult;
}
@ -109,7 +110,7 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const {
MatchResult matchResult{false, nullptr, nullptr};
if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow1_node, "Pow", {7, 12}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow1_node, "Pow", {7, 12, 13}) ||
!graph_utils::IsSupportedProvider(pow1_node, GetCompatibleExecutionProviders()) ||
pow1_node.GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(pow1_node)) {
@ -125,7 +126,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
Node& mul1_node = *graph.GetNode(pow1_node.OutputNodesBegin()->Index());
auto input_index = optimizer_utils::IndexOfNodeInput(mul1_node, *pow1_node.MutableOutputDefs()[0]);
if (!CheckNode(graph, mul1_node, "Mul", 7, pow1_node.GetExecutionProviderType(), true) ||
if (!CheckNode(graph, mul1_node, "Mul", {7, 13}, pow1_node.GetExecutionProviderType(), true) ||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul1_node.InputDefs()[(input_index + 1) % 2]),
0.044714998453855515f, true)) {
return matchResult;
@ -134,7 +135,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
Node& add1_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index());
input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul1_node.MutableOutputDefs()[0]);
if (!CheckNode(graph, add1_node, "Add", 7, pow1_node.GetExecutionProviderType(), true) ||
if (!CheckNode(graph, add1_node, "Add", {7, 13}, pow1_node.GetExecutionProviderType(), true) ||
add1_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != pow_input_arg->Name()) {
return matchResult;
}
@ -142,7 +143,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
Node& mul2_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index());
input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *add1_node.MutableOutputDefs()[0]);
if (!CheckNode(graph, mul2_node, "Mul", 7, pow1_node.GetExecutionProviderType(), true) ||
if (!CheckNode(graph, mul2_node, "Mul", {7, 13}, pow1_node.GetExecutionProviderType(), true) ||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul2_node.InputDefs()[(input_index + 1) % 2]),
0.7978845834732056f, true)) {
return matchResult;
@ -177,12 +178,12 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
};
Node& tanh_node = *graph.GetNode(matchRet.tanh_input_node->OutputNodesBegin()->Index());
if (!CheckNode(graph, tanh_node, "Tanh", 6, node.GetExecutionProviderType(), true)) {
if (!CheckNode(graph, tanh_node, "Tanh", {6, 13}, node.GetExecutionProviderType(), true)) {
continue;
}
Node& add2_node = *graph.GetNode(tanh_node.OutputNodesBegin()->Index());
if (!CheckNode(graph, add2_node, "Add", 7, node.GetExecutionProviderType(), true)) {
if (!CheckNode(graph, add2_node, "Add", {7, 13}, node.GetExecutionProviderType(), true)) {
continue;
}
@ -193,7 +194,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
Node& mul5_node = *graph.GetNode(add2_node.OutputNodesBegin()->Index());
// This is the output of the Gelu subgraph, we don't need check it has single edge.
if (!CheckNode(graph, mul5_node, "Mul", 7, node.GetExecutionProviderType(), false)) {
if (!CheckNode(graph, mul5_node, "Mul", {7, 13}, node.GetExecutionProviderType(), false)) {
continue;
}
@ -201,7 +202,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const Node* p_mul5_input_node = graph_utils::GetInputNode(mul5_node, (input_index + 1) % 2);
if (p_mul5_input_node == nullptr) continue;
Node& mul6_node = const_cast<Node&>(*p_mul5_input_node);
if (!CheckNode(graph, mul6_node, "Mul", 7, node.GetExecutionProviderType(), false)) {
if (!CheckNode(graph, mul6_node, "Mul", {7, 13}, node.GetExecutionProviderType(), false)) {
continue;
}

View file

@ -13,7 +13,7 @@ using namespace onnxruntime::common;
namespace onnxruntime {
// FastGelu supports limited data types.
static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(float)"};
static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"};
static bool IsSupportedDataType(const Node& node) {
for (const auto& input_arg : node.InputDefs()) {
@ -51,7 +51,7 @@ static bool CheckInputShape(const Node& node, const NodeArg& input, const NodeAr
// it means that the shape of MatMul output is good for FastGelu.
const Node* parent_node = graph_utils::GetInputNode(node, 0);
if (nullptr != parent_node &&
graph_utils::IsSupportedOptypeVersionAndDomain(*parent_node, "MatMul", {1, 9}, kOnnxDomain)) {
graph_utils::IsSupportedOptypeVersionAndDomain(*parent_node, "MatMul", {1, 9, 13}, kOnnxDomain)) {
const NodeArg& input_b = *(parent_node->InputDefs()[1]);
if (optimizer_utils::ValidateShape(input_b, {-1, bias_length})) {
return true;

View file

@ -55,7 +55,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
Node& div = *p_div;
ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7, 13}) ||
!graph_utils::IsSupportedProvider(div, GetCompatibleExecutionProviders()) ||
!optimizer_utils::CheckOutputEdges(graph, div, 1) ||
!IsSupportedDataType(div)) {
@ -71,7 +71,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
}
Node& erf_node = *graph.GetNode(div.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(erf_node, "Erf", {9}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(erf_node, "Erf", {9, 13}) ||
erf_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, erf_node, 1) ||
!IsSupportedDataType(erf_node)) {
@ -79,7 +79,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
}
Node& add_node = *graph.GetNode(erf_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13}) ||
add_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, add_node, 1) ||
!IsSupportedDataType(add_node)) {
@ -95,7 +95,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
Node& mul_node = *graph.GetNode(add_node.OutputNodesBegin()->Index());
// note: output edges count doesn't matter as the new Gelu node will produce outputs with the same names
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) ||
mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
!IsSupportedDataType(mul_node)) {
continue;
@ -106,7 +106,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
if (p_mul2_node != nullptr) {
// Match subgraph pattern 1
Node& mul2_node = *graph.GetNode(p_mul2_node->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13}) ||
mul2_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, mul2_node, 1) ||
!IsSupportedDataType(mul2_node)) {
@ -139,7 +139,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
continue;
Node& mul2_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13}) ||
mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
!IsSupportedDataType(mul_node)) {
continue;

View file

@ -24,12 +24,12 @@ bool IsFusableActivation(const Node& node) {
return IsSupportedOptypeVersionAndDomain(node, "Elu", {6}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "HardSigmoid", {6}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", {6}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Relu", {6}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Selu", {6}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Softplus", {1}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Softsign", {1}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Tanh", {6}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13}, kOnnxDomain) ||
#ifndef DISABLE_CONTRIB_OPS
IsSupportedOptypeVersionAndDomain(node, "ScaledTanh", {1}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "ParametricSoftplus", {1}, kOnnxDomain) ||
@ -51,7 +51,7 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
auto& node = *node_ptr;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gemm", {7, 9, 11}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gemm", {7, 9, 11, 13}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || node.GetOutputEdgesCount() != 1) {
continue;
}

View file

@ -69,7 +69,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
Node& reduce_mean_node = *p_reduce_mean;
ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13}) ||
!graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) ||
(reduce_mean_node.GetOutputEdgesCount() != 1 && reduce_mean_node.GetOutputEdgesCount() != 2) ||
!graph.GetNodeOutputsInGraphOutputs(reduce_mean_node).empty() ||
@ -102,7 +102,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
Node& sub_node = *graph.GetNode(p_sub_node->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node, "Sub", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node, "Sub", {7, 13}) ||
sub_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, sub_node, subCnt == 1 ? 2u : 1u) ||
!IsSupportedDataType(sub_node)) {
@ -117,7 +117,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// Find the sub_dup node if exist
if (p_sub_node_dup != nullptr) {
Node& sub_node_dup = *graph.GetNode(p_sub_node_dup->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node_dup, "Sub", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node_dup, "Sub", {7, 13}) ||
sub_node_dup.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, sub_node, 1) ||
!IsSupportedDataType(sub_node_dup)) {
@ -134,7 +134,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
Node& div_node = *graph.GetNode(p_div->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13}) ||
div_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, div_node, 1) ||
!IsSupportedDataType(div_node)) {
@ -149,7 +149,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
Node& sqrt_node = *graph.GetNode(p_sqrt->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sqrt_node, "Sqrt", {6}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sqrt_node, "Sqrt", {6, 13}) ||
sqrt_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, sqrt_node, 1) ||
!IsSupportedDataType(sqrt_node) ||
@ -160,7 +160,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// Traceback the sqrt node to find add --> sqrt
Node& add2_node = *graph.GetNode(sqrt_node.InputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7, 13}) ||
add2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, add2_node, 1) ||
!IsSupportedDataType(add2_node)) {
@ -175,7 +175,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
Node& reduce_mean2_node = *graph.GetNode(p_reduce_mean2->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11, 13}) ||
reduce_mean2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, reduce_mean2_node, 1) ||
!IsSupportedDataType(reduce_mean2_node) ||
@ -186,7 +186,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// Traceback the reduceMean node to find pow --> reduceMean
Node& pow_node = *graph.GetNode(reduce_mean2_node.InputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow_node, "Pow", {7, 12}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow_node, "Pow", {7, 12, 13}) ||
pow_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, pow_node, 1) ||
!IsSupportedDataType(pow_node)) {
@ -198,7 +198,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const Node* p_cast_node = graph_utils::FirstParentByType(pow_node, "Cast");
if (p_cast_node != nullptr) {
Node& cast_node = *graph.GetNode(pow_node.InputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13}) ||
cast_node.GetExecutionProviderType() != cast_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, cast_node, 1) ||
!IsSupportedDataType(cast_node)) {
@ -216,7 +216,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// div --> mul
Node& mul_node = *graph.GetNode(div_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) ||
mul_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, mul_node, 1) ||
!IsSupportedDataType(mul_node)) {
@ -227,7 +227,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// mul --> add
// Need not check output edges of last node since they will be moved to fused node.
Node& last_add_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(last_add_node, "Add", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(last_add_node, "Add", {7, 13}) ||
last_add_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!IsSupportedDataType(last_add_node)) {
continue;

View file

@ -24,7 +24,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9, 13}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
node.GetOutputEdgesCount() != 1) {
continue;
@ -40,7 +40,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
const Node& next_node = (*next_node_itr);
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13}) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
continue;
}
@ -58,7 +58,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
if ((*matmul_type) != (*add_type)) {
continue;
}
if ((*matmul_type) != "tensor(float)" && (*matmul_type) != "tensor(float16)") {
if ((*matmul_type) != "tensor(float)" && (*matmul_type) != "tensor(float16)" && (*matmul_type) != "tensor(bfloat16)") {
continue;
}

View file

@ -46,7 +46,7 @@ optional<float> GetScalarConstantInitializer(const Graph& graph, const NodeArg&
float scalar=0.f;
utils::MLTypeCallDispatcherRet<
Status, ExtractScalarAsFloatDispatchTarget,
uint32_t, uint64_t, int32_t, int64_t, MLFloat16, float, double>
uint32_t, uint64_t, int32_t, int64_t, MLFloat16, float, double, BFloat16>
dispatcher{initializer->data_type()};
ORT_THROW_IF_ERROR(dispatcher.Invoke(*initializer, scalar));
@ -62,7 +62,7 @@ optional<std::pair<float, int>> GetScaleFromNode(
return excluded_initializer_names.find(node_arg.Name()) != excluded_initializer_names.end();
};
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Div", {7})) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Div", {7, 13})) {
// (x / scale_reciprocal)
const auto div_inputs = scale_node.InputDefs();
ORT_ENFORCE(div_inputs.size() == 2);
@ -79,7 +79,7 @@ optional<std::pair<float, int>> GetScaleFromNode(
return {std::make_pair(1.0f / divisor.value(), scale_reciprocal_arg_index)};
}
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Mul", {7})) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Mul", {7, 13})) {
// (x * scale) or (scale * x)
const auto mul_inputs = scale_node.InputDefs();
ORT_ENFORCE(mul_inputs.size() == 2);
@ -170,7 +170,7 @@ std::vector<ScaleMergeInfo> GetOutputNodeMerges(
Status ProcessNode(
Graph& graph, Node& node, bool& modified,
const std::unordered_set<std::string>& excluded_initializer_names) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9}) &&
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9, 13}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeScaleMatMul", {1}, kMSDomain)) {
return Status::OK();
}

View file

@ -96,8 +96,8 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
auto& node = *graph.GetNode(node_index);
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if ((!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeScaleMatMul", {1}, kMSDomain)) ||
if ((!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9, 13}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeScaleMatMul", {1, 13}, kMSDomain)) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
continue;
}

View file

@ -937,23 +937,23 @@ void NchwcTransformerImpl::Transform(Node& node) {
// node may already have all inputs converted to NCHWc format and is not
// needed for correct operation. This avoids doing extra string checks for
// nodes unrelated to this transformer.
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sum", {6, 8})) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sum", {6, 8, 13})) {
TransformBinary(node, true);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7})) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13})) {
TransformBinary(node, false);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11})) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11, 13})) {
TransformConcat(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6})) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13})) {
TransformActivation(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9})) {
TransformBatchNormalization(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1})) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13})) {
TransformTranspose(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Upsample", {9}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Resize", {10, 11})) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Upsample", {9, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Resize", {10, 11, 13})) {
TransformResize(node);
}
}

View file

@ -68,6 +68,11 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
replace_min = true;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
if (i.data<BFloat16>()->ToFloat() < 0.f) {
replace_min = true;
}
break;
default:
ORT_THROW("Unexpected data type for Clip 'min' input of ", initializer->data_type());
}
@ -110,7 +115,7 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
}
bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6})) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13})) {
return false;
}
@ -122,7 +127,7 @@ bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node, const
// as Clip will apply the minimum. If the Clip 'min' value is < 0 we need
// to update it to 0 to apply what the Relu would have done. We do that in Apply.
const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12, 13}) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
return false;
}

View file

@ -23,7 +23,7 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c
Node& reshape = *p_reshape;
ORT_RETURN_IF_ERROR(Recurse(reshape, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reshape, "Reshape", {5}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reshape, "Reshape", {5, 13}) ||
!graph_utils::IsSupportedProvider(reshape, GetCompatibleExecutionProviders())) {
continue;
}
@ -48,9 +48,9 @@ bool ReshapeFusion::Match_One_Element_Output_Subgraph_1(Graph& graph, const Node
int index, std::vector<int64_t> shape_value, bool checkOneElementOnly,
const logging::Logger& logger) {
std::vector<graph_utils::EdgeEndToMatch> parent_path{
{0, index, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain}};
{0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain},
{0, 0, "Shape", {1, 13}, kOnnxDomain}};
std::vector<const Node::EdgeEnd*> edges;
if (graph_utils::FindPath(concat, true, parent_path, edges, logger)) {
const Node& unsqueeze = edges[0]->GetNode();
@ -87,9 +87,9 @@ bool ReshapeFusion::Match_One_Element_Output_Subgraph_1(Graph& graph, const Node
bool ReshapeFusion::Match_One_Element_Output_Subgraph_2(Graph& graph, const NodeArg& root_input, const Node& cur_node,
int index, const logging::Logger& logger) {
std::vector<graph_utils::EdgeEndToMatch> parent_path{
{0, index, "Squeeze", {1, 11}, kOnnxDomain},
{0, 0, "Slice", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain}};
{0, index, "Squeeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Slice", {1, 11, 13}, kOnnxDomain},
{0, 0, "Shape", {1, 13}, kOnnxDomain}};
std::vector<const Node::EdgeEnd*> edges;
if (graph_utils::FindPath(cur_node, true, parent_path, edges, logger)) {
const Node& slice = edges[1]->GetNode();
@ -157,15 +157,15 @@ bool ReshapeFusion::Is_One_Element_Output_Subgraph(Graph& graph, const NodeArg&
}
std::vector<graph_utils::EdgeEndToMatch> div_path{
{0, index, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Div", {7}, kOnnxDomain}};
{0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Div", {7, 13}, kOnnxDomain}};
std::vector<graph_utils::EdgeEndToMatch> mul_path{
{0, index, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Mul", {7}, kOnnxDomain}};
{0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Mul", {7, 13}, kOnnxDomain}};
std::vector<graph_utils::EdgeEndToMatch> unsqueeze_path{
{0, index, "Unsqueeze", {1, 11}, kOnnxDomain}};
{0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain}};
std::vector<const Node::EdgeEnd*> edges;
if (graph_utils::FindPath(concat, true, div_path, edges, logger) ||
@ -242,7 +242,7 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo
}
const Node& concat = *p_concat;
if (!graph_utils::IsSupportedOptypeVersionAndDomain(concat, "Concat", {1, 4, 11})) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(concat, "Concat", {1, 4, 11, 13})) {
return false;
}

View file

@ -50,7 +50,7 @@ Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& ru
}
bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Shape", {1})) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Shape", {1, 13})) {
return false;
}

View file

@ -12,7 +12,7 @@ using namespace onnxruntime::common;
namespace onnxruntime {
// LayerNorm supports limited data types.
static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(float)"};
static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"};
static bool IsSupportedDataType(const Node& node) {
for (const auto& input_arg : node.InputDefs()) {
@ -163,8 +163,8 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
// Format 1
std::vector<graph_utils::EdgeEndToMatch> format1_parent_path{
{0, 0, "Add", {7}, kOnnxDomain},
{0, 0, "Add", {7}, kOnnxDomain}};
{0, 0, "Add", {7, 13}, kOnnxDomain},
{0, 0, "Add", {7, 13}, kOnnxDomain}};
std::vector<const Node::EdgeEnd*> edges;
if (graph_utils::FindPath(ln_node, true, format1_parent_path, edges, logger)) {
@ -182,8 +182,8 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
if (matched_format == Format::None) {
// Format 2
std::vector<graph_utils::EdgeEndToMatch> format2_parent_path{
{0, 0, "Add", {7}, kOnnxDomain},
{0, 1, "Add", {7}, kOnnxDomain}};
{0, 0, "Add", {7, 13}, kOnnxDomain},
{0, 1, "Add", {7, 13}, kOnnxDomain}};
if (graph_utils::FindPath(ln_node, true, format2_parent_path, edges, logger)) {
p_add1 = const_cast<Node*>(&edges[0]->GetNode());
@ -201,7 +201,7 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
if (matched_format == Format::None) {
// Format 3
std::vector<graph_utils::EdgeEndToMatch> format3_parent_path{
{0, 0, "Add", {7}, kOnnxDomain}};
{0, 0, "Add", {7, 13}, kOnnxDomain}};
if (graph_utils::FindPath(ln_node, true, format3_parent_path, edges, logger)) {
p_add1 = const_cast<Node*>(&edges[0]->GetNode());

View file

@ -19,7 +19,7 @@ Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_e
bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
// We currently support elimination for Slice operator v1.
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11})) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13})) {
return false;
}
@ -42,8 +42,8 @@ bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node, cons
if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) {
return false;
}
} else if (graph_utils::MatchesOpSinceVersion(node, {10, 11})) {
// If it is a Slice operator of opset version 10 or 11, starts/ends/axes/steps are provided as node inputs.
} else if (graph_utils::MatchesOpSinceVersion(node, {10, 11, 13})) {
// If it is a Slice operator of opset version >= 10, starts/ends/axes/steps are provided as node inputs.
// Returns a pointer to the corresponding NodeArg if input of the node at this index exists; otherwise, a nullptr.
auto get_input_if_exists = [&node](size_t input_idx) -> const NodeArg* {

View file

@ -18,7 +18,7 @@ void FuseResidualAddIfAny(Graph& graph, const Node& dropout_node,
for (auto last_node_itr = dropout_node.OutputNodesBegin(); last_node_itr != dropout_node.OutputNodesEnd(); ++last_node_itr) {
const Node& last_node = (*last_node_itr);
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Add", {7}) &&
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Add", {7, 13}) &&
last_node.GetExecutionProviderType() == dropout_node.GetExecutionProviderType()) {
const TensorShapeProto* input1_shape = last_node.InputDefs()[0]->Shape();
const TensorShapeProto* input2_shape = last_node.InputDefs()[1]->Shape();
@ -83,7 +83,7 @@ Status BiasDropoutFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve
std::vector<std::reference_wrapper<Node>> nodes_to_fuse;
// matching for bias Add node
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
node.GetOutputEdgesCount() != 1) {
continue;
@ -127,7 +127,7 @@ Status BiasDropoutFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve
}
const Node& next_node = (*next_node_itr);
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Dropout", {12}, kOnnxDomain) ||
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Dropout", {12, 13}, kOnnxDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "TrainableDropout", {9}, kOnnxDomain)) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
continue;

View file

@ -53,7 +53,7 @@ Status InsertSoftmaxCrossEntropyLossOutput::Apply(Graph& graph, Node& node, Rewr
}
bool InsertSoftmaxCrossEntropyLossOutput::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const {
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "SoftmaxCrossEntropyLoss", {12}) &&
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "SoftmaxCrossEntropyLoss", {12, 13}) &&
node.OutputDefs().size() == 1) {
return true;
}

View file

@ -28,24 +28,25 @@ struct OpInfo {
size_t output_count;
};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v1 = {1};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v1_11 = {1, 11};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v2_11 = {2, 11};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v5 = {5};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v7 = {7};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v1_13 = {1, 13};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v1_11_13 = {1, 11, 13};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v2_11_13 = {2, 11, 13};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v5_13 = {5, 13};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v7_13 = {7, 13};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v9 = {9};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v12 = {12};
const OpInfo add_info = OpInfo("Add", opset_v7);
const OpInfo split_info = OpInfo("Split", opset_v2_11, kOnnxDomainAlias, 3);
const OpInfo reshape_info = OpInfo("Reshape", opset_v5);
const OpInfo transpose_info = OpInfo("Transpose", opset_v1);
const OpInfo matmul_info = OpInfo("MatMul", opset_v9);
const OpInfo div_info = OpInfo("Div", opset_v7);
const OpInfo mul_info = OpInfo("Mul", opset_v7);
const OpInfo sub_info = OpInfo("Sub", opset_v7);
const OpInfo softmax_info = OpInfo("Softmax", opset_v1_11);
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v9_13 = {9, 13};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v12_13 = {12, 13};
const OpInfo add_info = OpInfo("Add", opset_v7_13);
const OpInfo split_info = OpInfo("Split", opset_v2_11_13, kOnnxDomainAlias, 3);
const OpInfo reshape_info = OpInfo("Reshape", opset_v5_13);
const OpInfo transpose_info = OpInfo("Transpose", opset_v1_13);
const OpInfo matmul_info = OpInfo("MatMul", opset_v9_13);
const OpInfo div_info = OpInfo("Div", opset_v7_13);
const OpInfo mul_info = OpInfo("Mul", opset_v7_13);
const OpInfo sub_info = OpInfo("Sub", opset_v7_13);
const OpInfo softmax_info = OpInfo("Softmax", opset_v1_11_13);
const OpInfo trainable_dropout_info = OpInfo("TrainableDropout", opset_v9, kOnnxDomain);
const OpInfo dropout_info = OpInfo("Dropout", opset_v12);
const OpInfo dropout_info = OpInfo("Dropout", opset_v12_13);
struct NodeInfo {
NodeInfo(const std::vector<OpInfo>& op_infos,
@ -243,7 +244,7 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph
auto& node = *graph.GetNode(node_index);
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9, 13}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
node.GetOutputEdgesCount() != 1) {
continue;
@ -257,7 +258,7 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph
}
Node& add_node = *graph.GetNode(node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13}) ||
add_node.GetExecutionProviderType() != node.GetExecutionProviderType() ||
add_node.GetOutputEdgesCount() != 1) {
continue;
@ -272,14 +273,14 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph
}
Node& matmul2_node = *graph.GetNode(gelu_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(matmul2_node, "MatMul", {9}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(matmul2_node, "MatMul", {9, 13}) ||
matmul2_node.GetExecutionProviderType() != node.GetExecutionProviderType() ||
matmul2_node.GetOutputEdgesCount() != 1) {
continue;
}
Node& add2_node = *graph.GetNode(matmul2_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7, 13}) ||
add2_node.GetExecutionProviderType() != node.GetExecutionProviderType() ||
add2_node.GetOutputEdgesCount() != 1) {
continue;
@ -368,7 +369,7 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified,
auto& node = *graph.GetNode(node_index);
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", opset_v9) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", opset_v9_13) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
node.GetOutputEdgesCount() != 1) {
continue;
@ -603,7 +604,7 @@ Status MegatronTransformer::TransformDropout(Graph& graph, bool& modified, int g
continue;
}
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", opset_v12) &&
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", opset_v12_13) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "TrainableDropout", opset_v9, kOnnxDomain)) {
continue;
}