diff --git a/onnxruntime/core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.cc b/onnxruntime/core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.cc index cc0f785479..9d53e28921 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.cc @@ -53,7 +53,7 @@ Status DuplicateDQForOutputEdge(const graph_utils::GraphEdge& original_dq_output MakeString("Added by ", kTransformerName), dq_inputs, {&new_dq_output_nodearg}, - nullptr, // attributes + &original_dq_node.GetAttributes(), original_dq_node.Domain()); // set up edges diff --git a/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc b/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc index 7a67747f7c..89ffb8ec87 100644 --- a/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc +++ b/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc @@ -234,4 +234,44 @@ TEST(EnsureUniqueDQForNodeUnitTests, QDQWithMultiConsumerDQNodes) { EXPECT_EQ(OpCount(op_count_before, "DequantizeLinear") + 4, OpCount(op_count_after, "DequantizeLinear")); } +TEST(EnsureUniqueDQForNodeUnitTests, QDQWithMultiConsumerDQNodesPreservingAttributes) { + constexpr auto model_uri = ORT_TSTR("testdata/qdq_with_multi_consumer_q_dq_axis.onnx"); + + SessionOptions session_options{}; + // test interaction with level 1 transformers + session_options.graph_optimization_level = TransformerLevel::Level1; + + InferenceSessionWrapper session{session_options, GetEnvironment()}; + + ASSERT_STATUS_OK(session.Load(model_uri)); + + const auto op_count_before = CountOpsInGraph(session.GetGraph()); + + ASSERT_STATUS_OK(session.Initialize()); + + const auto op_count_after = CountOpsInGraph(session.GetGraph()); + + EXPECT_EQ(OpCount(op_count_before, "DequantizeLinear") + 8, OpCount(op_count_after, "DequantizeLinear")); + + int64_t given_axis = 0; // all the following 4 DQ nodes and their duplicated one should have axis = 0 + std::string axis_dq_name0 = "Convolution28_Output_0/fusedmuladd_B/DequantizeLinear"; + std::string axis_dq_name1 = "Parameter5/DequantizeLinear"; + std::string axis_dq_name2 = "Convolution110_Output_0/fusedmuladd_B/DequantizeLinear"; + std::string axis_dq_name3 = "Parameter87/DequantizeLinear"; + for (const auto& node : session.GetGraph().Nodes()) { + if (node.OpType() == "DequantizeLinear") { + if (node.Name().find(axis_dq_name0) == 0 || + node.Name().find(axis_dq_name1) == 0 || + node.Name().find(axis_dq_name2) == 0 || + node.Name().find(axis_dq_name3) == 0) { + const auto& attrs = node.GetAttributes(); + ASSERT_TRUE(attrs.find("axis") != attrs.end()); + const auto& axis_attr = attrs.at("axis"); + int64_t axis = axis_attr.i(); + EXPECT_EQ(axis, given_axis); + } + } + } +} + } // namespace onnxruntime::test diff --git a/onnxruntime/test/testdata/qdq_with_multi_consumer_q_dq_axis.onnx b/onnxruntime/test/testdata/qdq_with_multi_consumer_q_dq_axis.onnx new file mode 100644 index 0000000000..4f575ebb28 Binary files /dev/null and b/onnxruntime/test/testdata/qdq_with_multi_consumer_q_dq_axis.onnx differ