Do not apply QuickGeluFusion if an intermediate tensor is a graph output (#15109)

This commit is contained in:
Deokhwan Kim 2023-04-07 02:17:06 +09:00 committed by GitHub
parent 026fb3ca1e
commit 55495cc809
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 2 deletions

View file

@ -30,7 +30,8 @@ Status QuickGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
int alpha_index = -1;
float alpha = 1.0f;
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14}) &&
graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && node.GetOutputEdgesCount() == 1) {
graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && node.GetOutputEdgesCount() == 1 &&
!graph.NodeProducesGraphOutput(node)) {
for (int i = 0; i < static_cast<int>(node.InputDefs().size()); ++i) {
const NodeArg& input_arg = *(node.InputDefs()[i]);
if (!optimizer_utils::IsScalar(input_arg)) continue;
@ -68,7 +69,7 @@ Status QuickGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
Node& sigmoid_node = *p_sigmoid_node;
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sigmoid_node, "Sigmoid", {6, 13}) ||
!graph_utils::IsSupportedProvider(sigmoid_node, GetCompatibleExecutionProviders()) ||
sigmoid_node.GetOutputEdgesCount() != 1) {
sigmoid_node.GetOutputEdgesCount() != 1 || graph.NodeProducesGraphOutput(sigmoid_node)) {
continue;
}
nodes_to_fuse.emplace_back(sigmoid_node);

View file

@ -4049,6 +4049,72 @@ TEST_F(GraphTransformationTests, QuickGelu) {
pre_graph_checker, post_graph_checker));
}
// Sigmoid's output is a graph output.
{
constexpr float alpha = 1.702f;
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
auto* alpha_arg = builder.MakeInitializer<float>({}, {alpha});
auto* mul_out_0 = builder.MakeIntermediate();
auto* sigmoid_out = builder.MakeOutput();
auto* mul_out_1 = builder.MakeOutput();
builder.AddNode("Mul", {alpha_arg, input_arg}, {mul_out_0});
builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out});
builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out_1});
};
auto pre_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 2);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sigmoid"] == 1);
return Status::OK();
};
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 2);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sigmoid"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["com.microsoft.QuickGelu"] == 0);
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<QuickGeluFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
}
// First Mul's output is a graph output.
{
constexpr float alpha = 1.702f;
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
auto* alpha_arg = builder.MakeInitializer<float>({}, {alpha});
auto* mul_out_0 = builder.MakeOutput();
auto* sigmoid_out = builder.MakeIntermediate();
auto* mul_out_1 = builder.MakeOutput();
builder.AddNode("Mul", {alpha_arg, input_arg}, {mul_out_0});
builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out});
builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out_1});
};
auto pre_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 2);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sigmoid"] == 1);
return Status::OK();
};
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 2);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sigmoid"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["com.microsoft.QuickGelu"] == 0);
return Status::OK();
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<QuickGeluFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
}
// Sigmoid(x)*x, float
{
auto build_test_case = [&](ModelTestBuilder& builder) {