mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
Do not apply QuickGeluFusion if an intermediate tensor is a graph output (#15109)
This commit is contained in:
parent
026fb3ca1e
commit
55495cc809
2 changed files with 69 additions and 2 deletions
|
|
@ -30,7 +30,8 @@ Status QuickGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
int alpha_index = -1;
|
||||
float alpha = 1.0f;
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14}) &&
|
||||
graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && node.GetOutputEdgesCount() == 1) {
|
||||
graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && node.GetOutputEdgesCount() == 1 &&
|
||||
!graph.NodeProducesGraphOutput(node)) {
|
||||
for (int i = 0; i < static_cast<int>(node.InputDefs().size()); ++i) {
|
||||
const NodeArg& input_arg = *(node.InputDefs()[i]);
|
||||
if (!optimizer_utils::IsScalar(input_arg)) continue;
|
||||
|
|
@ -68,7 +69,7 @@ Status QuickGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
Node& sigmoid_node = *p_sigmoid_node;
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sigmoid_node, "Sigmoid", {6, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(sigmoid_node, GetCompatibleExecutionProviders()) ||
|
||||
sigmoid_node.GetOutputEdgesCount() != 1) {
|
||||
sigmoid_node.GetOutputEdgesCount() != 1 || graph.NodeProducesGraphOutput(sigmoid_node)) {
|
||||
continue;
|
||||
}
|
||||
nodes_to_fuse.emplace_back(sigmoid_node);
|
||||
|
|
|
|||
|
|
@ -4049,6 +4049,72 @@ TEST_F(GraphTransformationTests, QuickGelu) {
|
|||
pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
// Sigmoid's output is a graph output.
|
||||
{
|
||||
constexpr float alpha = 1.702f;
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
|
||||
auto* alpha_arg = builder.MakeInitializer<float>({}, {alpha});
|
||||
auto* mul_out_0 = builder.MakeIntermediate();
|
||||
auto* sigmoid_out = builder.MakeOutput();
|
||||
auto* mul_out_1 = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Mul", {alpha_arg, input_arg}, {mul_out_0});
|
||||
builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out});
|
||||
builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out_1});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 2);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sigmoid"] == 1);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 2);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sigmoid"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["com.microsoft.QuickGelu"] == 0);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<QuickGeluFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
// First Mul's output is a graph output.
|
||||
{
|
||||
constexpr float alpha = 1.702f;
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
|
||||
auto* alpha_arg = builder.MakeInitializer<float>({}, {alpha});
|
||||
auto* mul_out_0 = builder.MakeOutput();
|
||||
auto* sigmoid_out = builder.MakeIntermediate();
|
||||
auto* mul_out_1 = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Mul", {alpha_arg, input_arg}, {mul_out_0});
|
||||
builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out});
|
||||
builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out_1});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 2);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sigmoid"] == 1);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 2);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sigmoid"] == 1);
|
||||
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["com.microsoft.QuickGelu"] == 0);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<QuickGeluFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
// Sigmoid(x)*x, float
|
||||
{
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue