make layernorm fusion to support opset 11 (#2545)

This commit is contained in:
Yufeng Li 2019-12-06 13:06:36 -08:00 committed by GitHub
parent eeb28a80c0
commit 34beafc51c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -16,7 +16,7 @@ static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(
static bool IsSupportedDataType(const Node& node) {
for (const auto& input_arg : node.InputDefs()) {
if (std::find(supported_data_types.begin(), supported_data_types.end(),
*(input_arg->Type())) == supported_data_types.end()) {
*(input_arg->Type())) == supported_data_types.end()) {
return false;
}
}
@ -56,7 +56,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
Node& reduce_mean_node = *p_reduce_mean;
ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11}) ||
!graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) ||
(reduce_mean_node.GetOutputEdgesCount() != 1 && reduce_mean_node.GetOutputEdgesCount() != 2) ||
!IsSupportedDataType(reduce_mean_node)) {
@ -95,7 +95,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
nodes_to_remove.push_back(sub_node);
// Find the "Div" node after "Sub".
const Node* p_div = nullptr;
p_div = graph_utils::FirstChildByType(sub_node, "Div");
@ -110,7 +110,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
nodes_to_remove.push_back(sub_node_dup);
// Find Div node after the duplicated sub node if it's not found after the first sub node.
// Find Div node after the duplicated sub node if it's not found after the first sub node.
if (p_div == nullptr) {
p_div = graph_utils::FirstChildByType(sub_node_dup, "Div");
}
@ -138,7 +138,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sqrt_node, "Sqrt", {6}) ||
sqrt_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
sqrt_node.GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(sqrt_node) ||
!IsSupportedDataType(sqrt_node) ||
sqrt_node.GetInputEdgesCount() == 0) {
continue;
}
@ -162,10 +162,10 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
Node& reduce_mean2_node = *graph.GetNode(p_reduce_mean2->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11}) ||
reduce_mean2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
reduce_mean2_node.GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(reduce_mean2_node) ||
!IsSupportedDataType(reduce_mean2_node) ||
reduce_mean2_node.GetInputEdgesCount() == 0) {
continue;
}
@ -222,7 +222,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
NodeArg* scale = nullptr;
NodeArg* bias = nullptr;
for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
// Scale must be 1d.
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) {
@ -244,7 +244,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
// Scale and bias must have the same dimension.
// Scale and bias must have the same dimension.
if (scale->Shape()->dim(0).dim_value() != bias->Shape()->dim(0).dim_value()) {
continue;
}
@ -267,4 +267,4 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
return Status::OK();
}
} // namespace onnxruntime
} // namespace onnxruntime