mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
make layernorm fusion to support opset 11 (#2545)
This commit is contained in:
parent
eeb28a80c0
commit
34beafc51c
1 changed files with 10 additions and 10 deletions
|
|
@ -16,7 +16,7 @@ static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(
|
|||
static bool IsSupportedDataType(const Node& node) {
|
||||
for (const auto& input_arg : node.InputDefs()) {
|
||||
if (std::find(supported_data_types.begin(), supported_data_types.end(),
|
||||
*(input_arg->Type())) == supported_data_types.end()) {
|
||||
*(input_arg->Type())) == supported_data_types.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -56,7 +56,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
Node& reduce_mean_node = *p_reduce_mean;
|
||||
ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11}) ||
|
||||
!graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) ||
|
||||
(reduce_mean_node.GetOutputEdgesCount() != 1 && reduce_mean_node.GetOutputEdgesCount() != 2) ||
|
||||
!IsSupportedDataType(reduce_mean_node)) {
|
||||
|
|
@ -95,7 +95,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
continue;
|
||||
}
|
||||
nodes_to_remove.push_back(sub_node);
|
||||
|
||||
|
||||
// Find the "Div" node after "Sub".
|
||||
const Node* p_div = nullptr;
|
||||
p_div = graph_utils::FirstChildByType(sub_node, "Div");
|
||||
|
|
@ -110,7 +110,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
continue;
|
||||
}
|
||||
nodes_to_remove.push_back(sub_node_dup);
|
||||
// Find Div node after the duplicated sub node if it's not found after the first sub node.
|
||||
// Find Div node after the duplicated sub node if it's not found after the first sub node.
|
||||
if (p_div == nullptr) {
|
||||
p_div = graph_utils::FirstChildByType(sub_node_dup, "Div");
|
||||
}
|
||||
|
|
@ -138,7 +138,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sqrt_node, "Sqrt", {6}) ||
|
||||
sqrt_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
sqrt_node.GetOutputEdgesCount() != 1 ||
|
||||
!IsSupportedDataType(sqrt_node) ||
|
||||
!IsSupportedDataType(sqrt_node) ||
|
||||
sqrt_node.GetInputEdgesCount() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -162,10 +162,10 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
continue;
|
||||
}
|
||||
Node& reduce_mean2_node = *graph.GetNode(p_reduce_mean2->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11}) ||
|
||||
reduce_mean2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
reduce_mean2_node.GetOutputEdgesCount() != 1 ||
|
||||
!IsSupportedDataType(reduce_mean2_node) ||
|
||||
!IsSupportedDataType(reduce_mean2_node) ||
|
||||
reduce_mean2_node.GetInputEdgesCount() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -222,7 +222,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
NodeArg* scale = nullptr;
|
||||
NodeArg* bias = nullptr;
|
||||
for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
|
||||
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
|
||||
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
|
||||
graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
|
||||
// Scale must be 1d.
|
||||
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) {
|
||||
|
|
@ -244,7 +244,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
continue;
|
||||
}
|
||||
|
||||
// Scale and bias must have the same dimension.
|
||||
// Scale and bias must have the same dimension.
|
||||
if (scale->Shape()->dim(0).dim_value() != bias->Shape()->dim(0).dim_value()) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -267,4 +267,4 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue