Make partitioning utils QDQ aware so it does not break up QDQ node units (#19723)

### Description
<!-- Describe your changes. -->
If the EP handles QDQ node units, we need to make sure we do not split
those into different partitions.

Update the partitioning utils to be QDQ aware. If there are node units
we process the logical nodes they represent instead of individual nodes.
This ensure we process all nodes in a QDQ node unit at the same time so
that they are always in the same partition.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Fix one of the issues in #19590

---------

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
This commit is contained in:
Scott McKay 2024-03-12 10:55:49 +10:00 committed by GitHub
parent cba605e845
commit 978c40d853
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 886 additions and 489 deletions

View file

@ -37,8 +37,13 @@ if (onnxruntime_BUILD_UNIT_TESTS)
set(gtest_disable_pthreads ON)
endif()
set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
# Needs to update onnxruntime/test/xctest/xcgtest.mm
if (IOS OR ANDROID)
# on mobile platforms the absl flags class dumps the flag names (assumably for binary size), which breaks passing
# any args to gtest executables, such as using --gtest_filter to debug a specific test.
# Processing of compile definitions:
# https://github.com/abseil/abseil-cpp/blob/8dc90ff07402cd027daec520bb77f46e51855889/absl/flags/config.h#L21
# If set, this code throws away the flag and does nothing on registration, which results in no flags being known:
# https://github.com/abseil/abseil-cpp/blob/8dc90ff07402cd027daec520bb77f46e51855889/absl/flags/flag.h#L205-L217
set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE)
else()
set(GTEST_HAS_ABSL ON CACHE BOOL "" FORCE)

View file

@ -70,8 +70,8 @@ list(FILTER coreml_proto_generated_srcs INCLUDE REGEX "\.pb\.(h|cc)$")
source_group(TREE ${CMAKE_CURRENT_BINARY_DIR} PREFIX coreml_proto_generated FILES ${coreml_proto_generated_srcs})
# These are shared utils,
# TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML
file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
# TODO, move this to a separate lib when used by EPs other than NNAPI and CoreML
file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h"
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc"
)

View file

@ -49,12 +49,10 @@
endif()
# These are shared utils,
# TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML
# TODO, move this to a separate lib when used by EPs other than NNAPI and CoreML
list(APPEND onnxruntime_provider_nnapi_cc_src_patterns
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h"
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc"
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h"
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc"
)
file(GLOB onnxruntime_providers_nnapi_cc_srcs CONFIGURE_DEPENDS ${onnxruntime_provider_nnapi_cc_src_patterns})
@ -81,4 +79,4 @@
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR})
endif()
endif()

View file

@ -4,12 +4,10 @@
add_compile_definitions(USE_QNN=1)
# These are shared utils,
# TODO, move this to a separated lib when used by EPs other than QNN, NNAPI and CoreML
file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
# TODO, move to a separate lib when used by EPs other than QNN, NNAPI and CoreML
file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h"
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc"
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h"
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc"
)
file(GLOB_RECURSE
@ -42,4 +40,4 @@
# ignore the warning unknown-pragmas on "pragma region"
if(NOT MSVC)
target_compile_options(onnxruntime_providers_qnn PRIVATE "-Wno-unknown-pragmas")
endif()
endif()

View file

@ -7,9 +7,6 @@
"${ONNXRUNTIME_INCLUDE_DIR}/core/providers/xnnpack/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.cc"
# utils for handling QDQ models
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h"
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc"
)
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_xnnpack_cc_srcs})

View file

@ -0,0 +1,351 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#include "node_unit.h"
#include "core/graph/graph_viewer.h"
namespace onnxruntime {
namespace {
enum class QLinearOpType : uint8_t {
Unknown, // Unknown or not a linear quantized op
DequantizeLinear,
QuantizeLinear,
QLinearConv,
QLinearMatMul,
QLinearAdd,
QLinearSigmoid,
QLinearAveragePool,
QLinearMul,
QLinearReduceMean,
QLinearConcat,
QLinearGlobalAveragePool,
QLinearLeakyRelu,
};
QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) {
const auto& op_type = node.OpType();
if (op_type == "DequantizeLinear")
return QLinearOpType::DequantizeLinear;
else if (op_type == "QuantizeLinear")
return QLinearOpType::QuantizeLinear;
else if (op_type == "QLinearConv")
return QLinearOpType::QLinearConv;
else if (op_type == "QLinearMatMul")
return QLinearOpType::QLinearMatMul;
else if (op_type == "QLinearAdd")
return QLinearOpType::QLinearAdd;
else if (op_type == "QLinearSigmoid")
return QLinearOpType::QLinearSigmoid;
else if (op_type == "QLinearAveragePool")
return QLinearOpType::QLinearAveragePool;
else if (op_type == "QLinearMul")
return QLinearOpType::QLinearMul;
else if (op_type == "QLinearReduceMean")
return QLinearOpType::QLinearReduceMean;
else if (op_type == "QLinearConcat")
return QLinearOpType::QLinearConcat;
else if (op_type == "QLinearGlobalAveragePool")
return QLinearOpType::QLinearGlobalAveragePool;
else if (op_type == "QLinearLeakyRelu")
return QLinearOpType::QLinearLeakyRelu;
return QLinearOpType::Unknown;
}
// Ops have 1 input
bool IsUnaryQLinearOp(QLinearOpType type) {
return type == QLinearOpType::QLinearSigmoid ||
type == QLinearOpType::QLinearAveragePool ||
type == QLinearOpType::QLinearGlobalAveragePool ||
type == QLinearOpType::QLinearLeakyRelu ||
type == QLinearOpType::QLinearReduceMean;
}
// Ops have 2 inputs
bool IsBinaryQLinearOp(QLinearOpType type) {
return type == QLinearOpType::QLinearConv ||
type == QLinearOpType::QLinearMatMul ||
type == QLinearOpType::QLinearAdd ||
type == QLinearOpType::QLinearMul;
}
// Ops have 1 or more inputs
bool IsVariadicQLinearOp(QLinearOpType type) {
return type == QLinearOpType::QLinearConcat;
}
const std::vector<const Node*> GetQDQIONodes(const GraphViewer& graph_viewer,
const QDQ::NodeGroup& node_group, bool is_input) {
std::vector<const Node*> io_nodes;
const auto& src_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes;
io_nodes.reserve(src_nodes.size());
for (const auto& node_idx : src_nodes) {
io_nodes.push_back(graph_viewer.GetNode(node_idx));
}
return io_nodes;
}
// Get the input or output NodeUnitIODef(s) for the given QDQ NodeGroup
std::vector<NodeUnitIODef> GetQDQIODefs(const Node& target_node, const QDQ::NodeGroup& node_group, bool is_input) {
const auto& dq_or_q_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes;
const auto target_node_io_defs = is_input ? target_node.InputDefs() : target_node.OutputDefs();
const size_t target_node_io_defs_size = target_node_io_defs.size();
// Find all the quantized IO defs and indices (for the input/output of the target node)
std::unordered_map<size_t, NodeUnitIODef> quantized_io_defs;
quantized_io_defs.reserve(target_node_io_defs_size);
auto cur = is_input ? target_node.InputEdgesBegin() : target_node.OutputEdgesBegin();
auto end = is_input ? target_node.InputEdgesEnd() : target_node.OutputEdgesEnd();
for (; cur != end; ++cur) {
const Node& node = cur->GetNode();
// If we can find the node index in the dq or q nodes this is a quantized input/output
if (std::find(dq_or_q_nodes.cbegin(), dq_or_q_nodes.cend(), node.Index()) != dq_or_q_nodes.cend()) {
const auto node_inputs = node.InputDefs();
// quantization scale and zp are always the input[1, 2]
NodeUnitIODef::QuantParam quant_param{*node_inputs[1], node_inputs.size() == 3 ? node_inputs[2] : nullptr};
if (is_input) {
// DQ is input to the target node, use the DstArgIndex
auto idx = cur->GetDstArgIndex();
// This is a DQ node, we are using x, x_scale, x_zp (input[0, 1, 2])
quantized_io_defs.insert({idx, NodeUnitIODef{*node_inputs[0], quant_param}});
} else {
// Q is output of the target node, use the SrcArgIndex
auto idx = cur->GetSrcArgIndex();
// This is a Q node, we are using y (output[0]), y_scale, y_zp (input[1, 2])
const auto node_outputs = node.OutputDefs();
quantized_io_defs.insert({idx, NodeUnitIODef{*node_outputs[0], quant_param}});
}
}
}
// Construct the IODefs for this QDQ NodeGroup
std::vector<NodeUnitIODef> io_defs;
io_defs.reserve(target_node_io_defs_size);
for (size_t i = 0; i < target_node_io_defs_size; i++) {
// If we can find the NodeUnitIODef for this index, this is a quantized input/output
if (quantized_io_defs.find(i) != quantized_io_defs.cend()) {
io_defs.push_back(std::move(quantized_io_defs.at(i)));
} else {
// This is a regular input
io_defs.push_back({*target_node_io_defs[i], std::nullopt});
}
}
return io_defs;
}
} // namespace
Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer,
const Node& target_node,
gsl::span<const Node* const> dq_nodes,
gsl::span<const Node* const> q_nodes) {
// Within a QDQ node group, a target node input is the only consumer of each DQ.
// This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications
// may have happened since. Verify that this is still true.
for (const auto* dq_node : dq_nodes) {
const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node);
ORT_RETURN_IF(dq_produces_graph_output,
"QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(),
", target node: ", target_node.Name());
const bool dq_has_single_output_edge_to_target =
dq_node->GetOutputEdgesCount() == 1 &&
dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index();
ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target,
"QDQ node group cannot have DQ that doesn't have a single output edge to the target node. "
"DQ node: ",
dq_node->Name(), ", target node: ", target_node.Name());
}
// an output from the target node can have either Q consumers or direct consumers. it cannot have both.
// this must be checked on a per output basis.
// e.g. TopK produces values and indices. The indices output won't be quantized, so even if we replace the TopK QDQ
// node group with a quantized TopK, an int64_t indices value will be produced and can provide a graph output.
if (!q_nodes.empty()) {
auto cur_edge = target_node.OutputEdgesBegin();
auto end_edge = target_node.OutputEdgesEnd();
std::vector<const Node*> output_consumers(target_node.OutputDefs().size(), nullptr);
for (; cur_edge != end_edge; ++cur_edge) {
auto output_idx = cur_edge->GetSrcArgIndex();
const Node& this_consumer = cur_edge->GetNode();
const Node* existing_consumer = output_consumers[output_idx];
if (existing_consumer != nullptr) {
// another edge for this output. either both are Q or both are not.
bool valid = true;
if (existing_consumer->OpType() == "QuantizeLinear") {
valid = this_consumer.OpType() == "QuantizeLinear";
} else {
valid = this_consumer.OpType() != "QuantizeLinear";
}
ORT_RETURN_IF_NOT(valid,
"QDQ node group cannot have an output from the target node being consumed by a Q node and "
"a non-Q node. target node: ",
target_node.Name());
} else {
output_consumers[output_idx] = &this_consumer;
}
}
const auto& graph_outputs = graph_viewer.GetOutputs();
for (size_t idx = 0, end = output_consumers.size(); idx < end; ++idx) {
// any output with a Q cannot be a graph output as it will disappear if the QDQ node unit is converted to
// a quantized op.
if (output_consumers[idx] != nullptr && output_consumers[idx]->OpType() == "QuantizeLinear") {
const auto& output_name = target_node.OutputDefs()[idx]->Name();
bool is_graph_output = std::any_of(graph_outputs.begin(), graph_outputs.end(),
[&output_name](const NodeArg* node_arg) {
return node_arg->Name() == output_name;
});
ORT_RETURN_IF(is_graph_output,
"QDQ node group cannot have an output from the target node that is consumed by a Q node and "
"a graph output. target node: ",
target_node.Name(), " output idx:", idx);
}
}
}
return Status::OK();
}
NodeUnit::NodeUnit(const Node& node)
: target_node_(node),
type_(Type::SingleNode),
input_edge_count_(node.GetInputEdgesCount()) {
InitForSingleNode();
}
NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group)
: dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)},
target_node_(*graph_viewer.GetNode(node_group.target_node)),
q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)},
type_(Type::QDQGroup),
inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)},
outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} {
ORT_THROW_IF_ERROR(QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, target_node_, dq_nodes_, q_nodes_));
input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0),
[](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); });
// add edges for inputs that are not from DQ nodes. there is one edge to each DQ node.
// other inputs could come from initializers or graph inputs (no edges) or other nodes (edge).
input_edge_count_ += target_node_.GetInputEdgesCount() - dq_nodes_.size();
// create output edges. each target node output either goes to Q node/s or non-Q node/s.
// ValidateNodeGroupQDQNodes ensures this.
auto cur_edge = target_node_.OutputEdgesBegin();
auto end_edge = target_node_.OutputEdgesEnd();
for (; cur_edge != end_edge; ++cur_edge) {
const Node& node = cur_edge->GetNode();
// if node is in q_nodes we hide the Q node.
if (std::find(q_nodes_.cbegin(), q_nodes_.cend(), &node) != q_nodes_.cend()) {
auto src_idx = cur_edge->GetSrcArgIndex();
auto q_cur_edge = node.OutputEdgesBegin();
auto q_end_edge = node.OutputEdgesEnd();
for (; q_cur_edge != q_end_edge; ++q_cur_edge) {
output_edges_.insert(Node::EdgeEnd{q_cur_edge->GetNode(), src_idx, q_cur_edge->GetDstArgIndex()});
}
} else {
// non-Q node, or Q node that isn't in the QDQ node group (unexpected but may be possible). add as-is.
output_edges_.insert(*cur_edge);
}
}
}
const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); }
const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); }
const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); }
int NodeUnit::SinceVersion() const noexcept { return target_node_.SinceVersion(); }
NodeIndex NodeUnit::Index() const noexcept { return target_node_.Index(); }
const Path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); }
ProviderType NodeUnit::GetExecutionProviderType() const noexcept { return target_node_.GetExecutionProviderType(); }
void NodeUnit::InitForSingleNode() {
const auto& input_defs = target_node_.InputDefs();
const auto& output_defs = target_node_.OutputDefs();
auto qlinear_type = GetQLinearOpType(target_node_);
if (qlinear_type == QLinearOpType::Unknown || IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support
// Not a Qlinear op, add all inputs / outputs
auto add_all_io = [](std::vector<NodeUnitIODef>& defs,
const ConstPointerContainer<std::vector<NodeArg*>>& node_defs) {
defs.reserve(node_defs.size());
for (const auto def : node_defs) {
defs.push_back(NodeUnitIODef{*def, std::nullopt});
}
};
add_all_io(inputs_, input_defs);
add_all_io(outputs_, output_defs);
} else if (IsUnaryQLinearOp(qlinear_type)) {
// Unary QLinear Op has 5 inputs
// x, x_scale, x_zp, y_scale, y_zp (optional)
inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}});
outputs_.push_back(NodeUnitIODef{*output_defs[0],
NodeUnitIODef::QuantParam{*input_defs[3],
input_defs.size() > 4 ? input_defs[4] : nullptr}});
} else if (IsBinaryQLinearOp(qlinear_type)) {
// Binary QLinear Op has 9 inputs
// x1, x1_scale, x1_zp, x2/w, x2_scale, x2_zp, y_scale , y_zp, B
inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}});
inputs_.push_back(NodeUnitIODef{*input_defs[3], NodeUnitIODef::QuantParam{*input_defs[4], input_defs[5]}});
if (input_defs.size() == 9) { // has Bias
inputs_.push_back(NodeUnitIODef{*input_defs[8], std::nullopt}); // for Bias the scale and zp are optional
}
outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[6], input_defs[7]}});
} else if (qlinear_type == QLinearOpType::DequantizeLinear) {
// DequantizeLinear has 3 inputs
// x, x_scale, x_zp
// output is not quantized
inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3
? input_defs[2]
: nullptr}});
outputs_.push_back(NodeUnitIODef{*output_defs[0], std::nullopt});
} else if (qlinear_type == QLinearOpType::QuantizeLinear) {
// QuantizeLinear the input is not quantized and has 3 inputs
// x, y_scale, y_zp (optional)
// The output is quantized
inputs_.push_back(NodeUnitIODef{*input_defs[0], std::nullopt});
outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3
? input_defs[2]
: nullptr}});
} else {
ORT_THROW("The QLinear op [", static_cast<uint8_t>(qlinear_type), "] is not supported");
}
}
Node::EdgeConstIterator NodeUnit::OutputEdgesBegin() const {
return (type_ == Type::SingleNode) ? target_node_.OutputEdgesBegin() : output_edges_.begin();
}
Node::EdgeConstIterator NodeUnit::OutputEdgesEnd() const {
return (type_ == Type::SingleNode) ? target_node_.OutputEdgesEnd() : output_edges_.end();
}
std::vector<const Node*> NodeUnit::GetAllNodesInGroup() const noexcept {
std::vector<const Node*> all_nodes = dq_nodes_;
all_nodes.push_back(&target_node_);
all_nodes.insert(all_nodes.end(), q_nodes_.begin(), q_nodes_.end());
return all_nodes;
}
} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

View file

@ -3,6 +3,9 @@
#pragma once
// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#include <string>
#include <optional>
#include <vector>
@ -18,8 +21,21 @@ class NodeArg;
class Path;
namespace QDQ {
struct NodeGroup;
}
// Struct to represent a DequantizeLinear -> Op -> QuantizeLinear node group
struct NodeGroup {
std::vector<NodeIndex> dq_nodes;
std::vector<NodeIndex> q_nodes;
NodeIndex target_node;
// Validator to check if the set of nodes can form a valid QDQ NodeGroup.
// Checks target node is only consumer of each DQ, and that the outputs remain valid if the QDQ node group was to
// be converted into a single node with a quantized operator.
static Status CanCreateNodeGroup(const GraphViewer& graph_viewer,
const Node& target_node,
gsl::span<const Node* const> dq_nodes,
gsl::span<const Node* const> q_nodes);
};
} // namespace QDQ
// Definition of one input or output
// If the optional quant_param is present, then this is a quantized input,
@ -69,26 +85,33 @@ class NodeUnit {
const std::vector<const Node*>& GetQNodes() const noexcept { return q_nodes_; }
std::vector<const Node*> GetAllNodesInGroup() const noexcept;
Node::EdgeConstIterator OutputEdgesBegin(size_t index) const;
Node::EdgeConstIterator OutputEdgesEnd(size_t index) const;
/// Number of input edges to the logical node. For a QDQ node this is the count of input edges to the DQ nodes
/// plus any other edges to the target node for inputs that are not via a DQ node.
size_t InputEdgeCount() const { return input_edge_count_; }
// output edges. src index is for outputs of the target node. dest index and node is for consumer of node unit
// output. any Q nodes are hidden.
Node::EdgeConstIterator OutputEdgesBegin() const;
Node::EdgeConstIterator OutputEdgesEnd() const;
private:
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit
const std::vector<const Node*> dq_nodes_; // dq nodes for this NodeUnit, not all inputs
// Initialization for a NodeUnit that contains a single node
void InitForSingleNode();
const std::vector<const Node*> dq_nodes_; // dq nodes for this NodeUnit, not necessarily all inputs
const Node& target_node_;
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs
const Type type_;
std::vector<NodeUnitIODef> inputs_;
std::vector<NodeUnitIODef> outputs_;
// Initializing for a single Node
void InitForSingleNode();
size_t input_edge_count_; // total number of input edges
// output edges, hiding any Q nodes involved. src_idx will be value from target node. only used for QDQ node group.
Node::EdgeSet output_edges_;
};
// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup)
// And return a map to quick query the NodeUnit which contains the given Node,
// Note, the value of the map is owned by the vector of std::unique_ptr<NodeUnit>
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
GetAllNodeUnits(const GraphViewer& graph_viewer);
} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

View file

@ -58,8 +58,8 @@ bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Nod
return false;
}
if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes);
!dq_validation_status.IsOK()) {
if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes);
!qdq_validation_status.IsOK()) {
return false;
}
@ -153,8 +153,8 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return false;
}
if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes);
!dq_validation_status.IsOK()) {
if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes);
!qdq_validation_status.IsOK()) {
return false;
}
@ -544,8 +544,8 @@ bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return false;
}
if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes);
!dq_validation_status.IsOK()) {
if (const auto qdq_validation_status = QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes);
!qdq_validation_status.IsOK()) {
return false;
}

View file

@ -5,6 +5,7 @@
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#include "core/framework/node_unit.h"
#include "core/optimizer/selectors_actions/selector_action_transformer.h"
namespace onnxruntime {
@ -13,13 +14,6 @@ class Node;
namespace QDQ {
// Struct to represent a DQ->Op->Q node group
struct NodeGroup {
std::vector<NodeIndex> dq_nodes;
std::vector<NodeIndex> q_nodes;
NodeIndex target_node;
};
class NodeGroupSelector {
public:
// This is a QDQ Selectors only function, will return QDQ::NodeGroup instead of NodesToOptimizeIndices

View file

@ -13,6 +13,7 @@
#include <core/providers/common.h>
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
namespace onnxruntime {
namespace QDQ {
@ -43,6 +44,7 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() {
{"Tile", {}}};
}
// These produce int64 indices output, which can't be quantized, so there's no downstream Q node.
static const OpVersionsAndSelector::OpVersionsMap GetDropDQOpVersionsMap() {
return {{"ArgMax", {}},
{"ArgMin", {}}};
@ -324,28 +326,48 @@ std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& grap
return qdq_selections;
}
Status ValidateNodeGroupDQNodes(const GraphViewer& graph_viewer,
const Node& target_node,
gsl::span<const Node* const> dq_nodes) {
// Within a QDQ node group, a target node input is the only consumer of each DQ.
// This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications
// may have happened since. Verify that this is still true.
for (const auto* dq_node : dq_nodes) {
const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node);
ORT_RETURN_IF(dq_produces_graph_output,
"QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(),
", target node: ", target_node.Name());
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
GetAllNodeUnits(const GraphViewer& graph_viewer) {
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
const bool dq_has_single_output_edge_to_target =
dq_node->GetOutputEdgesCount() == 1 &&
dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index();
ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target,
"QDQ node group cannot have DQ that doesn't have a single output edge to the target node. "
"DQ node: ",
dq_node->Name(), ", target node: ", target_node.Name());
const auto add_node_unit_to_map = [&](const std::vector<NodeIndex>& node_indices, const NodeUnit* node_unit) {
for (const auto& node_idx : node_indices) {
const auto* node = graph_viewer.GetNode(node_idx);
node_unit_map.insert({node, node_unit});
}
};
// Get QDQ NodeUnits first
QDQ::SelectorManager selector_mgr;
const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer);
for (const auto& qdq_selection : qdq_selections) {
auto qdq_unit = std::make_unique<NodeUnit>(graph_viewer, qdq_selection);
// Fill the node to node_unit map for all nodes in the QDQ Group
add_node_unit_to_map(qdq_selection.dq_nodes, qdq_unit.get());
add_node_unit_to_map(qdq_selection.q_nodes, qdq_unit.get());
add_node_unit_to_map({qdq_selection.target_node}, qdq_unit.get());
node_unit_holder.push_back(std::move(qdq_unit));
}
return Status::OK();
// Get the left over SingleNode NodeUnits
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (const auto node_idx : node_indices) {
const auto* node(graph_viewer.GetNode(node_idx));
// This is already part of a QDQ NodeUnit
if (node_unit_map.find(node) != node_unit_map.cend())
continue;
auto node_unit = std::make_unique<NodeUnit>(*node);
node_unit_map[node] = node_unit.get();
node_unit_holder.push_back(std::move(node_unit));
}
return std::make_pair(std::move(node_unit_holder), std::move(node_unit_map));
}
} // namespace QDQ

View file

@ -7,6 +7,7 @@
#include "core/common/common.h"
#include "core/common/gsl.h"
#include "core/common/inlined_containers.h"
#include "core/framework/node_unit.h"
#include "core/graph/basic_types.h"
#if !defined(ORT_MINIMAL_BUILD)
@ -78,11 +79,16 @@ class SelectorManager {
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SelectorManager);
};
// Checks whether the provided DQ nodes are valid for forming a QDQ node group with the provided target node.
// Returns successful status if so, failed status with reason otherwise.
Status ValidateNodeGroupDQNodes(const GraphViewer& graph_viewer,
const Node& target_node,
gsl::span<const Node* const> dq_nodes);
// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup)
// And return a map to quick query the NodeUnit which contains the given Node,
// Note, the value of the map is owned by the vector of std::unique_ptr<NodeUnit>
//
// TODO: The overall QDQ setup needs refactoring to separate out generic functionality from optimizer specific
// functionality.
// We currently have a bit of a mess with generic things like this to get all the node units being in the optimizer
// library whereas it should be able to be used by an EP with no dependency on optimizers.
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
GetAllNodeUnits(const GraphViewer& graph_viewer);
} // namespace QDQ
} // namespace onnxruntime

View file

@ -21,7 +21,6 @@
#include "core/framework/kernel_registry.h"
#include "core/graph/function_utils.h"
#include "core/graph/indexed_sub_graph.h"
#include "core/providers/shared/node_unit/node_unit.h"
#include "data_transfer.h"
namespace onnxruntime {

View file

@ -11,6 +11,7 @@
#include "core/common/logging/logging.h"
#include "core/common/safeint.h"
#include "core/framework/node_unit.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/graph.h"
@ -18,7 +19,6 @@
#include "core/providers/common.h"
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h"
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h"
#include "core/providers/shared/node_unit/node_unit.h"
#include "core/providers/shared/utils/utils.h"
namespace onnxruntime {

View file

@ -4,7 +4,7 @@
#pragma once
#include "core/common/common.h"
#include "core/providers/shared/node_unit/node_unit.h"
#include "core/framework/node_unit.h"
#include "core/providers/nnapi/nnapi_builtin/builders/model_builder.h"
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h"
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h"

View file

@ -11,17 +11,19 @@
#include "core/common/safeint.h"
#include "core/common/status.h"
#include "core/framework/execution_provider.h"
#include "core/framework/node_unit.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/graph_viewer.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/providers/common.h"
#include "core/providers/nnapi/nnapi_builtin/nnapi_api_helper.h"
#include "core/providers/shared/node_unit/node_unit.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/nnapi/nnapi_builtin/builders/helper.h"
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h"
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h"
#include "core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h"
#include "core/optimizer/initializer.h"
#include "core/providers/shared/utils/utils.h"
using namespace android::nn::wrapper;
@ -119,7 +121,7 @@ const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const {
}
void ModelBuilder::PreprocessNodeUnits() {
std::tie(node_unit_holder_, node_unit_map_) = GetAllNodeUnits(graph_viewer_);
std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_);
}
// Help to get all quantized operators' input and the NodeUnit(s) using the input
@ -664,7 +666,7 @@ int32_t ModelBuilder::FindActivation(const NodeUnit& node_unit) {
int32_t fuse_code = ANEURALNETWORKS_FUSED_NONE;
bool fuse_code_assigned_from_activation = false;
for (auto it = node_unit.OutputEdgesBegin(0), end = node_unit.OutputEdgesEnd(0); it != end; ++it) {
for (auto it = node_unit.OutputEdgesBegin(), end = node_unit.OutputEdgesEnd(); it != end; ++it) {
const auto& dst_node = it->GetNode();
const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()];

View file

@ -21,7 +21,6 @@
#include "core/optimizer/initializer.h"
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/shared/node_unit/node_unit.h"
#include "core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h"
namespace onnxruntime::nnapi::op_builder_helpers {

View file

@ -7,12 +7,12 @@
#include <vector>
#include "core/common/common.h"
#include "core/framework/node_unit.h"
#include "core/providers/common.h"
#include "core/providers/nnapi/nnapi_builtin/builders/helper.h"
#include "core/providers/nnapi/nnapi_builtin/builders/model_builder.h"
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h"
#include "core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksWrapper.h"
#include "core/providers/shared/node_unit/node_unit.h"
namespace onnxruntime::nnapi::op_builder_helpers {

View file

@ -7,7 +7,10 @@
#include "core/common/logging/logging.h"
#include "core/common/string_utils.h"
#include "core/framework/compute_capability.h"
#include "core/framework/node_unit.h"
#include "core/graph/graph_viewer.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
#include "core/platform/env.h"
#include "core/providers/common.h"
#include "core/providers/nnapi/nnapi_builtin/builders/helper.h"
@ -17,7 +20,6 @@
#include "core/providers/nnapi/nnapi_builtin/nnapi_api_helper.h"
#include "core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h"
#include "core/providers/partitioning_utils.h"
#include "core/providers/shared/node_unit/node_unit.h"
#include "core/session/onnxruntime_cxx_api.h"
namespace onnxruntime {
@ -119,7 +121,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer);
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
// This holds the result of whether a NodeUnit is supported or not,
// to prevent nodes in a NodeUnit to be checked for multiple times
@ -181,7 +183,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
};
result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed,
gen_metadef_name, NNAPI, kNnapiExecutionProvider);
gen_metadef_name, NNAPI, kNnapiExecutionProvider, &node_unit_map);
// Generally, NNAPI supports sub-graphs with at least one non-constant initializer input and one output.
// So far, we have a few cases that sub-graph has zero valid inputs, like `CastLike`

View file

@ -1,6 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#include "core/providers/partitioning_utils.h"
#include <algorithm>
@ -10,6 +13,7 @@
#include "core/framework/compute_capability.h"
#include "core/framework/execution_provider.h"
#include "core/framework/node_unit.h"
#include "core/graph/graph_viewer.h"
#include "core/providers/common.h"
@ -76,6 +80,11 @@ When selecting the next node to process, we first take:
The remaining unsupported nodes mark the border of the current group so they will be processed later when we consider
the next group.
If node_unit_map is provided, we process NodeUnit instances (a logical 'Node' that can be a single node or a
QDQ node group) instead of individual Node instances. As an EP must take complete NodeUnit instances (i.e. it
must not break up a QDQ node group by taking a subset of nodes in it), this granularity of processing is valid.
It is required to ensure we do not break up a QDQ node unit during partitioning.
@param graph_viewer GraphViewer that IExecutionProvider::GetCapability is called with.
@param is_node_supported_fn Callback to check whether a node is supported.
@param on_group_closed_fn Callback to indicate a completed partition node group.
@ -88,6 +97,7 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
const IsNodeSupportedFn& is_node_supported_fn,
const OnGroupClosedFn& on_group_closed_fn,
const std::string& execution_provider_type,
const std::unordered_map<const Node*, const NodeUnit*>* node_unit_map,
bool debug_output) {
#ifdef NDEBUG
ORT_UNUSED_PARAMETER(debug_output);
@ -111,7 +121,18 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
// initialize in-degrees and find root nodes
for (const auto& node_index : graph_viewer.GetNodesInTopologicalOrder()) {
const auto& node = *graph_viewer.GetNode(node_index);
const auto node_input_edge_count = node.GetInputEdgesCount();
auto node_input_edge_count = node.GetInputEdgesCount();
if (node_unit_map != nullptr) {
const auto& node_unit = node_unit_map->at(&node);
if (&node_unit->GetNode() != &node) {
// only process the target node
continue;
}
node_input_edge_count = node_unit->InputEdgeCount();
}
in_degree.insert({node.Index(), node_input_edge_count});
if (node_input_edge_count == 0) {
nodes_to_process.push_back(&node);
@ -151,6 +172,8 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
}
};
size_t num_nodes_processed = 0;
while (!nodes_to_process.empty() || !nodes_to_process_with_next_group.empty()) {
if (nodes_to_process.empty()) {
// we have processed all the nodes that we can while building this partition node group, start a new one
@ -162,9 +185,13 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
const Node& node = *nodes_to_process.front();
nodes_to_process.pop_front();
const NodeUnit* node_unit = node_unit_map ? node_unit_map->at(&node) : nullptr;
const bool is_qdq_node_unit = node_unit && node_unit->UnitType() == NodeUnit::Type::QDQGroup;
// a node that is already assigned to an EP other than current EP is unsupported
const bool is_node_supported =
(node.GetExecutionProviderType().empty() || node.GetExecutionProviderType() == execution_provider_type) && is_node_supported_fn(node);
const bool is_node_supported = (node.GetExecutionProviderType().empty() ||
node.GetExecutionProviderType() == execution_provider_type) &&
is_node_supported_fn(node);
if (!is_node_supported && Contains(supported_group_border, &node)) {
// an unsupported node on the border will be processed after the current partition node group
@ -173,34 +200,62 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
}
if (is_node_supported) {
// add node to the partition node group
supported_group.push_back(&node);
if (is_qdq_node_unit) {
// add DQ -> node -> Q for the node unit. must be in topological order
for (const auto& dq : node_unit->GetDQNodes()) {
supported_group.push_back(dq);
}
// remove node from the border and add its outputs to the border
supported_group.push_back(&node);
for (const auto& q : node_unit->GetQNodes()) {
supported_group.push_back(q);
}
} else {
supported_group.push_back(&node);
}
// remove node from the border
supported_group_border.erase(&node);
std::for_each(
node.OutputNodesBegin(), node.OutputNodesEnd(),
[&supported_group_border](const Node& output) {
supported_group_border.insert(&output);
});
}
// adjust in-degrees of the node outputs and add any new nodes to process
std::for_each(
node.OutputNodesBegin(), node.OutputNodesEnd(),
[&](const Node& output) {
auto& output_node_in_degree = in_degree[output.Index()];
--output_node_in_degree;
// For each downstream node:
// 1: add the downstream node to the border if the current node is supported
// 2: adjust in-degrees of the nodes consuming the current node's outputs, and add any new nodes to process
const auto process_downstream_node = [&](const Node& downstream_node) {
if (is_node_supported) {
supported_group_border.insert(&downstream_node);
}
if (output_node_in_degree == 0) {
nodes_to_process.push_back(&output);
}
});
auto& downstream_node_in_degree = in_degree[downstream_node.Index()];
--downstream_node_in_degree;
if (downstream_node_in_degree == 0) {
nodes_to_process.push_back(&downstream_node);
}
};
if (node_unit_map) {
std::for_each(node_unit->OutputEdgesBegin(), node_unit->OutputEdgesEnd(),
[&](const Node::EdgeEnd& edge_end) {
const Node& n = edge_end.GetNode();
const NodeUnit& downstream_node_unit = *node_unit_map->at(&n);
const Node& output = downstream_node_unit.GetNode();
process_downstream_node(output);
});
} else {
std::for_each(node.OutputNodesBegin(), node.OutputNodesEnd(), process_downstream_node);
}
++num_nodes_processed;
}
close_group();
ORT_ENFORCE(num_nodes_processed == in_degree.size(),
"Processed ", num_nodes_processed, " nodes. Expected to process ", in_degree.size());
return supported_groups;
}
} // namespace
@ -318,11 +373,13 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
const GenerateMetadefNameFn& generate_metadef_name_fn,
const std::string& execution_provider_name,
const std::string& execution_provider_type,
const std::unordered_map<const Node*, const NodeUnit*>* node_unit_map,
bool debug_output) {
const auto groups = CreateSupportedPartitionNodeGroups(graph_viewer,
is_node_supported_fn,
on_partition_closed_fn,
execution_provider_type,
node_unit_map,
debug_output);
std::vector<std::unique_ptr<ComputeCapability>> partitions{};
@ -346,6 +403,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
const GenerateMetadefNameFn& generate_metadef_name_fn,
const std::string& execution_provider_name,
const std::string& execution_provider_type,
const std::unordered_map<const Node*, const NodeUnit*>* node_unit_map,
bool debug_output) {
const auto excluded_nodes = CreateExcludedNodeSet(graph_viewer, stop_ops);
const bool check_excluded_nodes = !excluded_nodes.empty();
@ -360,8 +418,11 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
generate_metadef_name_fn,
execution_provider_name,
execution_provider_type,
node_unit_map,
debug_output);
}
} // namespace utils
} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

View file

@ -3,6 +3,9 @@
#pragma once
// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#include <functional>
#include <memory>
#include <unordered_set>
@ -14,8 +17,9 @@
namespace onnxruntime {
struct ComputeCapability;
class GraphViewer;
class NodeArg;
class Node;
class NodeArg;
class NodeUnit;
namespace utils {
@ -56,6 +60,8 @@ Create the supported partitions for the execution provider.
@param generate_metadef_name_fn Callback to create the name for the MetaDef.
@param execution_provider_name Name of execution provider creating the ComputeCapability instance.
@param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance.
@param node_unit_map Map of each Node in the graph_viewer to its NodeUnit. Provide if EP handles QDQ format models.
Should be created by EP calling GetAllNodeUnits.
@param debug_output Print diagnostic output about the partitions and reasons for partition breaks.
No-op in a release build.
@ -68,6 +74,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
const GenerateMetadefNameFn& generate_metadef_name_fn,
const std::string& execution_provider_name,
const std::string& execution_provider_type,
const std::unordered_map<const Node*, const NodeUnit*>* node_unit_map = nullptr,
bool debug_output = false);
/**
@ -79,6 +86,8 @@ Create the supported partitions for the execution provider.
@param generate_metadef_name Functor to create the name for the MetaDef.
@param execution_provider_name Name of execution provider creating the ComputeCapability instance.
@param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance.
@param node_unit_map Map of each Node in the graph_viewer to its NodeUnit. Provide if EP handles QDQ format models.
Should be created by EP calling GetAllNodeUnits.
@param debug_output Print diagnostic output about the partitions and reasons for partition breaks.
No-op in a release build.
@ -91,6 +100,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
const GenerateMetadefNameFn& generate_metadef_name,
const std::string& execution_provider_name,
const std::string& execution_provider_type,
const std::unordered_map<const Node*, const NodeUnit*>* node_unit_map = nullptr,
bool debug_output = false);
/**
@ -125,3 +135,5 @@ InlinedHashSet<const Node*> CreateExcludedNodeSet(const GraphViewer& graph_viewe
const std::unordered_set<std::string>& stop_ops);
} // namespace utils
} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

View file

@ -4,7 +4,7 @@
#pragma once
#include "core/graph/graph_viewer.h"
#include "core/providers/shared/node_unit/node_unit.h"
#include "core/framework/node_unit.h"
#include "core/providers/shared/utils/utils.h"
namespace onnxruntime {

View file

@ -9,6 +9,8 @@
#include "core/providers/qnn/builder/op_builder_factory.h"
#include "core/providers/shared/utils/utils.h"
#include "core/framework/utils.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
#include "core/providers/qnn/builder/qnn_utils.h"
namespace onnxruntime {
@ -95,7 +97,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
// valid throughout the lifetime of the ModelBuilder
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer);
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
// This name must be same with the EPContext node name
const auto& graph_name = fused_node.Name();

View file

@ -6,13 +6,13 @@
#include <vector>
#include "core/common/status.h"
#include "core/framework/node_unit.h"
#include "core/graph/graph_viewer.h"
#include "core/platform/ort_mutex.h"
#include "core/providers/qnn/builder/qnn_def.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
#include "core/providers/qnn/builder/qnn_backend_manager.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/providers/shared/node_unit/node_unit.h"
namespace onnxruntime {
namespace qnn {

View file

@ -11,8 +11,8 @@
#include "QnnInterface.h"
#include "qnn_def.h"
#include "core/common/logging/logging.h"
#include "core/framework/node_unit.h"
#include "core/graph/graph_viewer.h"
#include "core/providers/shared/node_unit/node_unit.h"
#include "core/providers/shared/utils/utils.h"
namespace onnxruntime {

View file

@ -10,6 +10,8 @@
#include "core/session/onnxruntime_run_options_config_keys.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/framework/kernel_registry.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
#include "core/platform/env.h"
#include "core/providers/common.h"
#include "core/providers/partitioning_utils.h"
@ -494,7 +496,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer);
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, node_unit_holder.size(),
is_qnn_ctx_model, logger);
@ -534,44 +536,39 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
size_t num_of_supported_nodes = 0;
// Create partitions from supported nodes.
{
std::vector<std::unique_ptr<ComputeCapability>> partitions = utils::CreateSupportedPartitions(graph_viewer,
supported_nodes, {},
gen_metadef_name, QNN,
kQnnExecutionProvider,
true);
std::vector<std::unique_ptr<ComputeCapability>> partitions = utils::CreateSupportedPartitions(
graph_viewer, supported_nodes, {}, gen_metadef_name, QNN, kQnnExecutionProvider, &node_unit_map, true);
// Filter out partitions that consist of a single QuantizeLinear or DequantizeLinear node.
// We also count the number of supported nodes in all valid partitions.
for (auto& partition : partitions) {
bool is_valid_partition = true;
size_t nodes_in_partition = 0;
// Filter out partitions that consist of a single QuantizeLinear or DequantizeLinear node.
// We also count the number of supported nodes in all valid partitions.
for (auto& partition : partitions) {
bool is_valid_partition = true;
size_t nodes_in_partition = 0;
if (partition && partition->sub_graph) {
nodes_in_partition = partition->sub_graph->nodes.size();
if (partition && partition->sub_graph) {
nodes_in_partition = partition->sub_graph->nodes.size();
if (nodes_in_partition == 1 && !is_qnn_ctx_model) {
const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]);
if (nodes_in_partition == 1 && !is_qnn_ctx_model) {
const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]);
if (!node) {
LOGS(logger, ERROR) << "QNN EP: Invalid node in partition of one node.";
is_valid_partition = false;
} else if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") {
LOGS(logger, WARNING) << "QNN EP does not support a single Quantize/Dequantize node in a partition.";
is_valid_partition = false;
}
if (!node) {
LOGS(logger, ERROR) << "QNN EP: Invalid node in partition of one node.";
is_valid_partition = false;
} else if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") {
LOGS(logger, WARNING) << "QNN EP does not support a single Quantize/Dequantize node in a partition.";
is_valid_partition = false;
}
} else {
LOGS(logger, ERROR) << "QNN EP: Invalid partition.";
is_valid_partition = false;
}
} else {
LOGS(logger, ERROR) << "QNN EP: Invalid partition.";
is_valid_partition = false;
}
if (is_valid_partition) {
result.push_back(std::move(partition));
num_of_supported_nodes += nodes_in_partition;
}
} // for
}
if (is_valid_partition) {
result.push_back(std::move(partition));
num_of_supported_nodes += nodes_in_partition;
}
} // for
const size_t num_of_partitions = result.size();
const auto summary_msg = MakeString("Number of partitions supported by QNN EP: ", num_of_partitions,

View file

@ -1,319 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "node_unit.h"
#include "core/graph/graph_viewer.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
namespace onnxruntime {
namespace {
enum class QLinearOpType : uint8_t {
Unknown, // Unknown or not a linear quantized op
DequantizeLinear,
QuantizeLinear,
QLinearConv,
QLinearMatMul,
QLinearAdd,
QLinearSigmoid,
QLinearAveragePool,
QLinearMul,
QLinearReduceMean,
QLinearConcat,
QLinearGlobalAveragePool,
QLinearLeakyRelu,
};
QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) {
const auto& op_type = node.OpType();
if (op_type == "DequantizeLinear")
return QLinearOpType::DequantizeLinear;
else if (op_type == "QuantizeLinear")
return QLinearOpType::QuantizeLinear;
else if (op_type == "QLinearConv")
return QLinearOpType::QLinearConv;
else if (op_type == "QLinearMatMul")
return QLinearOpType::QLinearMatMul;
else if (op_type == "QLinearAdd")
return QLinearOpType::QLinearAdd;
else if (op_type == "QLinearSigmoid")
return QLinearOpType::QLinearSigmoid;
else if (op_type == "QLinearAveragePool")
return QLinearOpType::QLinearAveragePool;
else if (op_type == "QLinearMul")
return QLinearOpType::QLinearMul;
else if (op_type == "QLinearReduceMean")
return QLinearOpType::QLinearReduceMean;
else if (op_type == "QLinearConcat")
return QLinearOpType::QLinearConcat;
else if (op_type == "QLinearGlobalAveragePool")
return QLinearOpType::QLinearGlobalAveragePool;
else if (op_type == "QLinearLeakyRelu")
return QLinearOpType::QLinearLeakyRelu;
return QLinearOpType::Unknown;
}
// Ops have 1 input
bool IsUnaryQLinearOp(QLinearOpType type) {
return type == QLinearOpType::QLinearSigmoid ||
type == QLinearOpType::QLinearAveragePool ||
type == QLinearOpType::QLinearGlobalAveragePool ||
type == QLinearOpType::QLinearLeakyRelu ||
type == QLinearOpType::QLinearReduceMean;
}
// Ops have 2 inputs
bool IsBinaryQLinearOp(QLinearOpType type) {
return type == QLinearOpType::QLinearConv ||
type == QLinearOpType::QLinearMatMul ||
type == QLinearOpType::QLinearAdd ||
type == QLinearOpType::QLinearMul;
}
// Ops have 1 or more inputs
bool IsVariadicQLinearOp(QLinearOpType type) {
return type == QLinearOpType::QLinearConcat;
}
const std::vector<const Node*> GetQDQIONodes(const GraphViewer& graph_viewer,
const QDQ::NodeGroup& node_group, bool is_input) {
std::vector<const Node*> io_nodes;
const auto& src_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes;
io_nodes.reserve(src_nodes.size());
for (const auto& node_idx : src_nodes) {
io_nodes.push_back(graph_viewer.GetNode(node_idx));
}
return io_nodes;
}
// Get the input or output NodeUnitIODef(s) for the given QDQ NodeGroup
std::vector<NodeUnitIODef> GetQDQIODefs(const Node& target_node, const QDQ::NodeGroup& node_group,
bool is_input) {
const auto& dq_or_q_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes;
const auto target_node_io_defs = is_input ? target_node.InputDefs() : target_node.OutputDefs();
const size_t target_node_io_defs_size = target_node_io_defs.size();
// Find all the quantized IO defs and indices (for the input to the target node)
std::unordered_map<size_t, NodeUnitIODef> quantized_io_defs;
quantized_io_defs.reserve(target_node_io_defs_size);
auto cur = is_input ? target_node.InputEdgesBegin() : target_node.OutputEdgesBegin();
auto end = is_input ? target_node.InputEdgesEnd() : target_node.OutputEdgesEnd();
for (; cur != end; ++cur) {
const Node& node = cur->GetNode();
// If we can find the node index in the dq or q nodes, then this is a quantize node (can be DQ or Q depends on is_input)
if (std::find(dq_or_q_nodes.cbegin(), dq_or_q_nodes.cend(), node.Index()) != dq_or_q_nodes.cend()) {
const auto node_inputs = node.InputDefs();
// quantization scale and zp are always the input[1, 2]
NodeUnitIODef::QuantParam quant_param{
*node_inputs[1],
node_inputs.size() == 3 ? node_inputs[2] : nullptr};
if (is_input) {
// DQ is input to the target node, use the DstArgIndex
auto idx = cur->GetDstArgIndex();
// This is a DQ node, we are using x, x_scale, x_zp (input[0, 1, 2])
quantized_io_defs.insert({idx, NodeUnitIODef{*node_inputs[0], quant_param}});
} else {
// Q is output of the target node, use the SrcArgIndex
auto idx = cur->GetSrcArgIndex();
// This is a Q node, we are using y (output[0]), y_scale, y_zp (input[1, 2])
const auto node_outputs = node.OutputDefs();
quantized_io_defs.insert({idx, NodeUnitIODef{*node_outputs[0], quant_param}});
}
}
}
// Construct the IODefs for this QDQ NodeGroup
std::vector<NodeUnitIODef> io_defs;
io_defs.reserve(target_node_io_defs_size);
for (size_t i = 0; i < target_node_io_defs_size; i++) {
// If we can find the NodeUnitIODef for this index, this is a quantized input
if (quantized_io_defs.find(i) != quantized_io_defs.cend()) {
io_defs.push_back(std::move(quantized_io_defs.at(i)));
} else {
// This is a regular input
io_defs.push_back({*target_node_io_defs[i], std::nullopt});
}
}
return io_defs;
}
} // namespace
NodeUnit::NodeUnit(const Node& node)
: target_node_(node),
type_(Type::SingleNode) {
InitForSingleNode();
}
NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group)
: q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)},
dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)},
target_node_(*graph_viewer.GetNode(node_group.target_node)),
type_(Type::QDQGroup),
inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)},
outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} {
ORT_THROW_IF_ERROR(QDQ::ValidateNodeGroupDQNodes(graph_viewer, target_node_, dq_nodes_));
}
const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); }
const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); }
const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); }
int NodeUnit::SinceVersion() const noexcept { return target_node_.SinceVersion(); }
NodeIndex NodeUnit::Index() const noexcept { return target_node_.Index(); }
const Path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); }
ProviderType NodeUnit::GetExecutionProviderType() const noexcept { return target_node_.GetExecutionProviderType(); }
void NodeUnit::InitForSingleNode() {
const auto& input_defs = target_node_.InputDefs();
const auto& output_defs = target_node_.OutputDefs();
auto qlinear_type = GetQLinearOpType(target_node_);
if (qlinear_type == QLinearOpType::Unknown ||
IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support
// Not a Qlinear op, add all inputs / outputs
auto add_all_io = [](std::vector<NodeUnitIODef>& defs,
const ConstPointerContainer<std::vector<NodeArg*>>& node_defs) {
defs.reserve(node_defs.size());
for (const auto def : node_defs) {
defs.push_back(NodeUnitIODef{*def, std::nullopt});
}
};
add_all_io(inputs_, input_defs);
add_all_io(outputs_, output_defs);
} else if (IsUnaryQLinearOp(qlinear_type)) {
// Unary QLinear Op has 5 inputs
// x, x_scale, x_zp, y_scale, y_zp (optional)
inputs_.push_back(NodeUnitIODef{
*input_defs[0],
NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}});
outputs_.push_back(NodeUnitIODef{
*output_defs[0],
NodeUnitIODef::QuantParam{*input_defs[3],
input_defs.size() > 4
? input_defs[4]
: nullptr}});
} else if (IsBinaryQLinearOp(qlinear_type)) {
// Binary QLinear Op has 9 inputs
// x1, x1_scale, x1_zp, x2/w, x2_scale, x2_zp, y_scale , y_zp, B
inputs_.push_back(NodeUnitIODef{
*input_defs[0],
NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}});
inputs_.push_back(NodeUnitIODef{
*input_defs[3],
NodeUnitIODef::QuantParam{*input_defs[4], input_defs[5]}});
if (input_defs.size() == 9) { // has Bias
inputs_.push_back(NodeUnitIODef{
*input_defs[8],
std::nullopt}); // for Bias the scale and zp are optional
}
outputs_.push_back(NodeUnitIODef{
*output_defs[0],
NodeUnitIODef::QuantParam{*input_defs[6], input_defs[7]}});
} else if (qlinear_type == QLinearOpType::DequantizeLinear) {
// DequantizeLinear has 3 inputs
// x, x_scale, x_zp
// output is not quantized
inputs_.push_back(NodeUnitIODef{
*input_defs[0],
NodeUnitIODef::QuantParam{*input_defs[1],
input_defs.size() == 3
? input_defs[2]
: nullptr}});
outputs_.push_back(NodeUnitIODef{*output_defs[0], std::nullopt});
} else if (qlinear_type == QLinearOpType::QuantizeLinear) {
// QuantizeLinear the input is not quantized and has 3 inputs
// x, y_scale, y_zp (optional)
// The output is quantized
inputs_.push_back(NodeUnitIODef{*input_defs[0], std::nullopt});
outputs_.push_back(NodeUnitIODef{
*output_defs[0],
NodeUnitIODef::QuantParam{*input_defs[1],
input_defs.size() == 3
? input_defs[2]
: nullptr}});
} else {
ORT_THROW("The QLinear op [", static_cast<uint8_t>(qlinear_type), "] is not supported");
}
}
Node::EdgeConstIterator NodeUnit::OutputEdgesBegin(size_t index) const {
if (type_ == Type::SingleNode) {
ORT_ENFORCE(index == 0, "invalid output node index");
return target_node_.OutputEdgesBegin();
} else {
ORT_ENFORCE(index < q_nodes_.size(), "invalid output node index");
return q_nodes_[index]->OutputEdgesBegin();
}
}
Node::EdgeConstIterator NodeUnit::OutputEdgesEnd(size_t index) const {
if (type_ == Type::SingleNode) {
ORT_ENFORCE(index == 0, "invalid output node index");
return target_node_.OutputEdgesEnd();
} else {
ORT_ENFORCE(index < q_nodes_.size(), "invalid output node index");
return q_nodes_[index]->OutputEdgesEnd();
}
}
std::vector<const Node*> NodeUnit::GetAllNodesInGroup() const noexcept {
std::vector<const Node*> all_nodes = dq_nodes_;
all_nodes.push_back(&target_node_);
all_nodes.insert(all_nodes.end(), q_nodes_.begin(), q_nodes_.end());
return all_nodes;
}
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
GetAllNodeUnits(const GraphViewer& graph_viewer) {
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
const auto add_node_unit_to_map = [&](const std::vector<NodeIndex>& node_indices, const NodeUnit* node_unit) {
for (const auto& node_idx : node_indices) {
const auto* node = graph_viewer.GetNode(node_idx);
node_unit_map.insert({node, node_unit});
}
};
// Get QDQ NodeUnits first
QDQ::SelectorManager selector_mgr;
const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer);
for (const auto& qdq_selection : qdq_selections) {
auto qdq_unit = std::make_unique<NodeUnit>(graph_viewer, qdq_selection);
// Fill the node to node_unit map for all nodes in the QDQ Group
add_node_unit_to_map(qdq_selection.dq_nodes, qdq_unit.get());
add_node_unit_to_map(qdq_selection.q_nodes, qdq_unit.get());
add_node_unit_to_map({qdq_selection.target_node}, qdq_unit.get());
node_unit_holder.push_back(std::move(qdq_unit));
}
// Get the left over SingleNode NodeUnits
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (const auto node_idx : node_indices) {
const auto* node(graph_viewer.GetNode(node_idx));
// This is already part of a QDQ NodeUnit
if (node_unit_map.find(node) != node_unit_map.cend())
continue;
auto node_unit = std::make_unique<NodeUnit>(*node);
node_unit_map[node] = node_unit.get();
node_unit_holder.push_back(std::move(node_unit));
}
return std::make_pair(std::move(node_unit_holder), std::move(node_unit_map));
}
} // namespace onnxruntime

View file

@ -4,12 +4,12 @@
#include "utils.h"
#include <core/common/safeint.h>
#include <core/framework/tensorprotoutils.h>
#include <core/graph/graph.h>
#include <core/providers/common.h>
#include "core/providers/shared/node_unit/node_unit.h"
#include "core/common/safeint.h"
#include "core/framework/node_unit.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/graph.h"
#include "core/optimizer/initializer.h"
#include "core/providers/common.h"
namespace onnxruntime {

View file

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include "core/framework/tensorprotoutils.h"
#include "utils.h"
#include "core/providers/utils.h"
namespace onnxruntime {
namespace utils {
@ -23,6 +23,5 @@ common::Status OutputOptionalWithoutDataHelper(const ONNX_NAMESPACE::TypeProto&
return Status::OK();
}
#endif
} // namespace utils
} // namespace onnxruntime

View file

@ -6,12 +6,12 @@
#include <unordered_map>
#include "core/common/common.h"
#include "core/framework/node_unit.h"
#include "core/framework/op_node_proto_helper.h"
#include "core/graph/graph_utils.h"
#include "core/graph/graph_viewer.h"
#include "core/providers/common.h"
#include "core/providers/cpu/nn/pool_attributes.h"
#include "core/providers/shared/node_unit/node_unit.h"
#include "core/providers/xnnpack/detail/utils.h"
// each operator provides a helper to check if supported

View file

@ -6,15 +6,15 @@
#include <vector>
#include "core/common/common.h"
#include "core/common/safeint.h"
#include "core/framework/node_unit.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/indexed_sub_graph.h"
#include "core/graph/node_attr_utils.h"
#include "core/providers/shared/node_unit/node_unit.h"
#include "onnx/defs/attr_proto_util.h"
#include "core/common/safeint.h"
#include "core/optimizer/initializer.h"
#include "onnx/defs/attr_proto_util.h"
namespace onnxruntime {
namespace xnnpack {

View file

@ -10,10 +10,10 @@
#include <string>
#include <utility>
#include "core/framework/node_unit.h"
#include "core/framework/op_kernel.h"
#include "core/graph/indexed_sub_graph.h"
#include "core/providers/common.h"
#include "core/providers/shared/node_unit/node_unit.h"
#include "xnnpack.h"

View file

@ -6,17 +6,17 @@
#include <unordered_set>
#include <utility>
#include "core/graph/function_utils.h"
#include "xnnpack_execution_provider.h"
#include "detail/utils.h"
#include "detail/node_support_checker.h"
#include "core/framework/compute_capability.h"
#include "core/framework/kernel_registry.h"
#include "core/providers/shared/node_unit/node_unit.h"
#include "core/framework/node_unit.h"
#include "core/graph/function_utils.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "xnnpack_init.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
#include "core/providers/xnnpack/xnnpack_execution_provider.h"
#include "core/providers/xnnpack/detail/utils.h"
#include "core/providers/xnnpack/detail/node_support_checker.h"
#include "core/providers/xnnpack/xnnpack_init.h"
namespace onnxruntime {
@ -268,7 +268,7 @@ std::vector<std::unique_ptr<ComputeCapability>> XnnpackExecutionProvider::GetCap
// Get all the NodeUnits in the GraphViewer so we can check if something is in a QDQ node group
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph);
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph);
// This holds the result of whether a NodeUnit is supported or not,
// to prevent nodes in a NodeUnit being checked for multiple times

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "test_fp16.h"
#include <iomanip>
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "core/framework/compute_capability.h"
#include "core/framework/node_unit.h"
#include "core/graph/model.h"
#include "core/graph/onnx_protobuf.h"
#include "core/mlas/inc/mlas.h"
@ -9,7 +10,6 @@
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
#include "core/optimizer/utils.h"
#include "core/providers/partitioning_utils.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/environment.h"
@ -30,10 +30,6 @@
#pragma warning(disable : 4127)
#endif // #if defined(_MSC_VER)
#ifdef USE_NNAPI
#include "core/providers/shared/node_unit/node_unit.h"
#endif // #ifdef USE_NNAPI
struct QDQOpKeys {
const char* quantize_linear;
const char* dequantize_linear;
@ -3243,14 +3239,14 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) {
ASSERT_EQ(std::vector<NodeIndex>({4}), qdq_group.q_nodes);
}
// The function GetAllNodeUnits is enabled for NNAPI EP only for now
#ifdef USE_NNAPI
// The function GetAllNodeUnits is used by NNAPI, XNNPACK and QNN
#if defined(USE_NNAPI) || defined(USE_QNN) || defined(USE_XNNPACK)
{
// Get all the NodeUnits in the graph_viewer
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(whole_graph_viewer);
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer);
// We should get a single QDQ Node unit in the result
ASSERT_EQ(1, node_unit_holder.size());
@ -3288,7 +3284,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) {
verify_io_def(qdq_node_unit.Inputs()[2], *whole_graph_viewer.GetNode(2)); // DQ_bias
verify_io_def(qdq_node_unit.Outputs()[0], *whole_graph_viewer.GetNode(4)); // Q_output
}
#endif // #ifdef USE_NNAPI
#endif // defined(USE_NNAPI) || defined(USE_QNN) || defined(USE_XNNPACK)
// Create a graph viewer covers part of the graph
// Make sure the qdq conv selector will fail for the partial graph

View file

@ -220,6 +220,7 @@ InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer&
auto compile_capabilities = utils::CreateSupportedPartitions(graph_viewer, supported_compiled_nodes, stop_ops_,
generate_metadef_name, ep_name_,
onnxruntime::utils::kInternalTestingExecutionProvider,
/*QDQ NodeUnit map*/ nullptr,
debug_output_);
if (!static_capabilities.empty()) {

View file

@ -0,0 +1,174 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "core/common/common.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"
#include "core/framework/node_unit.h"
#include "core/framework/compute_capability.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
#include "core/providers/partitioning_utils.h"
#include "test/optimizer/graph_transform_test_builder.h"
#include "test/optimizer/qdq_test_utils.h"
#include "test/util/include/asserts.h"
#include "test/util/include/test_utils.h"
#include "test/util/include/test/test_environment.h"
namespace onnxruntime {
namespace test {
// Test handling of a DQ node that is connected to an initializer at the start of the graph, but not used
// in a QDQ node group until after an unsupported node in the graph. If we do not process QDQ node units
// correctly this DQ will incorrectly be in the first partition, with the rest of the QDQ node group in
// the second partition.
TEST(PartitioningUtilsTest, TestQDQHandling) {
constexpr const ORTCHAR_T* model_uri = ORT_TSTR("testdata/ort_github_issue_19590.onnx");
auto& logger = DefaultLoggingManager().DefaultLogger();
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, logger));
Graph& graph = p_model->MainGraph();
GraphViewer graph_viewer = GraphViewer(graph);
// we want everything but the Cast in the test model to be supported
const auto is_node_supported = [&](const Node& node) -> bool {
return node.OpType() != "Cast";
};
const auto on_group_closed = [&](const std::vector<const Node*>& /*group*/) -> bool {
return true;
};
const auto gen_metadef_name = [&]() {
static int metadef_id = 0;
return "TestMetaDef_" + std::to_string(metadef_id++);
};
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
auto result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed,
gen_metadef_name, "TEST", kCpuExecutionProvider, &node_unit_map,
true);
// we should have 2 supported partitions, split by the Cast node.
// the first should have the Mul and NOT the DQ for the initializer if everything worked correctly.
ASSERT_EQ(result.size(), size_t(2)) << "Expected 2 partitions";
ASSERT_EQ(result[0]->sub_graph->nodes.size(), size_t(1)) << "First partition should only have the Mul and not a DQ";
ASSERT_EQ(result[1]->sub_graph->nodes.size(), size_t(5)); // everything else except the unsupported Cast
}
/// Check that CreateSupportedPartitions processes all nodes without error.
static void CheckAllNodesProcessed(const std::function<void(ModelTestBuilder&)>& build_model) {
auto& logger = DefaultLoggingManager().DefaultLogger();
const std::unordered_map<std::string, int> domain_to_version = {{"", 15}};
Model model("PartitioningUtils_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, logger);
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
build_model(helper);
helper.SetGraphOutputs();
ASSERT_STATUS_OK(model.MainGraph().Resolve());
GraphViewer graph_viewer = GraphViewer(graph);
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
const auto is_node_supported = [&](const Node& /*node*/) -> bool {
return true;
};
const auto on_group_closed = [&](const std::vector<const Node*>& /*group*/) -> bool {
return true;
};
const auto gen_metadef_name = [&]() {
static int metadef_id = 0;
return "TestMetaDef_" + std::to_string(metadef_id++);
};
auto result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed,
gen_metadef_name, "TEST", kCpuExecutionProvider, &node_unit_map,
true);
// the 'real' test is that CreateSupportedPartitions doesn't throw due to a mismatch with expected vs processed nodes
// as all ops are supported there should only ever be 1 partition
ASSERT_EQ(result.size(), size_t(1)) << "Expected 1 partition";
}
TEST(PartitioningUtilsTest, TestHandlingQDQNodeUnitWithNoQNodes) {
// build graph with QDQ node unit for logical operator (Equal) that has no Q node and a downstream node (Cast).
auto build_model = [](ModelTestBuilder& builder) {
constexpr uint8_t zero_point = 0;
constexpr float qdq_scale = 0.0038f;
const std::vector<int64_t> input_shape = {1, 3, 8, 8};
auto* input0 = builder.MakeInput<float>(input_shape, -1.0f, 1.0f);
auto* input1 = builder.MakeInput<float>(input_shape, -1.0f, 1.0f);
auto* output = builder.MakeOutput();
// input -> Q -> DQ -> Op
auto* qdq0_output = AddQDQNodePair<uint8_t>(builder, input0, qdq_scale, zero_point);
auto* qdq1_output = AddQDQNodePair<uint8_t>(builder, input1, qdq_scale, zero_point);
// Equal ->
auto* equal_output = builder.MakeIntermediate();
builder.AddNode("Equal", {qdq0_output, qdq1_output}, {equal_output});
// -> Cast -> output
Node& cast_node = builder.AddNode("Cast", {equal_output}, {output});
cast_node.AddAttribute("to",
static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT));
};
CheckAllNodesProcessed(build_model);
}
// TopK produces 2 outputs, one of which is used in a QDQ node group (Q of values output)
// and the other (indices output) is not. A downstream node consuming the indices output has an edge from the target
// node and not a Q node.
// To process this correctly, the QDQ NodeUnit must return output edges for both the Q node/s of the values output,
// and the downstream node (Cast in this case) of the indices output.
TEST(PartitioningUtilsTest, TestQDQNodeGroupWithOutputFromTargetNode) {
const auto build_model = [](ModelTestBuilder& builder) {
constexpr uint8_t zero_point = 0;
constexpr float qdq_scale = 0.0038f;
const std::vector<int64_t> input_shape = {1, 3, 8, 8};
auto* input0 = builder.MakeInput<float>(input_shape, -1.0f, 1.0f);
// input -> Q -> DQ ->
auto* qdq0_output = AddQDQNodePair<uint8_t>(builder, input0, qdq_scale, zero_point);
// K input
NodeArg* k_input = builder.MakeInput<int64_t>({1}, {10});
// TopK op
NodeArg* values_output = builder.MakeIntermediate();
NodeArg* indices_output = builder.MakeIntermediate();
builder.AddNode("TopK", {qdq0_output, k_input}, {values_output, indices_output});
// values -> Q -> DQ -> graph output
AddQDQNodePairWithOutputAsGraphOutput<uint8_t>(builder, values_output, qdq_scale, zero_point);
// indices -> Cast -> graph output
auto* i_output = builder.MakeOutput();
Node& cast_node = builder.AddNode("Cast", {indices_output}, {i_output});
const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32;
cast_node.AddAttribute("to", static_cast<int64_t>(dst_type));
};
CheckAllNodesProcessed(build_model);
}
} // namespace test
} // namespace onnxruntime

Binary file not shown.

View file

@ -0,0 +1,77 @@
import onnx
from onnx import TensorProto, helper
# graph with a QDQ MatMul node unit where one input is and initializer -> DQ and the other is on a path that
# contains a supported node followed by an unsupported node followed by the DQ -> MatMul.
# The DQ of the initializer is prior to the unsupported node. If the partitioning utils do not process the QDQ node
# unit together, the DQ for the initializer and the first supported node will be in the first partition, which
# incorrectly breaks up the QDQ node unit.
graph_proto = helper.make_graph(
[
# DQ of initializer for MatMul B input
helper.make_node(
"DequantizeLinear",
inputs=["matmul_b_uint8", "scale0"],
outputs=["dq_matmul_b"],
name="dq_matmul_b",
),
# Treat as supported
helper.make_node(
"Mul",
inputs=["input:0", "scale_input"],
outputs=["mul:0"],
name="mul0",
),
# Treat as unsupported
helper.make_node("Cast", inputs=["mul:0"], outputs=["mul_uint8"], name="cast0", to=2),
# DQ of MatMul A input
helper.make_node(
"DequantizeLinear",
inputs=["mul_uint8", "scale1"],
outputs=["dq_matmul_a"],
name="dq_matmul_a",
),
# MatMul
helper.make_node(
"MatMul",
inputs=[
"dq_matmul_a",
"dq_matmul_b",
],
outputs=["matmul_ab"],
name="matmul_ab",
),
# Q
helper.make_node(
"QuantizeLinear",
inputs=["matmul_ab", "scale2"],
outputs=["q_matmul_ab"],
name="q_matmul_ab",
),
# DQ for model output
helper.make_node(
"DequantizeLinear",
inputs=["q_matmul_ab", "scale2"],
outputs=["out:0"],
name="dq_graph_output",
),
],
"Main_graph",
[
helper.make_tensor_value_info("input:0", TensorProto.FLOAT, [3, 2]),
],
[
helper.make_tensor_value_info("out:0", TensorProto.FLOAT, [3, 2]),
],
[
helper.make_tensor("scale0", TensorProto.FLOAT, [1], [20.0]),
helper.make_tensor("scale1", TensorProto.FLOAT, [1], [30.0]),
helper.make_tensor("scale2", TensorProto.FLOAT, [1], [40.0]),
helper.make_tensor("matmul_b_uint8", TensorProto.UINT8, [2, 2], [1, 2, 3, 4]),
helper.make_tensor("scale_input", TensorProto.FLOAT, [2], [3.0, 4.0]),
],
)
model = helper.make_model(graph_proto)
onnx.checker.check_model(model, True)
onnx.save(model, "ort_github_issue_19590.onnx")