mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Extend QDQPropagation transformer to handle multiple consumers (#21313)
### Description
- Extends the QDQPropagationTransformer to propagate DQs (forward)
across operators with multiple consumers (previously only supported 1
consumer).
- Adds Slice to the list of operators that the QDQPropagationTransformer
can propagate DQ/Q ops across.
- Supports QDQ propagation for opset 21.
- Correctly copies Q or DQ attributes when creating new nodes.
### Motivation and Context
The QDQPropagationTransformer fixes up QDQ node units for certain "data
movement" ops (e.g., Transpose) by inserting Q -> DQ sequences where
necessary. For example, the sequence `DQ -> Transpose -> Sigmoid` is
transformed to `DQ -> Transpose -> Q -> DQ -> Sigmoid`.
However, this fix-up does not currently support data movement ops with
multiple consumers, as in:
```
DQ -> Transpose --+--> Sigmoid ->
|
+--> Relu ->
|
+-> graph_output
```
With the updates in this PR, the above model can be transformed to:
```
DQ -> Transpose -> Q --+--> DQ -> Sigmoid ->
|
+--> DQ -> Relu ->
|
+--> DQ -> graph_output
```
This update allows QNN EP to support quantized models created with tools
that do not wrap data movement ops in Q/DQ ops.
---------
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
This commit is contained in:
parent
c203d89958
commit
f4edf9bb58
3 changed files with 417 additions and 90 deletions
|
|
@ -3,8 +3,13 @@
|
|||
|
||||
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <optional>
|
||||
#include <queue>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
|
||||
#include "core/common/inlined_containers_fwd.h"
|
||||
#include "core/graph/extended_graph_edge.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
|
|
@ -17,39 +22,147 @@ namespace onnxruntime {
|
|||
namespace {
|
||||
bool CanNodePropagate(const Node& node) {
|
||||
return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {12}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {5, 13, 14, 19}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Squeeze", {1, 11, 13}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Unsqueeze", {1, 11, 13});
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {5, 13, 14, 19, 21}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13, 21}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Squeeze", {1, 11, 13, 21}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Unsqueeze", {1, 11, 13, 21}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13});
|
||||
}
|
||||
|
||||
// convert this: src_node -> dst_node
|
||||
// to this: src_node -> Q -> DQ -> dst_node
|
||||
// assumptions:
|
||||
// 1. insertion_edge is valid - node indexes refer to valid nodes, arg name refers to a valid NodeArg, and it
|
||||
// corresponds to an actual graph relationship
|
||||
// 2. scale_initializer_nodearg and zp_initializer_nodearg_ptr (if not null) are constant initializers
|
||||
Status InsertQDQPair(Graph& graph, const ExtendedGraphEdge& insertion_edge,
|
||||
NodeArg& scale_initializer_nodearg, NodeArg* zp_initializer_nodearg_ptr,
|
||||
const std::string& qdq_domain, const logging::Logger& logger) {
|
||||
auto* src_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Source);
|
||||
auto* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
|
||||
// Makes matching attributes for new QuantizeLinear nodes from an existing DequantizeLinear node.
|
||||
NodeAttributes MakeQAttrsFromDQ(const Node& dq_node) {
|
||||
assert(dq_node.SinceVersion() <= 21); // Checked by previous call to QDQ::MatchDQNode().
|
||||
// In opset <= 21, all DQ attributes (i.e., axis and block_size) are also Q attributes.
|
||||
// So, set a copy of the DQ attributes.
|
||||
return dq_node.GetAttributes();
|
||||
}
|
||||
|
||||
ORT_ENFORCE(src_node || dst_node, "At least one graph node must be specified in the propagation edge.");
|
||||
// Makes matching attributes for new DequantizeLinear nodes from an existing QuantizeLinear node.
|
||||
NodeAttributes MakeDQAttrsFromQ(const Node& q_node) {
|
||||
assert(q_node.SinceVersion() <= 21); // Checked by previous call to QDQ::MatchQNode().
|
||||
const NodeAttributes& q_attrs = q_node.GetAttributes();
|
||||
if (q_attrs.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const auto& base_name = insertion_edge.arg_name;
|
||||
// In opset <= 21, only the "axis" and "block_size" attributes for Q are also DQ attributes.
|
||||
NodeAttributes dq_attrs;
|
||||
|
||||
auto axis_attr_it = q_attrs.find("axis");
|
||||
if (axis_attr_it != q_attrs.end()) {
|
||||
dq_attrs.insert({axis_attr_it->first, axis_attr_it->second});
|
||||
}
|
||||
|
||||
auto block_size_attr_it = q_attrs.find("block_size");
|
||||
if (block_size_attr_it != q_attrs.end()) {
|
||||
dq_attrs.insert({block_size_attr_it->first, block_size_attr_it->second});
|
||||
}
|
||||
|
||||
return dq_attrs;
|
||||
}
|
||||
|
||||
// Validates edges into which to insert Q -> DQ ops.
|
||||
// - Must have at least one edge.
|
||||
// - All edges must correspond to the same graph NodeArg (i.e., same source but potentially different destination).
|
||||
// - All edges must be attached to either a source node or a destination node.
|
||||
Status ValidateQDQInsertionEdges(Graph& graph, gsl::span<const ExtendedGraphEdge> insertion_edges) {
|
||||
const size_t num_edges = insertion_edges.size();
|
||||
ORT_RETURN_IF(num_edges == 0, "Expected at least one edge into which to insert QDQ pair.");
|
||||
|
||||
const ExtendedGraphEdge& first_edge = insertion_edges[0];
|
||||
const Node* src_node = first_edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Source);
|
||||
const Node* first_dst_node = first_edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
|
||||
const std::string& node_arg_name = first_edge.arg_name;
|
||||
ORT_RETURN_IF_NOT(graph.GetNodeArg(node_arg_name) != nullptr,
|
||||
"QDQ insertion edge does not have a valid graph NodeArg for ", node_arg_name);
|
||||
ORT_RETURN_IF_NOT(src_node != nullptr || first_dst_node != nullptr,
|
||||
"QDQ insertion edge [0] for NodeArg ", node_arg_name,
|
||||
" must have a source or a destination node");
|
||||
|
||||
for (size_t i = 1; i < num_edges; i++) {
|
||||
const ExtendedGraphEdge& insertion_edge = insertion_edges[i];
|
||||
ORT_RETURN_IF_NOT(insertion_edge.arg_name == node_arg_name,
|
||||
"QDQ insertion edge [", i, "] has NodeArg ", insertion_edge.arg_name,
|
||||
" but expected NodeArg ", node_arg_name);
|
||||
|
||||
const Node* edge_dst_node = insertion_edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
|
||||
ORT_RETURN_IF_NOT(src_node != nullptr || edge_dst_node != nullptr,
|
||||
"QDQ insertion edge [", i, "] for NodeArg ", node_arg_name,
|
||||
" must have a source or a destination node");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Logs information about the edges into which Q/DQ nodes will be inserted in InsertQDQPairs().
|
||||
// Assumes the edges have already been validated.
|
||||
void LogQDQInsertion(const logging::Logger& logger, logging::Severity severity, const CodeLocation& code_location,
|
||||
const Graph& graph, gsl::span<const ExtendedGraphEdge> edges) {
|
||||
auto logging_data_type = logging::DataType::SYSTEM;
|
||||
if (!logger.OutputIsEnabled(severity, logging_data_type)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const Node* src_node = edges[0].GetNodeAtEnd(graph, ExtendedGraphEdge::End::Source);
|
||||
const auto& node_arg_name = edges[0].arg_name;
|
||||
std::string src_label = src_node ? MakeString("node (\"", src_node->Name(), "\", index: ", src_node->Index(), ")")
|
||||
: "input";
|
||||
std::ostringstream dst_labels;
|
||||
const size_t num_edges = edges.size();
|
||||
|
||||
for (size_t i = 0; i < num_edges; ++i) {
|
||||
const ExtendedGraphEdge& edge = edges[i];
|
||||
const Node* dst_node = edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
|
||||
dst_labels << (dst_node ? MakeString("dst node (\"", dst_node->Name(), "\", index: ", dst_node->Index(), ")")
|
||||
: "output")
|
||||
<< (i == num_edges - 1 ? "" : ",");
|
||||
}
|
||||
|
||||
logging::Capture(logger, severity, logging::Category::onnxruntime, logging_data_type, code_location).Stream()
|
||||
<< "Inserted Q/DQ pair between "
|
||||
<< (src_node ? MakeString("src node (\"", src_node->Name(), "\", index: ", src_node->Index(), ")")
|
||||
: "input")
|
||||
<< " and " << dst_labels.str()
|
||||
<< " at NodeArg \"" << node_arg_name << "\".";
|
||||
}
|
||||
|
||||
// convert this: src_node (or graph input) --+--> dst_node_0 (or graph output)
|
||||
// |
|
||||
// +--> dst_node_1
|
||||
// | ...
|
||||
// +--> dst_node_n
|
||||
//
|
||||
// to this: src_node (or graph input) -> Q --+--> DQ -> dst_node_0 (or graph output)
|
||||
// |
|
||||
// +--> DQ -> dst_node_1
|
||||
// | ...
|
||||
// +--> DQ -> dst_node_n
|
||||
// Checks that all insertion edges share the same NodeArg. That is, the edges originate from the same source node
|
||||
// output. If there is no src_node, then all edges should come from the same graph input.
|
||||
// This function returns an error status if edges are invalid.
|
||||
//
|
||||
// Assumes that scale_initializer_nodearg and zp_initializer_nodearg_ptr (if not null) are constant initializers.
|
||||
Status InsertQDQPairs(Graph& graph, gsl::span<const ExtendedGraphEdge> insertion_edges,
|
||||
NodeArg& scale_initializer_nodearg, NodeArg* zp_initializer_nodearg_ptr,
|
||||
const std::string& qdq_domain, const NodeAttributes& q_attrs, const NodeAttributes& dq_attrs,
|
||||
const logging::Logger& logger) {
|
||||
ORT_RETURN_IF_ERROR(ValidateQDQInsertionEdges(graph, insertion_edges));
|
||||
|
||||
const ExtendedGraphEdge& first_edge = insertion_edges[0]; // ValidateQDQInsertionEdges() guarantees at least one edge
|
||||
|
||||
Node* src_node = first_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Source); // nullptr for graph input
|
||||
const auto& base_name = first_edge.arg_name;
|
||||
auto& base_node_arg = *graph.GetNodeArg(base_name);
|
||||
|
||||
LOGS(logger, VERBOSE) << "Inserting Q/DQ pair between "
|
||||
<< (src_node ? MakeString("node (\"", src_node->Name(), "\", index: ", src_node->Index(), ")")
|
||||
: "input")
|
||||
<< " and "
|
||||
<< (dst_node ? MakeString("node (\"", dst_node->Name(), "\", index: ", dst_node->Index(), ")")
|
||||
: "output")
|
||||
<< " at NodeArg \"" << base_name << "\".";
|
||||
LogQDQInsertion(logger, logging::Severity::kVERBOSE, ORT_WHERE, graph, insertion_edges);
|
||||
|
||||
// set up new NodeArgs
|
||||
auto& pre_q_nodearg = insertion_edge.HasGraphInputOrInitializer()
|
||||
auto make_q_or_dq_inputs = [](NodeArg& data, NodeArg& scale, NodeArg* zero_point) {
|
||||
return zero_point ? InlinedVector<NodeArg*>{&data, &scale, zero_point}
|
||||
: InlinedVector<NodeArg*>{&data, &scale};
|
||||
};
|
||||
|
||||
// Create Q node that will be inserted after src_node
|
||||
auto& pre_q_nodearg = first_edge.HasGraphInputOrInitializer()
|
||||
? base_node_arg
|
||||
: graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_name + "_pre_q"),
|
||||
nullptr);
|
||||
|
|
@ -57,17 +170,6 @@ Status InsertQDQPair(Graph& graph, const ExtendedGraphEdge& insertion_edge,
|
|||
auto& q_to_dq_nodearg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_name + "_q_to_dq"),
|
||||
nullptr);
|
||||
|
||||
auto& post_dq_nodearg = insertion_edge.HasGraphOutput()
|
||||
? base_node_arg
|
||||
: graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_name + "_post_dq"),
|
||||
nullptr);
|
||||
|
||||
// set up new Nodes
|
||||
auto make_q_or_dq_inputs = [](NodeArg& data, NodeArg& scale, NodeArg* zero_point) {
|
||||
return zero_point ? std::vector<NodeArg*>{&data, &scale, zero_point}
|
||||
: std::vector<NodeArg*>{&data, &scale};
|
||||
};
|
||||
|
||||
auto& q_node = graph.AddNode(graph.GenerateNodeName(base_name + "_q"),
|
||||
QDQ::QOpName,
|
||||
"Inserted by QDQPropagationTransformer",
|
||||
|
|
@ -76,40 +178,61 @@ Status InsertQDQPair(Graph& graph, const ExtendedGraphEdge& insertion_edge,
|
|||
zp_initializer_nodearg_ptr),
|
||||
// outputs
|
||||
{&q_to_dq_nodearg},
|
||||
nullptr, // attributes
|
||||
&q_attrs, // attributes
|
||||
qdq_domain);
|
||||
|
||||
ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(q_node), "Failed to set op schema for added Q node.");
|
||||
|
||||
auto& dq_node = graph.AddNode(graph.GenerateNodeName(base_name + "_dq"),
|
||||
QDQ::DQOpName,
|
||||
"Inserted by QDQPropagationTransformer",
|
||||
// inputs
|
||||
make_q_or_dq_inputs(q_to_dq_nodearg, scale_initializer_nodearg,
|
||||
zp_initializer_nodearg_ptr),
|
||||
// outputs
|
||||
{&post_dq_nodearg},
|
||||
nullptr, // attributes
|
||||
qdq_domain);
|
||||
|
||||
ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(dq_node), "Failed to set op schema for added DQ node.");
|
||||
|
||||
// set up edges
|
||||
if (src_node && dst_node) {
|
||||
graph.RemoveEdge(src_node->Index(), dst_node->Index(),
|
||||
insertion_edge.src->arg_idx, insertion_edge.dst->arg_idx);
|
||||
}
|
||||
|
||||
if (src_node) {
|
||||
src_node->MutableOutputDefs()[insertion_edge.src->arg_idx] = &pre_q_nodearg;
|
||||
graph.AddEdge(src_node->Index(), q_node.Index(), insertion_edge.src->arg_idx, 0);
|
||||
// Remove original edges between src and dst nodes.
|
||||
for (const auto& insertion_edge : insertion_edges) {
|
||||
auto* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
|
||||
|
||||
if (dst_node) {
|
||||
graph.RemoveEdge(src_node->Index(), dst_node->Index(),
|
||||
insertion_edge.src->arg_idx, insertion_edge.dst->arg_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Add edge from src to Q node.
|
||||
src_node->MutableOutputDefs()[first_edge.src->arg_idx] = &pre_q_nodearg;
|
||||
graph.AddEdge(src_node->Index(), q_node.Index(), first_edge.src->arg_idx, 0);
|
||||
}
|
||||
|
||||
graph.AddEdge(q_node.Index(), dq_node.Index(), 0, 0);
|
||||
// Create a DQ node for each dst node and connect remaining edges.
|
||||
for (size_t edge_idx = 0; edge_idx < insertion_edges.size(); ++edge_idx) {
|
||||
const auto& insertion_edge = insertion_edges[edge_idx];
|
||||
const std::string edge_suffix = edge_idx == 0 ? "" : std::to_string(edge_idx);
|
||||
auto& post_dq_nodearg = insertion_edge.HasGraphOutput()
|
||||
? base_node_arg
|
||||
: graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(MakeString(base_name,
|
||||
"_post_dq",
|
||||
edge_suffix)),
|
||||
nullptr);
|
||||
|
||||
if (dst_node) {
|
||||
dst_node->MutableInputDefs()[insertion_edge.dst->arg_idx] = &post_dq_nodearg;
|
||||
graph.AddEdge(dq_node.Index(), dst_node->Index(), 0, insertion_edge.dst->arg_idx);
|
||||
auto& dq_node = graph.AddNode(graph.GenerateNodeName(MakeString(base_name, "_dq", edge_suffix)),
|
||||
QDQ::DQOpName,
|
||||
"Inserted by QDQPropagationTransformer",
|
||||
// inputs
|
||||
make_q_or_dq_inputs(q_to_dq_nodearg, scale_initializer_nodearg,
|
||||
zp_initializer_nodearg_ptr),
|
||||
// outputs
|
||||
{&post_dq_nodearg},
|
||||
&dq_attrs, // attributes
|
||||
qdq_domain);
|
||||
|
||||
ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(dq_node), "Failed to set op schema for added DQ node.");
|
||||
|
||||
Node* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
|
||||
|
||||
// Add edge from Q to DQ
|
||||
graph.AddEdge(q_node.Index(), dq_node.Index(), 0, 0);
|
||||
|
||||
// Add edge from DQ to dst_node
|
||||
if (dst_node) {
|
||||
dst_node->MutableInputDefs()[insertion_edge.dst->arg_idx] = &post_dq_nodearg;
|
||||
graph.AddEdge(dq_node.Index(), dst_node->Index(), 0, insertion_edge.dst->arg_idx);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -156,37 +279,39 @@ std::optional<ExtendedGraphEdge> GetPreviousPropagationEdge(const Graph& graph,
|
|||
return GetPreviousEdge(graph, *src_node);
|
||||
}
|
||||
|
||||
std::optional<ExtendedGraphEdge> GetNextEdge(const Graph& graph, const Node& node) {
|
||||
// for now we can just consider the first output (index 0)
|
||||
InlinedVector<ExtendedGraphEdge> GetNextEdges(const Graph& graph, const Node& node) {
|
||||
constexpr int node_output_index = 0; // for now we can just consider the first output (index 0)
|
||||
InlinedVector<ExtendedGraphEdge> next_edges;
|
||||
const auto output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node, static_cast<size_t>(node_output_index));
|
||||
|
||||
const auto output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node, 0);
|
||||
if (output_edges.empty()) {
|
||||
// maybe edge to output
|
||||
return ExtendedGraphEdge::TryCreateFromNodeToOutput(graph, node, 0);
|
||||
// edges to next nodes
|
||||
for (const auto& output_edge : output_edges) {
|
||||
next_edges.push_back(ExtendedGraphEdge::CreateFromValidGraphEdge(output_edge));
|
||||
}
|
||||
|
||||
if (!graph.IsOutput(node.OutputDefs()[0]) && output_edges.size() == 1) {
|
||||
// single edge to next node
|
||||
return ExtendedGraphEdge::CreateFromValidGraphEdge(output_edges.front());
|
||||
// maybe edge to graph output
|
||||
auto edge_to_output = ExtendedGraphEdge::TryCreateFromNodeToOutput(graph, node, node_output_index);
|
||||
if (edge_to_output.has_value()) {
|
||||
next_edges.push_back(edge_to_output.value());
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
return next_edges;
|
||||
}
|
||||
|
||||
std::optional<ExtendedGraphEdge> GetNextPropagationEdge(const Graph& graph,
|
||||
const ExtendedGraphEdge& edge) {
|
||||
InlinedVector<ExtendedGraphEdge> GetNextPropagationEdges(const Graph& graph,
|
||||
const ExtendedGraphEdge& edge) {
|
||||
if (edge.HasGraphOutput()) {
|
||||
return std::nullopt;
|
||||
return {};
|
||||
}
|
||||
|
||||
const auto* dst_node = edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
|
||||
ORT_ENFORCE(dst_node != nullptr);
|
||||
|
||||
if (!CanNodePropagate(*dst_node)) {
|
||||
return std::nullopt;
|
||||
return {};
|
||||
}
|
||||
|
||||
return GetNextEdge(graph, *dst_node);
|
||||
return GetNextEdges(graph, *dst_node);
|
||||
}
|
||||
|
||||
class GraphConstantInitializerGetter {
|
||||
|
|
@ -228,21 +353,54 @@ Status PropagateDQForward(Graph& graph, gsl::span<const NodeIndex> node_indices,
|
|||
? dq_node.MutableInputDefs()[QDQ::InputIndex::ZERO_POINT_ID]
|
||||
: nullptr;
|
||||
|
||||
const auto edge_after_dq = GetNextEdge(graph, dq_node);
|
||||
if (!edge_after_dq) {
|
||||
const InlinedVector<ExtendedGraphEdge> edges_after_dq = GetNextEdges(graph, dq_node);
|
||||
if (edges_after_dq.size() != 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto curr_edge = GetNextPropagationEdge(graph, *edge_after_dq);
|
||||
curr_edge.has_value();
|
||||
curr_edge = GetNextPropagationEdge(graph, *curr_edge)) {
|
||||
if (const auto* dst_node = curr_edge->GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
|
||||
dst_node && QDQ::MatchQNode(*dst_node)) {
|
||||
break;
|
||||
// Utility function to check if any edge out of a node (e.g., Transpose) ends in a Q node.
|
||||
auto any_edge_ends_in_q = [](Graph& graph, const InlinedVector<ExtendedGraphEdge>& edges) -> bool {
|
||||
for (const auto& edge : edges) {
|
||||
const auto* edge_dst_node = edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
|
||||
if (edge_dst_node && QDQ::MatchQNode(*edge_dst_node)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// Propagate DQ forward in a BFS traversal of NodeArg edges. A NodeArg "edge group" consists of one or more edges
|
||||
// that all begin at the same source node's output slot and end at a graph output or a destination node.
|
||||
// Ex: The subgraph below shows a NodeArg edge group (containing 3 edges) that begins at a
|
||||
// Transpose, ends at two destination nodes, and produces a graph output.
|
||||
// DQ -> Transpose --+--> Sigmoid -> ...
|
||||
// |
|
||||
// +--> Slice -> ...
|
||||
// |
|
||||
// +--> graph_output
|
||||
std::queue<InlinedVector<ExtendedGraphEdge>> node_arg_edges;
|
||||
node_arg_edges.push(GetNextPropagationEdges(graph, edges_after_dq[0]));
|
||||
|
||||
while (!node_arg_edges.empty()) {
|
||||
const InlinedVector<ExtendedGraphEdge> curr_edge_group = std::move(node_arg_edges.front());
|
||||
node_arg_edges.pop();
|
||||
|
||||
// Skip if edge group is empty. Also, to keep things simple, we do not yet handle edge groups in which
|
||||
// one of the destination nodes is already a QuantizeLinear node. Ex:
|
||||
// DQ -> Transpose --+--> QuantizeLinear -> ...
|
||||
// |
|
||||
// +--> Slice -> ...
|
||||
if (curr_edge_group.empty() || any_edge_ends_in_q(graph, curr_edge_group)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(InsertQDQPair(graph, *curr_edge, dq_scale, dq_zero_point, dq_node.Domain(), logger));
|
||||
ORT_RETURN_IF_ERROR(InsertQDQPairs(graph, curr_edge_group, dq_scale, dq_zero_point, dq_node.Domain(),
|
||||
MakeQAttrsFromDQ(dq_node), dq_node.GetAttributes(), logger));
|
||||
modified = true;
|
||||
|
||||
for (const auto& edge : curr_edge_group) {
|
||||
node_arg_edges.push(GetNextPropagationEdges(graph, edge));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -290,7 +448,8 @@ Status PropagateQBackward(Graph& graph, gsl::span<const NodeIndex> node_indices,
|
|||
break;
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(InsertQDQPair(graph, *curr_edge, q_scale, q_zero_point, q_node.Domain(), logger));
|
||||
ORT_RETURN_IF_ERROR(InsertQDQPairs(graph, InlinedVector<ExtendedGraphEdge>{*curr_edge}, q_scale, q_zero_point,
|
||||
q_node.Domain(), q_node.GetAttributes(), MakeDQAttrsFromQ(q_node), logger));
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -246,14 +246,14 @@ Status TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>&
|
|||
ORT_RETURN_IF_ERROR(pre_graph_checker(graph));
|
||||
}
|
||||
#if SAVE_TEST_GRAPH
|
||||
ORT_RETURN_IF_ERROR(Model::Save(model, "model_original.onnx"));
|
||||
ORT_RETURN_IF_ERROR(Model::Save(model, ToPathString("model_original.onnx")));
|
||||
#endif
|
||||
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, level, logger));
|
||||
if (post_graph_checker) {
|
||||
ORT_RETURN_IF_ERROR(post_graph_checker(graph));
|
||||
}
|
||||
#if SAVE_TEST_GRAPH
|
||||
ORT_RETURN_IF_ERROR(Model::Save(model, "model_optimized.onnx"));
|
||||
ORT_RETURN_IF_ERROR(Model::Save(model, ToPathString("model_optimized.onnx")));
|
||||
#endif
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/optimizer/double_qdq_pairs_remover.h"
|
||||
#include "core/optimizer/qdq_transformer/qdq_final_cleanup.h"
|
||||
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
|
||||
|
|
@ -3084,6 +3085,57 @@ TEST(QDQTransformerTests, QDQPropagation_QBackward) {
|
|||
#endif
|
||||
}
|
||||
|
||||
// Test backwards propagation of a QuantizeLinear node that uses the "output_dtype" attribute
|
||||
// to set the quantization type (i.e., does not have an explicit zero-point input). This tests
|
||||
// the copying of attributes for QDQ propagation.
|
||||
TEST(QDQTransformerTests, QDQPropagation_QBackward_NoZP_OutputDtypeAttribute) {
|
||||
auto test_case = [&](ONNX_NAMESPACE::TensorProto_DataType q_output_type) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<float>({1, 2, 2}, {-2.0f, 0.0f, 1.0f, 2.0f});
|
||||
auto* output_arg = builder.MakeOutput();
|
||||
|
||||
// add Add
|
||||
auto* const_1_input = builder.MakeScalarInitializer<float>(1.0f);
|
||||
auto* add_output = builder.MakeIntermediate();
|
||||
builder.AddNode("Add", {input_arg, const_1_input}, {add_output});
|
||||
|
||||
// add Transpose
|
||||
auto* transpose_output = builder.MakeIntermediate();
|
||||
builder.AddNode("Transpose", {add_output}, {transpose_output});
|
||||
|
||||
// add Q with a "output_dtype" attribute. Omit the zero-point input (defaults to 0).
|
||||
constexpr float qdq_scale = 1.0f;
|
||||
Node& q_node = builder.AddQuantizeLinearNode(transpose_output, qdq_scale, output_arg);
|
||||
q_node.AddAttribute("output_dtype", static_cast<int64_t>(q_output_type));
|
||||
};
|
||||
|
||||
auto check_graph = [&](InferenceSessionWrapper& session) {
|
||||
const QDQOpKeys qdq_keys = GetQDQOpKeys(false);
|
||||
std::vector<std::string> expected_op_types_in_order = {
|
||||
"Add",
|
||||
qdq_keys.quantize_linear,
|
||||
qdq_keys.dequantize_linear,
|
||||
"Transpose",
|
||||
qdq_keys.quantize_linear,
|
||||
};
|
||||
|
||||
const auto op_types_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph(), true);
|
||||
EXPECT_EQ(op_types_in_order, expected_op_types_in_order);
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case,
|
||||
check_graph,
|
||||
TransformerLevel::Default,
|
||||
TransformerLevel::Level1,
|
||||
21); // Opset >= 21 supports the "output_dtype" attribute
|
||||
};
|
||||
|
||||
test_case(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
|
||||
test_case(ONNX_NAMESPACE::TensorProto_DataType_INT8);
|
||||
test_case(ONNX_NAMESPACE::TensorProto_DataType_UINT16);
|
||||
test_case(ONNX_NAMESPACE::TensorProto_DataType_INT16);
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, QDQPropagation_DQForward) {
|
||||
auto test_case = [&](const std::vector<int64_t>& input_shape,
|
||||
size_t maxpool_dim,
|
||||
|
|
@ -3420,6 +3472,122 @@ TEST(QDQTransformerTests, QDQPropagation_DQ_Q) {
|
|||
#endif
|
||||
}
|
||||
|
||||
// Test propagating a DQ forward through a chain of Slice and Transpose operators that have multiple consumers.
|
||||
// original model:
|
||||
// in0 -> DQ -> Slice --+--> slice_out
|
||||
// |
|
||||
// +--> Add -> out0
|
||||
// |
|
||||
// +--> Transpose --+--> Pow -> out1
|
||||
// | |
|
||||
// | +--> Pow -> out2
|
||||
// |
|
||||
// +--> Transpose --+--> Pow -> out3
|
||||
// |
|
||||
// +--> Pow -> out4
|
||||
// expected model:
|
||||
// in0 -> DQ -> Slice -> Q --+--> DQ -> slice_out
|
||||
// |
|
||||
// +--> DQ -> Add -> out0
|
||||
// |
|
||||
// +--> DQ -> TP -> Q --+--> DQ -> Pow -> out1
|
||||
// | |
|
||||
// | +--> DQ -> Pow -> out2
|
||||
// |
|
||||
// +--> DQ -> TP -> Q --+--> DQ -> Pow -> out3
|
||||
// |
|
||||
// +--> DQ -> Pow -> out4
|
||||
TEST(QDQTransformerTests, QDQPropagation_DQForward_SliceMultipleConsumers) {
|
||||
auto run_test_case = [&](bool slice_has_graph_output) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
std::vector<int64_t> input0_shape = {1, 2, 2, 2};
|
||||
std::vector<int64_t> input1_shape = {1, 1, 1, 1};
|
||||
auto* input0_arg = builder.MakeInput<uint8_t>(input0_shape,
|
||||
std::numeric_limits<uint8_t>::min(),
|
||||
std::numeric_limits<uint8_t>::max());
|
||||
auto* input1_arg = builder.MakeInput<float>(input1_shape, {0.0f});
|
||||
auto* output0_arg = builder.MakeOutput();
|
||||
auto* output1_arg = builder.MakeOutput();
|
||||
auto* output2_arg = builder.MakeOutput();
|
||||
auto* output3_arg = builder.MakeOutput();
|
||||
auto* output4_arg = builder.MakeOutput();
|
||||
|
||||
// DQ
|
||||
constexpr float qdq_scale = 1.0f;
|
||||
constexpr uint8_t qdq_zero_point = 128;
|
||||
auto* dq_output = builder.MakeIntermediate();
|
||||
builder.AddDequantizeLinearNode<uint8_t>(input0_arg, qdq_scale, qdq_zero_point, dq_output);
|
||||
|
||||
// Slice
|
||||
auto* slice_output = slice_has_graph_output ? builder.MakeOutput() : builder.MakeIntermediate();
|
||||
auto* slice_starts = builder.Make1DInitializer(std::vector<int64_t>{0, 0, 0, 0});
|
||||
auto* slice_ends = builder.Make1DInitializer(std::vector<int64_t>{1, 1, 1, 1});
|
||||
builder.AddNode("Slice", {dq_output, slice_starts, slice_ends}, {slice_output});
|
||||
|
||||
// Add
|
||||
builder.AddNode("Add", {slice_output, input1_arg}, {output0_arg});
|
||||
|
||||
// Transpose
|
||||
auto* transpose0_output = builder.MakeIntermediate();
|
||||
builder.AddNode("Transpose", {slice_output}, {transpose0_output});
|
||||
|
||||
// Transpose
|
||||
auto* transpose1_output = builder.MakeIntermediate();
|
||||
builder.AddNode("Transpose", {slice_output}, {transpose1_output});
|
||||
|
||||
// Pows
|
||||
auto* pow_exp = builder.MakeScalarInitializer(2.0f);
|
||||
builder.AddNode("Pow", {transpose0_output, pow_exp}, {output1_arg});
|
||||
builder.AddNode("Pow", {transpose0_output, pow_exp}, {output2_arg});
|
||||
builder.AddNode("Pow", {transpose1_output, pow_exp}, {output3_arg});
|
||||
builder.AddNode("Pow", {transpose1_output, pow_exp}, {output4_arg});
|
||||
};
|
||||
|
||||
auto check_graph = [&](InferenceSessionWrapper& session) {
|
||||
const QDQOpKeys qdq_keys = GetQDQOpKeys(false);
|
||||
std::vector<std::string> expected_op_types_in_order;
|
||||
expected_op_types_in_order.reserve(20);
|
||||
expected_op_types_in_order.insert(expected_op_types_in_order.end(),
|
||||
{qdq_keys.dequantize_linear,
|
||||
"Slice",
|
||||
qdq_keys.quantize_linear});
|
||||
|
||||
if (slice_has_graph_output) {
|
||||
// Should have a DQ before the graph output generated by the Slice.
|
||||
expected_op_types_in_order.push_back(qdq_keys.dequantize_linear);
|
||||
}
|
||||
|
||||
expected_op_types_in_order.insert(expected_op_types_in_order.end(),
|
||||
{qdq_keys.dequantize_linear,
|
||||
"Add",
|
||||
qdq_keys.dequantize_linear,
|
||||
"Transpose",
|
||||
qdq_keys.quantize_linear, qdq_keys.dequantize_linear,
|
||||
"Pow",
|
||||
qdq_keys.dequantize_linear,
|
||||
"Pow",
|
||||
qdq_keys.dequantize_linear,
|
||||
"Transpose",
|
||||
qdq_keys.quantize_linear, qdq_keys.dequantize_linear,
|
||||
"Pow",
|
||||
qdq_keys.dequantize_linear,
|
||||
"Pow"});
|
||||
|
||||
const auto op_types_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph(), true);
|
||||
EXPECT_EQ(op_types_in_order, expected_op_types_in_order);
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case,
|
||||
check_graph,
|
||||
TransformerLevel::Default,
|
||||
TransformerLevel::Level1,
|
||||
18, 0.0, 0.0, std::make_unique<QDQPropagationTransformer>());
|
||||
};
|
||||
|
||||
run_test_case(/*slice_has_graph_output*/ false);
|
||||
run_test_case(/*slice_has_graph_output*/ true);
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, QDQ_Selector_Test) {
|
||||
const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/transform/qdq_conv.onnx");
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue