mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-23 02:38:28 +00:00
support opset13 on transformers. (#4837)
Co-authored-by: Vincent Wang <weicwang@microsoft.com>
This commit is contained in:
parent
61a5502af0
commit
5eaac31faa
25 changed files with 137 additions and 126 deletions
|
|
@ -139,6 +139,10 @@ struct BFloat16 {
|
|||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
operator float() const {
|
||||
return ToFloat();
|
||||
}
|
||||
};
|
||||
|
||||
inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) {
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, node, 1)) {
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -103,12 +103,12 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
// Test if this is an activation that can be fused and also extract the
|
||||
// activation's parameters.
|
||||
std::vector<float> activation_params;
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6}) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Sigmoid", {6}) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Tanh", {6})) {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13}) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Sigmoid", {6, 13}) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Tanh", {6, 13})) {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LeakyRelu", {6})) {
|
||||
activation_params.push_back(graph_utils::GetNodeAttribute(next_node, "alpha")->f());
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12})) {
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12, 13})) {
|
||||
float min, max;
|
||||
if (GetClipConstantMinMax(graph, next_node, min, max)) {
|
||||
activation_params.push_back(min);
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const
|
|||
}
|
||||
|
||||
const auto& next_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13}) ||
|
||||
next_node.GetInputEdgesCount() != 1 ||
|
||||
// Make sure the two nodes do not span execution providers.
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const
|
|||
}
|
||||
|
||||
const auto& next_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13}) ||
|
||||
next_node.GetInputEdgesCount() != 1 ||
|
||||
// Make sure the two nodes do not span execution providers.
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule
|
|||
bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
|
||||
// We currently support elimination for Dropout operator v1, v6, v7, v10 and v12.
|
||||
// REVIEW(mzs): v10 implementation does not exist.
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {1, 6, 7, 10, 12})) {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {1, 6, 7, 10, 12, 13})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -86,19 +86,19 @@ Status DynamicQuantizeMatMulFusion::ApplyImpl(Graph& graph, bool& modified, int
|
|||
|
||||
ORT_RETURN_IF_ERROR(Recurse(mul_node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Left Parents path
|
||||
std::vector<graph_utils::EdgeEndToMatch> left_parent_path{
|
||||
{0, 0, "Cast", {6, 9}, kOnnxDomain},
|
||||
{0, 0, "Cast", {6, 9, 13}, kOnnxDomain},
|
||||
{0, 0, "MatMulInteger", {10}, kOnnxDomain},
|
||||
{0, 0, "DynamicQuantizeLinear", {11}, kOnnxDomain}};
|
||||
|
||||
std::vector<graph_utils::EdgeEndToMatch> right_parent_path{
|
||||
{0, 1, "Mul", {7}, kOnnxDomain},
|
||||
{0, 1, "Mul", {7, 13}, kOnnxDomain},
|
||||
{1, 0, "DynamicQuantizeLinear", {11}, kOnnxDomain}};
|
||||
|
||||
std::vector<std::reference_wrapper<Node>> left_nodes;
|
||||
|
|
|
|||
|
|
@ -501,20 +501,20 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
}
|
||||
// Find ReduceSum --> Attention
|
||||
std::vector<const Node::EdgeEnd*> edges;
|
||||
if (!graph_utils::FindPath(attention_node, true, {{0, 3, "ReduceSum", {1, 11}, kOnnxDomain}}, edges, logger)) {
|
||||
if (!graph_utils::FindPath(attention_node, true, {{0, 3, "ReduceSum", {1, 11, 13}, kOnnxDomain}}, edges, logger)) {
|
||||
continue;
|
||||
}
|
||||
Node& reduce_sum_node = *graph.GetNode(edges[0]->GetNode().Index());
|
||||
|
||||
// Find Add --> LayerNormalization
|
||||
if (!graph_utils::FindPath(layer_norm_node, true, {{0, 0, "Add", {7}, kOnnxDomain}}, edges, logger)) {
|
||||
if (!graph_utils::FindPath(layer_norm_node, true, {{0, 0, "Add", {7, 13}, kOnnxDomain}}, edges, logger)) {
|
||||
continue;
|
||||
}
|
||||
Node& layer_norm_add_node = *graph.GetNode(edges[0]->GetNode().Index());
|
||||
|
||||
// Trace back to find the Gather for segment embedding.
|
||||
std::vector<graph_utils::EdgeEndToMatch> segment_embedding_path{
|
||||
{0, 1, "Gather", {1, 11}, kOnnxDomain}};
|
||||
{0, 1, "Gather", {1, 11, 13}, kOnnxDomain}};
|
||||
if (!graph_utils::FindPath(layer_norm_add_node, true, segment_embedding_path, edges, logger)) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -534,8 +534,8 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
|
||||
// Trace back to find Gather --> Add --> LayerNormalization
|
||||
std::vector<graph_utils::EdgeEndToMatch> word_embedding_path{
|
||||
{0, 0, "Add", {7}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11}, kOnnxDomain}};
|
||||
{0, 0, "Add", {7, 13}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain}};
|
||||
if (!graph_utils::FindPath(layer_norm_add_node, true, word_embedding_path, edges, logger)) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ using namespace onnxruntime::common;
|
|||
namespace onnxruntime {
|
||||
|
||||
// FastGelu supports limited data types.
|
||||
static std::vector<std::string> gpu_supported_data_types{"tensor(float16)", "tensor(float)"};
|
||||
static std::vector<std::string> gpu_supported_data_types{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"};
|
||||
static std::vector<std::string> cpu_supported_data_types{"tensor(float)"};
|
||||
|
||||
static bool IsSupportedDataType(const Node& node) {
|
||||
|
|
@ -24,9 +24,10 @@ static bool IsSupportedDataType(const Node& node) {
|
|||
}
|
||||
}
|
||||
|
||||
static bool CheckNode(Graph& graph, const Node& node, const std::string& op_name, int32_t opset_version, ProviderType provider,
|
||||
bool require_single_output) {
|
||||
return graph_utils::IsSupportedOptypeVersionAndDomain(node, op_name, {opset_version}) &&
|
||||
static bool CheckNode(Graph& graph, const Node& node, const std::string& op_name,
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion>& opset_versions,
|
||||
ProviderType provider, bool require_single_output) {
|
||||
return graph_utils::IsSupportedOptypeVersionAndDomain(node, op_name, opset_versions) &&
|
||||
node.GetExecutionProviderType() == provider &&
|
||||
IsSupportedDataType(node) &&
|
||||
(!require_single_output || node.GetOutputEdgesCount() == 1) &&
|
||||
|
|
@ -36,7 +37,7 @@ static bool CheckNode(Graph& graph, const Node& node, const std::string& op_name
|
|||
MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
|
||||
std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const {
|
||||
MatchResult matchResult{false, nullptr, nullptr};
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul1_node, "Mul", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul1_node, "Mul", {7, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(mul1_node, GetCompatibleExecutionProviders()) ||
|
||||
mul1_node.GetOutputEdgesCount() != 1 ||
|
||||
!IsSupportedDataType(mul1_node)) {
|
||||
|
|
@ -59,7 +60,7 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
|
|||
|
||||
Node& mul2_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index());
|
||||
input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *mul1_node.MutableOutputDefs()[0]);
|
||||
if (!CheckNode(graph, mul2_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true) ||
|
||||
if (!CheckNode(graph, mul2_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true) ||
|
||||
mul2_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != gelu_without_bias_input_arg->Name()) {
|
||||
return matchResult;
|
||||
}
|
||||
|
|
@ -67,14 +68,14 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
|
|||
|
||||
Node& add1_node = *graph.GetNode(mul2_node.OutputNodesBegin()->Index());
|
||||
input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul2_node.MutableOutputDefs()[0]);
|
||||
if (!CheckNode(graph, add1_node, "Add", 7, mul1_node.GetExecutionProviderType(), true) ||
|
||||
if (!CheckNode(graph, add1_node, "Add", {7, 13}, mul1_node.GetExecutionProviderType(), true) ||
|
||||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(add1_node.InputDefs()[(input_index + 1) % 2]), 1.0f, true)) {
|
||||
return matchResult;
|
||||
}
|
||||
nodes_to_fuse.push_back(add1_node);
|
||||
|
||||
Node& mul3_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index());
|
||||
if (!CheckNode(graph, mul3_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true)) {
|
||||
if (!CheckNode(graph, mul3_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true)) {
|
||||
return matchResult;
|
||||
}
|
||||
nodes_to_fuse.push_back(mul3_node);
|
||||
|
|
@ -83,7 +84,7 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
|
|||
const Node* p_mul3_input_node = graph_utils::GetInputNode(mul3_node, (input_index + 1) % 2);
|
||||
if (p_mul3_input_node == nullptr) return matchResult;
|
||||
Node& mul4_node = const_cast<Node&>(*p_mul3_input_node);
|
||||
if (!CheckNode(graph, mul4_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true)) {
|
||||
if (!CheckNode(graph, mul4_node, "Mul", {7, 13}, mul1_node.GetExecutionProviderType(), true)) {
|
||||
return matchResult;
|
||||
}
|
||||
|
||||
|
|
@ -109,7 +110,7 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
|
|||
MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
|
||||
std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const {
|
||||
MatchResult matchResult{false, nullptr, nullptr};
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow1_node, "Pow", {7, 12}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow1_node, "Pow", {7, 12, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(pow1_node, GetCompatibleExecutionProviders()) ||
|
||||
pow1_node.GetOutputEdgesCount() != 1 ||
|
||||
!IsSupportedDataType(pow1_node)) {
|
||||
|
|
@ -125,7 +126,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
|
|||
|
||||
Node& mul1_node = *graph.GetNode(pow1_node.OutputNodesBegin()->Index());
|
||||
auto input_index = optimizer_utils::IndexOfNodeInput(mul1_node, *pow1_node.MutableOutputDefs()[0]);
|
||||
if (!CheckNode(graph, mul1_node, "Mul", 7, pow1_node.GetExecutionProviderType(), true) ||
|
||||
if (!CheckNode(graph, mul1_node, "Mul", {7, 13}, pow1_node.GetExecutionProviderType(), true) ||
|
||||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul1_node.InputDefs()[(input_index + 1) % 2]),
|
||||
0.044714998453855515f, true)) {
|
||||
return matchResult;
|
||||
|
|
@ -134,7 +135,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
|
|||
|
||||
Node& add1_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index());
|
||||
input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul1_node.MutableOutputDefs()[0]);
|
||||
if (!CheckNode(graph, add1_node, "Add", 7, pow1_node.GetExecutionProviderType(), true) ||
|
||||
if (!CheckNode(graph, add1_node, "Add", {7, 13}, pow1_node.GetExecutionProviderType(), true) ||
|
||||
add1_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != pow_input_arg->Name()) {
|
||||
return matchResult;
|
||||
}
|
||||
|
|
@ -142,7 +143,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
|
|||
|
||||
Node& mul2_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index());
|
||||
input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *add1_node.MutableOutputDefs()[0]);
|
||||
if (!CheckNode(graph, mul2_node, "Mul", 7, pow1_node.GetExecutionProviderType(), true) ||
|
||||
if (!CheckNode(graph, mul2_node, "Mul", {7, 13}, pow1_node.GetExecutionProviderType(), true) ||
|
||||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul2_node.InputDefs()[(input_index + 1) % 2]),
|
||||
0.7978845834732056f, true)) {
|
||||
return matchResult;
|
||||
|
|
@ -177,12 +178,12 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
};
|
||||
|
||||
Node& tanh_node = *graph.GetNode(matchRet.tanh_input_node->OutputNodesBegin()->Index());
|
||||
if (!CheckNode(graph, tanh_node, "Tanh", 6, node.GetExecutionProviderType(), true)) {
|
||||
if (!CheckNode(graph, tanh_node, "Tanh", {6, 13}, node.GetExecutionProviderType(), true)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node& add2_node = *graph.GetNode(tanh_node.OutputNodesBegin()->Index());
|
||||
if (!CheckNode(graph, add2_node, "Add", 7, node.GetExecutionProviderType(), true)) {
|
||||
if (!CheckNode(graph, add2_node, "Add", {7, 13}, node.GetExecutionProviderType(), true)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -193,7 +194,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
|
||||
Node& mul5_node = *graph.GetNode(add2_node.OutputNodesBegin()->Index());
|
||||
// This is the output of the Gelu subgraph, we don't need check it has single edge.
|
||||
if (!CheckNode(graph, mul5_node, "Mul", 7, node.GetExecutionProviderType(), false)) {
|
||||
if (!CheckNode(graph, mul5_node, "Mul", {7, 13}, node.GetExecutionProviderType(), false)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -201,7 +202,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
const Node* p_mul5_input_node = graph_utils::GetInputNode(mul5_node, (input_index + 1) % 2);
|
||||
if (p_mul5_input_node == nullptr) continue;
|
||||
Node& mul6_node = const_cast<Node&>(*p_mul5_input_node);
|
||||
if (!CheckNode(graph, mul6_node, "Mul", 7, node.GetExecutionProviderType(), false)) {
|
||||
if (!CheckNode(graph, mul6_node, "Mul", {7, 13}, node.GetExecutionProviderType(), false)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ using namespace onnxruntime::common;
|
|||
namespace onnxruntime {
|
||||
|
||||
// FastGelu supports limited data types.
|
||||
static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(float)"};
|
||||
static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"};
|
||||
|
||||
static bool IsSupportedDataType(const Node& node) {
|
||||
for (const auto& input_arg : node.InputDefs()) {
|
||||
|
|
@ -51,7 +51,7 @@ static bool CheckInputShape(const Node& node, const NodeArg& input, const NodeAr
|
|||
// it means that the shape of MatMul output is good for FastGelu.
|
||||
const Node* parent_node = graph_utils::GetInputNode(node, 0);
|
||||
if (nullptr != parent_node &&
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(*parent_node, "MatMul", {1, 9}, kOnnxDomain)) {
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(*parent_node, "MatMul", {1, 9, 13}, kOnnxDomain)) {
|
||||
const NodeArg& input_b = *(parent_node->InputDefs()[1]);
|
||||
if (optimizer_utils::ValidateShape(input_b, {-1, bias_length})) {
|
||||
return true;
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
|
|||
Node& div = *p_div;
|
||||
ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(div, GetCompatibleExecutionProviders()) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, div, 1) ||
|
||||
!IsSupportedDataType(div)) {
|
||||
|
|
@ -71,7 +71,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
|
|||
}
|
||||
|
||||
Node& erf_node = *graph.GetNode(div.OutputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(erf_node, "Erf", {9}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(erf_node, "Erf", {9, 13}) ||
|
||||
erf_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, erf_node, 1) ||
|
||||
!IsSupportedDataType(erf_node)) {
|
||||
|
|
@ -79,7 +79,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
|
|||
}
|
||||
|
||||
Node& add_node = *graph.GetNode(erf_node.OutputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13}) ||
|
||||
add_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, add_node, 1) ||
|
||||
!IsSupportedDataType(add_node)) {
|
||||
|
|
@ -95,7 +95,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
|
|||
|
||||
Node& mul_node = *graph.GetNode(add_node.OutputNodesBegin()->Index());
|
||||
// note: output edges count doesn't matter as the new Gelu node will produce outputs with the same names
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) ||
|
||||
mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
!IsSupportedDataType(mul_node)) {
|
||||
continue;
|
||||
|
|
@ -106,7 +106,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
|
|||
if (p_mul2_node != nullptr) {
|
||||
// Match subgraph pattern 1
|
||||
Node& mul2_node = *graph.GetNode(p_mul2_node->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13}) ||
|
||||
mul2_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, mul2_node, 1) ||
|
||||
!IsSupportedDataType(mul2_node)) {
|
||||
|
|
@ -139,7 +139,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
|
|||
continue;
|
||||
|
||||
Node& mul2_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7, 13}) ||
|
||||
mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
!IsSupportedDataType(mul_node)) {
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -24,12 +24,12 @@ bool IsFusableActivation(const Node& node) {
|
|||
return IsSupportedOptypeVersionAndDomain(node, "Elu", {6}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "HardSigmoid", {6}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", {6}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "Relu", {6}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "Selu", {6}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "Softplus", {1}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "Softsign", {1}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "Tanh", {6}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13}, kOnnxDomain) ||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
IsSupportedOptypeVersionAndDomain(node, "ScaledTanh", {1}, kOnnxDomain) ||
|
||||
IsSupportedOptypeVersionAndDomain(node, "ParametricSoftplus", {1}, kOnnxDomain) ||
|
||||
|
|
@ -51,7 +51,7 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
auto& node = *node_ptr;
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gemm", {7, 9, 11}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gemm", {7, 9, 11, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
Node& reduce_mean_node = *p_reduce_mean;
|
||||
ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) ||
|
||||
(reduce_mean_node.GetOutputEdgesCount() != 1 && reduce_mean_node.GetOutputEdgesCount() != 2) ||
|
||||
!graph.GetNodeOutputsInGraphOutputs(reduce_mean_node).empty() ||
|
||||
|
|
@ -102,7 +102,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
}
|
||||
|
||||
Node& sub_node = *graph.GetNode(p_sub_node->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node, "Sub", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node, "Sub", {7, 13}) ||
|
||||
sub_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, sub_node, subCnt == 1 ? 2u : 1u) ||
|
||||
!IsSupportedDataType(sub_node)) {
|
||||
|
|
@ -117,7 +117,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
// Find the sub_dup node if exist
|
||||
if (p_sub_node_dup != nullptr) {
|
||||
Node& sub_node_dup = *graph.GetNode(p_sub_node_dup->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node_dup, "Sub", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node_dup, "Sub", {7, 13}) ||
|
||||
sub_node_dup.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, sub_node, 1) ||
|
||||
!IsSupportedDataType(sub_node_dup)) {
|
||||
|
|
@ -134,7 +134,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
continue;
|
||||
}
|
||||
Node& div_node = *graph.GetNode(p_div->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13}) ||
|
||||
div_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, div_node, 1) ||
|
||||
!IsSupportedDataType(div_node)) {
|
||||
|
|
@ -149,7 +149,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
}
|
||||
Node& sqrt_node = *graph.GetNode(p_sqrt->Index());
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sqrt_node, "Sqrt", {6}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sqrt_node, "Sqrt", {6, 13}) ||
|
||||
sqrt_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, sqrt_node, 1) ||
|
||||
!IsSupportedDataType(sqrt_node) ||
|
||||
|
|
@ -160,7 +160,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
|
||||
// Traceback the sqrt node to find add --> sqrt
|
||||
Node& add2_node = *graph.GetNode(sqrt_node.InputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7, 13}) ||
|
||||
add2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, add2_node, 1) ||
|
||||
!IsSupportedDataType(add2_node)) {
|
||||
|
|
@ -175,7 +175,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
continue;
|
||||
}
|
||||
Node& reduce_mean2_node = *graph.GetNode(p_reduce_mean2->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11, 13}) ||
|
||||
reduce_mean2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, reduce_mean2_node, 1) ||
|
||||
!IsSupportedDataType(reduce_mean2_node) ||
|
||||
|
|
@ -186,7 +186,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
|
||||
// Traceback the reduceMean node to find pow --> reduceMean
|
||||
Node& pow_node = *graph.GetNode(reduce_mean2_node.InputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow_node, "Pow", {7, 12}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow_node, "Pow", {7, 12, 13}) ||
|
||||
pow_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, pow_node, 1) ||
|
||||
!IsSupportedDataType(pow_node)) {
|
||||
|
|
@ -198,7 +198,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
const Node* p_cast_node = graph_utils::FirstParentByType(pow_node, "Cast");
|
||||
if (p_cast_node != nullptr) {
|
||||
Node& cast_node = *graph.GetNode(pow_node.InputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13}) ||
|
||||
cast_node.GetExecutionProviderType() != cast_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, cast_node, 1) ||
|
||||
!IsSupportedDataType(cast_node)) {
|
||||
|
|
@ -216,7 +216,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
|
||||
// div --> mul
|
||||
Node& mul_node = *graph.GetNode(div_node.OutputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13}) ||
|
||||
mul_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, mul_node, 1) ||
|
||||
!IsSupportedDataType(mul_node)) {
|
||||
|
|
@ -227,7 +227,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
// mul --> add
|
||||
// Need not check output edges of last node since they will be moved to fused node.
|
||||
Node& last_add_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(last_add_node, "Add", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(last_add_node, "Add", {7, 13}) ||
|
||||
last_add_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!IsSupportedDataType(last_add_node)) {
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
|
|
@ -40,7 +40,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
}
|
||||
|
||||
const Node& next_node = (*next_node_itr);
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7, 13}) ||
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -58,7 +58,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
if ((*matmul_type) != (*add_type)) {
|
||||
continue;
|
||||
}
|
||||
if ((*matmul_type) != "tensor(float)" && (*matmul_type) != "tensor(float16)") {
|
||||
if ((*matmul_type) != "tensor(float)" && (*matmul_type) != "tensor(float16)" && (*matmul_type) != "tensor(bfloat16)") {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ optional<float> GetScalarConstantInitializer(const Graph& graph, const NodeArg&
|
|||
float scalar=0.f;
|
||||
utils::MLTypeCallDispatcherRet<
|
||||
Status, ExtractScalarAsFloatDispatchTarget,
|
||||
uint32_t, uint64_t, int32_t, int64_t, MLFloat16, float, double>
|
||||
uint32_t, uint64_t, int32_t, int64_t, MLFloat16, float, double, BFloat16>
|
||||
dispatcher{initializer->data_type()};
|
||||
ORT_THROW_IF_ERROR(dispatcher.Invoke(*initializer, scalar));
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ optional<std::pair<float, int>> GetScaleFromNode(
|
|||
return excluded_initializer_names.find(node_arg.Name()) != excluded_initializer_names.end();
|
||||
};
|
||||
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Div", {7})) {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Div", {7, 13})) {
|
||||
// (x / scale_reciprocal)
|
||||
const auto div_inputs = scale_node.InputDefs();
|
||||
ORT_ENFORCE(div_inputs.size() == 2);
|
||||
|
|
@ -79,7 +79,7 @@ optional<std::pair<float, int>> GetScaleFromNode(
|
|||
return {std::make_pair(1.0f / divisor.value(), scale_reciprocal_arg_index)};
|
||||
}
|
||||
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Mul", {7})) {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Mul", {7, 13})) {
|
||||
// (x * scale) or (scale * x)
|
||||
const auto mul_inputs = scale_node.InputDefs();
|
||||
ORT_ENFORCE(mul_inputs.size() == 2);
|
||||
|
|
@ -170,7 +170,7 @@ std::vector<ScaleMergeInfo> GetOutputNodeMerges(
|
|||
Status ProcessNode(
|
||||
Graph& graph, Node& node, bool& modified,
|
||||
const std::unordered_set<std::string>& excluded_initializer_names) {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9}) &&
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9, 13}) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeScaleMatMul", {1}, kMSDomain)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -96,8 +96,8 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
|
|||
auto& node = *graph.GetNode(node_index);
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if ((!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9}) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeScaleMatMul", {1}, kMSDomain)) ||
|
||||
if ((!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9, 13}) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeScaleMatMul", {1, 13}, kMSDomain)) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -937,23 +937,23 @@ void NchwcTransformerImpl::Transform(Node& node) {
|
|||
// node may already have all inputs converted to NCHWc format and is not
|
||||
// needed for correct operation. This avoids doing extra string checks for
|
||||
// nodes unrelated to this transformer.
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sum", {6, 8})) {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sum", {6, 8, 13})) {
|
||||
TransformBinary(node, true);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7})) {
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13})) {
|
||||
TransformBinary(node, false);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11})) {
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11, 13})) {
|
||||
TransformConcat(node);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6})) {
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13})) {
|
||||
TransformActivation(node);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9})) {
|
||||
TransformBatchNormalization(node);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1})) {
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13})) {
|
||||
TransformTranspose(node);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Upsample", {9}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Resize", {10, 11})) {
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Upsample", {9, 13}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Resize", {10, 11, 13})) {
|
||||
TransformResize(node);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -68,6 +68,11 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
|
|||
replace_min = true;
|
||||
}
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
|
||||
if (i.data<BFloat16>()->ToFloat() < 0.f) {
|
||||
replace_min = true;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Unexpected data type for Clip 'min' input of ", initializer->data_type());
|
||||
}
|
||||
|
|
@ -110,7 +115,7 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
|
|||
}
|
||||
|
||||
bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6})) {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -122,7 +127,7 @@ bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node, const
|
|||
// as Clip will apply the minimum. If the Clip 'min' value is < 0 we need
|
||||
// to update it to 0 to apply what the Relu would have done. We do that in Apply.
|
||||
const auto& next_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12, 13}) ||
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c
|
|||
Node& reshape = *p_reshape;
|
||||
ORT_RETURN_IF_ERROR(Recurse(reshape, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reshape, "Reshape", {5}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reshape, "Reshape", {5, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(reshape, GetCompatibleExecutionProviders())) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -48,9 +48,9 @@ bool ReshapeFusion::Match_One_Element_Output_Subgraph_1(Graph& graph, const Node
|
|||
int index, std::vector<int64_t> shape_value, bool checkOneElementOnly,
|
||||
const logging::Logger& logger) {
|
||||
std::vector<graph_utils::EdgeEndToMatch> parent_path{
|
||||
{0, index, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1}, kOnnxDomain}};
|
||||
{0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1, 13}, kOnnxDomain}};
|
||||
std::vector<const Node::EdgeEnd*> edges;
|
||||
if (graph_utils::FindPath(concat, true, parent_path, edges, logger)) {
|
||||
const Node& unsqueeze = edges[0]->GetNode();
|
||||
|
|
@ -87,9 +87,9 @@ bool ReshapeFusion::Match_One_Element_Output_Subgraph_1(Graph& graph, const Node
|
|||
bool ReshapeFusion::Match_One_Element_Output_Subgraph_2(Graph& graph, const NodeArg& root_input, const Node& cur_node,
|
||||
int index, const logging::Logger& logger) {
|
||||
std::vector<graph_utils::EdgeEndToMatch> parent_path{
|
||||
{0, index, "Squeeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Slice", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1}, kOnnxDomain}};
|
||||
{0, index, "Squeeze", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Slice", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1, 13}, kOnnxDomain}};
|
||||
std::vector<const Node::EdgeEnd*> edges;
|
||||
if (graph_utils::FindPath(cur_node, true, parent_path, edges, logger)) {
|
||||
const Node& slice = edges[1]->GetNode();
|
||||
|
|
@ -157,15 +157,15 @@ bool ReshapeFusion::Is_One_Element_Output_Subgraph(Graph& graph, const NodeArg&
|
|||
}
|
||||
|
||||
std::vector<graph_utils::EdgeEndToMatch> div_path{
|
||||
{0, index, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Div", {7}, kOnnxDomain}};
|
||||
{0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Div", {7, 13}, kOnnxDomain}};
|
||||
|
||||
std::vector<graph_utils::EdgeEndToMatch> mul_path{
|
||||
{0, index, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Mul", {7}, kOnnxDomain}};
|
||||
{0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
|
||||
{0, 0, "Mul", {7, 13}, kOnnxDomain}};
|
||||
|
||||
std::vector<graph_utils::EdgeEndToMatch> unsqueeze_path{
|
||||
{0, index, "Unsqueeze", {1, 11}, kOnnxDomain}};
|
||||
{0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain}};
|
||||
|
||||
std::vector<const Node::EdgeEnd*> edges;
|
||||
if (graph_utils::FindPath(concat, true, div_path, edges, logger) ||
|
||||
|
|
@ -242,7 +242,7 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo
|
|||
}
|
||||
const Node& concat = *p_concat;
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(concat, "Concat", {1, 4, 11})) {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(concat, "Concat", {1, 4, 11, 13})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& ru
|
|||
}
|
||||
|
||||
bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Shape", {1})) {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Shape", {1, 13})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ using namespace onnxruntime::common;
|
|||
namespace onnxruntime {
|
||||
|
||||
// LayerNorm supports limited data types.
|
||||
static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(float)"};
|
||||
static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"};
|
||||
|
||||
static bool IsSupportedDataType(const Node& node) {
|
||||
for (const auto& input_arg : node.InputDefs()) {
|
||||
|
|
@ -163,8 +163,8 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
|
||||
// Format 1
|
||||
std::vector<graph_utils::EdgeEndToMatch> format1_parent_path{
|
||||
{0, 0, "Add", {7}, kOnnxDomain},
|
||||
{0, 0, "Add", {7}, kOnnxDomain}};
|
||||
{0, 0, "Add", {7, 13}, kOnnxDomain},
|
||||
{0, 0, "Add", {7, 13}, kOnnxDomain}};
|
||||
|
||||
std::vector<const Node::EdgeEnd*> edges;
|
||||
if (graph_utils::FindPath(ln_node, true, format1_parent_path, edges, logger)) {
|
||||
|
|
@ -182,8 +182,8 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
if (matched_format == Format::None) {
|
||||
// Format 2
|
||||
std::vector<graph_utils::EdgeEndToMatch> format2_parent_path{
|
||||
{0, 0, "Add", {7}, kOnnxDomain},
|
||||
{0, 1, "Add", {7}, kOnnxDomain}};
|
||||
{0, 0, "Add", {7, 13}, kOnnxDomain},
|
||||
{0, 1, "Add", {7, 13}, kOnnxDomain}};
|
||||
|
||||
if (graph_utils::FindPath(ln_node, true, format2_parent_path, edges, logger)) {
|
||||
p_add1 = const_cast<Node*>(&edges[0]->GetNode());
|
||||
|
|
@ -201,7 +201,7 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
if (matched_format == Format::None) {
|
||||
// Format 3
|
||||
std::vector<graph_utils::EdgeEndToMatch> format3_parent_path{
|
||||
{0, 0, "Add", {7}, kOnnxDomain}};
|
||||
{0, 0, "Add", {7, 13}, kOnnxDomain}};
|
||||
|
||||
if (graph_utils::FindPath(ln_node, true, format3_parent_path, edges, logger)) {
|
||||
p_add1 = const_cast<Node*>(&edges[0]->GetNode());
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_e
|
|||
|
||||
bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
|
||||
// We currently support elimination for Slice operator v1.
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11})) {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -42,8 +42,8 @@ bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node, cons
|
|||
if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) {
|
||||
return false;
|
||||
}
|
||||
} else if (graph_utils::MatchesOpSinceVersion(node, {10, 11})) {
|
||||
// If it is a Slice operator of opset version 10 or 11, starts/ends/axes/steps are provided as node inputs.
|
||||
} else if (graph_utils::MatchesOpSinceVersion(node, {10, 11, 13})) {
|
||||
// If it is a Slice operator of opset version >= 10, starts/ends/axes/steps are provided as node inputs.
|
||||
|
||||
// Returns a pointer to the corresponding NodeArg if input of the node at this index exists; otherwise, a nullptr.
|
||||
auto get_input_if_exists = [&node](size_t input_idx) -> const NodeArg* {
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ void FuseResidualAddIfAny(Graph& graph, const Node& dropout_node,
|
|||
for (auto last_node_itr = dropout_node.OutputNodesBegin(); last_node_itr != dropout_node.OutputNodesEnd(); ++last_node_itr) {
|
||||
const Node& last_node = (*last_node_itr);
|
||||
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Add", {7}) &&
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Add", {7, 13}) &&
|
||||
last_node.GetExecutionProviderType() == dropout_node.GetExecutionProviderType()) {
|
||||
const TensorShapeProto* input1_shape = last_node.InputDefs()[0]->Shape();
|
||||
const TensorShapeProto* input2_shape = last_node.InputDefs()[1]->Shape();
|
||||
|
|
@ -83,7 +83,7 @@ Status BiasDropoutFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve
|
|||
std::vector<std::reference_wrapper<Node>> nodes_to_fuse;
|
||||
|
||||
// matching for bias Add node
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
|
|
@ -127,7 +127,7 @@ Status BiasDropoutFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve
|
|||
}
|
||||
|
||||
const Node& next_node = (*next_node_itr);
|
||||
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Dropout", {12}, kOnnxDomain) ||
|
||||
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Dropout", {12, 13}, kOnnxDomain) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "TrainableDropout", {9}, kOnnxDomain)) ||
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ Status InsertSoftmaxCrossEntropyLossOutput::Apply(Graph& graph, Node& node, Rewr
|
|||
}
|
||||
|
||||
bool InsertSoftmaxCrossEntropyLossOutput::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "SoftmaxCrossEntropyLoss", {12}) &&
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "SoftmaxCrossEntropyLoss", {12, 13}) &&
|
||||
node.OutputDefs().size() == 1) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,24 +28,25 @@ struct OpInfo {
|
|||
size_t output_count;
|
||||
};
|
||||
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v1 = {1};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v1_11 = {1, 11};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v2_11 = {2, 11};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v5 = {5};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v7 = {7};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v1_13 = {1, 13};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v1_11_13 = {1, 11, 13};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v2_11_13 = {2, 11, 13};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v5_13 = {5, 13};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v7_13 = {7, 13};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v9 = {9};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v12 = {12};
|
||||
const OpInfo add_info = OpInfo("Add", opset_v7);
|
||||
const OpInfo split_info = OpInfo("Split", opset_v2_11, kOnnxDomainAlias, 3);
|
||||
const OpInfo reshape_info = OpInfo("Reshape", opset_v5);
|
||||
const OpInfo transpose_info = OpInfo("Transpose", opset_v1);
|
||||
const OpInfo matmul_info = OpInfo("MatMul", opset_v9);
|
||||
const OpInfo div_info = OpInfo("Div", opset_v7);
|
||||
const OpInfo mul_info = OpInfo("Mul", opset_v7);
|
||||
const OpInfo sub_info = OpInfo("Sub", opset_v7);
|
||||
const OpInfo softmax_info = OpInfo("Softmax", opset_v1_11);
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v9_13 = {9, 13};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v12_13 = {12, 13};
|
||||
const OpInfo add_info = OpInfo("Add", opset_v7_13);
|
||||
const OpInfo split_info = OpInfo("Split", opset_v2_11_13, kOnnxDomainAlias, 3);
|
||||
const OpInfo reshape_info = OpInfo("Reshape", opset_v5_13);
|
||||
const OpInfo transpose_info = OpInfo("Transpose", opset_v1_13);
|
||||
const OpInfo matmul_info = OpInfo("MatMul", opset_v9_13);
|
||||
const OpInfo div_info = OpInfo("Div", opset_v7_13);
|
||||
const OpInfo mul_info = OpInfo("Mul", opset_v7_13);
|
||||
const OpInfo sub_info = OpInfo("Sub", opset_v7_13);
|
||||
const OpInfo softmax_info = OpInfo("Softmax", opset_v1_11_13);
|
||||
const OpInfo trainable_dropout_info = OpInfo("TrainableDropout", opset_v9, kOnnxDomain);
|
||||
const OpInfo dropout_info = OpInfo("Dropout", opset_v12);
|
||||
const OpInfo dropout_info = OpInfo("Dropout", opset_v12_13);
|
||||
|
||||
struct NodeInfo {
|
||||
NodeInfo(const std::vector<OpInfo>& op_infos,
|
||||
|
|
@ -243,7 +244,7 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph
|
|||
auto& node = *graph.GetNode(node_index);
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
|
|
@ -257,7 +258,7 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph
|
|||
}
|
||||
|
||||
Node& add_node = *graph.GetNode(node.OutputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13}) ||
|
||||
add_node.GetExecutionProviderType() != node.GetExecutionProviderType() ||
|
||||
add_node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
|
|
@ -272,14 +273,14 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph
|
|||
}
|
||||
|
||||
Node& matmul2_node = *graph.GetNode(gelu_node.OutputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(matmul2_node, "MatMul", {9}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(matmul2_node, "MatMul", {9, 13}) ||
|
||||
matmul2_node.GetExecutionProviderType() != node.GetExecutionProviderType() ||
|
||||
matmul2_node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node& add2_node = *graph.GetNode(matmul2_node.OutputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7, 13}) ||
|
||||
add2_node.GetExecutionProviderType() != node.GetExecutionProviderType() ||
|
||||
add2_node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
|
|
@ -368,7 +369,7 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified,
|
|||
auto& node = *graph.GetNode(node_index);
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", opset_v9) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", opset_v9_13) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
|
|
@ -603,7 +604,7 @@ Status MegatronTransformer::TransformDropout(Graph& graph, bool& modified, int g
|
|||
continue;
|
||||
}
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", opset_v12) &&
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", opset_v12_13) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "TrainableDropout", opset_v9, kOnnxDomain)) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue