mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
Make partitioning utils QDQ aware so it does not break up QDQ node units (#19723)
### Description <!-- Describe your changes. --> If the EP handles QDQ node units, we need to make sure we do not split those into different partitions. Update the partitioning utils to be QDQ aware. If there are node units we process the logical nodes they represent instead of individual nodes. This ensure we process all nodes in a QDQ node unit at the same time so that they are always in the same partition. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Fix one of the issues in #19590 --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
This commit is contained in:
parent
cba605e845
commit
978c40d853
38 changed files with 886 additions and 489 deletions
|
|
@ -37,8 +37,13 @@ if (onnxruntime_BUILD_UNIT_TESTS)
|
|||
set(gtest_disable_pthreads ON)
|
||||
endif()
|
||||
set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
|
||||
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
|
||||
# Needs to update onnxruntime/test/xctest/xcgtest.mm
|
||||
if (IOS OR ANDROID)
|
||||
# on mobile platforms the absl flags class dumps the flag names (assumably for binary size), which breaks passing
|
||||
# any args to gtest executables, such as using --gtest_filter to debug a specific test.
|
||||
# Processing of compile definitions:
|
||||
# https://github.com/abseil/abseil-cpp/blob/8dc90ff07402cd027daec520bb77f46e51855889/absl/flags/config.h#L21
|
||||
# If set, this code throws away the flag and does nothing on registration, which results in no flags being known:
|
||||
# https://github.com/abseil/abseil-cpp/blob/8dc90ff07402cd027daec520bb77f46e51855889/absl/flags/flag.h#L205-L217
|
||||
set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE)
|
||||
else()
|
||||
set(GTEST_HAS_ABSL ON CACHE BOOL "" FORCE)
|
||||
|
|
|
|||
|
|
@ -70,8 +70,8 @@ list(FILTER coreml_proto_generated_srcs INCLUDE REGEX "\.pb\.(h|cc)$")
|
|||
source_group(TREE ${CMAKE_CURRENT_BINARY_DIR} PREFIX coreml_proto_generated FILES ${coreml_proto_generated_srcs})
|
||||
|
||||
# These are shared utils,
|
||||
# TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML
|
||||
file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
|
||||
# TODO, move this to a separate lib when used by EPs other than NNAPI and CoreML
|
||||
file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -49,12 +49,10 @@
|
|||
endif()
|
||||
|
||||
# These are shared utils,
|
||||
# TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML
|
||||
# TODO, move this to a separate lib when used by EPs other than NNAPI and CoreML
|
||||
list(APPEND onnxruntime_provider_nnapi_cc_src_patterns
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc"
|
||||
)
|
||||
|
||||
file(GLOB onnxruntime_providers_nnapi_cc_srcs CONFIGURE_DEPENDS ${onnxruntime_provider_nnapi_cc_src_patterns})
|
||||
|
|
@ -81,4 +79,4 @@
|
|||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR})
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -4,12 +4,10 @@
|
|||
add_compile_definitions(USE_QNN=1)
|
||||
|
||||
# These are shared utils,
|
||||
# TODO, move this to a separated lib when used by EPs other than QNN, NNAPI and CoreML
|
||||
file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
|
||||
# TODO, move to a separate lib when used by EPs other than QNN, NNAPI and CoreML
|
||||
file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc"
|
||||
)
|
||||
|
||||
file(GLOB_RECURSE
|
||||
|
|
@ -42,4 +40,4 @@
|
|||
# ignore the warning unknown-pragmas on "pragma region"
|
||||
if(NOT MSVC)
|
||||
target_compile_options(onnxruntime_providers_qnn PRIVATE "-Wno-unknown-pragmas")
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -7,9 +7,6 @@
|
|||
"${ONNXRUNTIME_INCLUDE_DIR}/core/providers/xnnpack/*.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.cc"
|
||||
# utils for handling QDQ models
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc"
|
||||
)
|
||||
|
||||
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_xnnpack_cc_srcs})
|
||||
|
|
|
|||
351
onnxruntime/core/framework/node_unit.cc
Normal file
351
onnxruntime/core/framework/node_unit.cc
Normal file
|
|
@ -0,0 +1,351 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
#include "node_unit.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace {
|
||||
|
||||
enum class QLinearOpType : uint8_t {
|
||||
Unknown, // Unknown or not a linear quantized op
|
||||
DequantizeLinear,
|
||||
QuantizeLinear,
|
||||
QLinearConv,
|
||||
QLinearMatMul,
|
||||
QLinearAdd,
|
||||
QLinearSigmoid,
|
||||
QLinearAveragePool,
|
||||
QLinearMul,
|
||||
QLinearReduceMean,
|
||||
QLinearConcat,
|
||||
QLinearGlobalAveragePool,
|
||||
QLinearLeakyRelu,
|
||||
};
|
||||
|
||||
QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) {
|
||||
const auto& op_type = node.OpType();
|
||||
if (op_type == "DequantizeLinear")
|
||||
return QLinearOpType::DequantizeLinear;
|
||||
else if (op_type == "QuantizeLinear")
|
||||
return QLinearOpType::QuantizeLinear;
|
||||
else if (op_type == "QLinearConv")
|
||||
return QLinearOpType::QLinearConv;
|
||||
else if (op_type == "QLinearMatMul")
|
||||
return QLinearOpType::QLinearMatMul;
|
||||
else if (op_type == "QLinearAdd")
|
||||
return QLinearOpType::QLinearAdd;
|
||||
else if (op_type == "QLinearSigmoid")
|
||||
return QLinearOpType::QLinearSigmoid;
|
||||
else if (op_type == "QLinearAveragePool")
|
||||
return QLinearOpType::QLinearAveragePool;
|
||||
else if (op_type == "QLinearMul")
|
||||
return QLinearOpType::QLinearMul;
|
||||
else if (op_type == "QLinearReduceMean")
|
||||
return QLinearOpType::QLinearReduceMean;
|
||||
else if (op_type == "QLinearConcat")
|
||||
return QLinearOpType::QLinearConcat;
|
||||
else if (op_type == "QLinearGlobalAveragePool")
|
||||
return QLinearOpType::QLinearGlobalAveragePool;
|
||||
else if (op_type == "QLinearLeakyRelu")
|
||||
return QLinearOpType::QLinearLeakyRelu;
|
||||
|
||||
return QLinearOpType::Unknown;
|
||||
}
|
||||
|
||||
// Ops have 1 input
|
||||
bool IsUnaryQLinearOp(QLinearOpType type) {
|
||||
return type == QLinearOpType::QLinearSigmoid ||
|
||||
type == QLinearOpType::QLinearAveragePool ||
|
||||
type == QLinearOpType::QLinearGlobalAveragePool ||
|
||||
type == QLinearOpType::QLinearLeakyRelu ||
|
||||
type == QLinearOpType::QLinearReduceMean;
|
||||
}
|
||||
|
||||
// Ops have 2 inputs
|
||||
bool IsBinaryQLinearOp(QLinearOpType type) {
|
||||
return type == QLinearOpType::QLinearConv ||
|
||||
type == QLinearOpType::QLinearMatMul ||
|
||||
type == QLinearOpType::QLinearAdd ||
|
||||
type == QLinearOpType::QLinearMul;
|
||||
}
|
||||
|
||||
// Ops have 1 or more inputs
|
||||
bool IsVariadicQLinearOp(QLinearOpType type) {
|
||||
return type == QLinearOpType::QLinearConcat;
|
||||
}
|
||||
|
||||
const std::vector<const Node*> GetQDQIONodes(const GraphViewer& graph_viewer,
|
||||
const QDQ::NodeGroup& node_group, bool is_input) {
|
||||
std::vector<const Node*> io_nodes;
|
||||
const auto& src_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes;
|
||||
io_nodes.reserve(src_nodes.size());
|
||||
for (const auto& node_idx : src_nodes) {
|
||||
io_nodes.push_back(graph_viewer.GetNode(node_idx));
|
||||
}
|
||||
|
||||
return io_nodes;
|
||||
}
|
||||
|
||||
// Get the input or output NodeUnitIODef(s) for the given QDQ NodeGroup
|
||||
std::vector<NodeUnitIODef> GetQDQIODefs(const Node& target_node, const QDQ::NodeGroup& node_group, bool is_input) {
|
||||
const auto& dq_or_q_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes;
|
||||
const auto target_node_io_defs = is_input ? target_node.InputDefs() : target_node.OutputDefs();
|
||||
const size_t target_node_io_defs_size = target_node_io_defs.size();
|
||||
|
||||
// Find all the quantized IO defs and indices (for the input/output of the target node)
|
||||
std::unordered_map<size_t, NodeUnitIODef> quantized_io_defs;
|
||||
quantized_io_defs.reserve(target_node_io_defs_size);
|
||||
|
||||
auto cur = is_input ? target_node.InputEdgesBegin() : target_node.OutputEdgesBegin();
|
||||
auto end = is_input ? target_node.InputEdgesEnd() : target_node.OutputEdgesEnd();
|
||||
|
||||
for (; cur != end; ++cur) {
|
||||
const Node& node = cur->GetNode();
|
||||
|
||||
// If we can find the node index in the dq or q nodes this is a quantized input/output
|
||||
if (std::find(dq_or_q_nodes.cbegin(), dq_or_q_nodes.cend(), node.Index()) != dq_or_q_nodes.cend()) {
|
||||
const auto node_inputs = node.InputDefs();
|
||||
// quantization scale and zp are always the input[1, 2]
|
||||
NodeUnitIODef::QuantParam quant_param{*node_inputs[1], node_inputs.size() == 3 ? node_inputs[2] : nullptr};
|
||||
|
||||
if (is_input) {
|
||||
// DQ is input to the target node, use the DstArgIndex
|
||||
auto idx = cur->GetDstArgIndex();
|
||||
// This is a DQ node, we are using x, x_scale, x_zp (input[0, 1, 2])
|
||||
quantized_io_defs.insert({idx, NodeUnitIODef{*node_inputs[0], quant_param}});
|
||||
} else {
|
||||
// Q is output of the target node, use the SrcArgIndex
|
||||
auto idx = cur->GetSrcArgIndex();
|
||||
// This is a Q node, we are using y (output[0]), y_scale, y_zp (input[1, 2])
|
||||
const auto node_outputs = node.OutputDefs();
|
||||
quantized_io_defs.insert({idx, NodeUnitIODef{*node_outputs[0], quant_param}});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Construct the IODefs for this QDQ NodeGroup
|
||||
std::vector<NodeUnitIODef> io_defs;
|
||||
io_defs.reserve(target_node_io_defs_size);
|
||||
for (size_t i = 0; i < target_node_io_defs_size; i++) {
|
||||
// If we can find the NodeUnitIODef for this index, this is a quantized input/output
|
||||
if (quantized_io_defs.find(i) != quantized_io_defs.cend()) {
|
||||
io_defs.push_back(std::move(quantized_io_defs.at(i)));
|
||||
} else {
|
||||
// This is a regular input
|
||||
io_defs.push_back({*target_node_io_defs[i], std::nullopt});
|
||||
}
|
||||
}
|
||||
|
||||
return io_defs;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer,
|
||||
const Node& target_node,
|
||||
gsl::span<const Node* const> dq_nodes,
|
||||
gsl::span<const Node* const> q_nodes) {
|
||||
// Within a QDQ node group, a target node input is the only consumer of each DQ.
|
||||
// This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications
|
||||
// may have happened since. Verify that this is still true.
|
||||
for (const auto* dq_node : dq_nodes) {
|
||||
const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node);
|
||||
ORT_RETURN_IF(dq_produces_graph_output,
|
||||
"QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(),
|
||||
", target node: ", target_node.Name());
|
||||
|
||||
const bool dq_has_single_output_edge_to_target =
|
||||
dq_node->GetOutputEdgesCount() == 1 &&
|
||||
dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index();
|
||||
ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target,
|
||||
"QDQ node group cannot have DQ that doesn't have a single output edge to the target node. "
|
||||
"DQ node: ",
|
||||
dq_node->Name(), ", target node: ", target_node.Name());
|
||||
}
|
||||
|
||||
// an output from the target node can have either Q consumers or direct consumers. it cannot have both.
|
||||
// this must be checked on a per output basis.
|
||||
// e.g. TopK produces values and indices. The indices output won't be quantized, so even if we replace the TopK QDQ
|
||||
// node group with a quantized TopK, an int64_t indices value will be produced and can provide a graph output.
|
||||
if (!q_nodes.empty()) {
|
||||
auto cur_edge = target_node.OutputEdgesBegin();
|
||||
auto end_edge = target_node.OutputEdgesEnd();
|
||||
std::vector<const Node*> output_consumers(target_node.OutputDefs().size(), nullptr);
|
||||
|
||||
for (; cur_edge != end_edge; ++cur_edge) {
|
||||
auto output_idx = cur_edge->GetSrcArgIndex();
|
||||
const Node& this_consumer = cur_edge->GetNode();
|
||||
const Node* existing_consumer = output_consumers[output_idx];
|
||||
|
||||
if (existing_consumer != nullptr) {
|
||||
// another edge for this output. either both are Q or both are not.
|
||||
bool valid = true;
|
||||
if (existing_consumer->OpType() == "QuantizeLinear") {
|
||||
valid = this_consumer.OpType() == "QuantizeLinear";
|
||||
} else {
|
||||
valid = this_consumer.OpType() != "QuantizeLinear";
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_NOT(valid,
|
||||
"QDQ node group cannot have an output from the target node being consumed by a Q node and "
|
||||
"a non-Q node. target node: ",
|
||||
target_node.Name());
|
||||
} else {
|
||||
output_consumers[output_idx] = &this_consumer;
|
||||
}
|
||||
}
|
||||
|
||||
const auto& graph_outputs = graph_viewer.GetOutputs();
|
||||
for (size_t idx = 0, end = output_consumers.size(); idx < end; ++idx) {
|
||||
// any output with a Q cannot be a graph output as it will disappear if the QDQ node unit is converted to
|
||||
// a quantized op.
|
||||
if (output_consumers[idx] != nullptr && output_consumers[idx]->OpType() == "QuantizeLinear") {
|
||||
const auto& output_name = target_node.OutputDefs()[idx]->Name();
|
||||
bool is_graph_output = std::any_of(graph_outputs.begin(), graph_outputs.end(),
|
||||
[&output_name](const NodeArg* node_arg) {
|
||||
return node_arg->Name() == output_name;
|
||||
});
|
||||
ORT_RETURN_IF(is_graph_output,
|
||||
"QDQ node group cannot have an output from the target node that is consumed by a Q node and "
|
||||
"a graph output. target node: ",
|
||||
target_node.Name(), " output idx:", idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
NodeUnit::NodeUnit(const Node& node)
|
||||
: target_node_(node),
|
||||
type_(Type::SingleNode),
|
||||
input_edge_count_(node.GetInputEdgesCount()) {
|
||||
InitForSingleNode();
|
||||
}
|
||||
|
||||
NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group)
|
||||
: dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)},
|
||||
target_node_(*graph_viewer.GetNode(node_group.target_node)),
|
||||
q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)},
|
||||
type_(Type::QDQGroup),
|
||||
inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)},
|
||||
outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} {
|
||||
ORT_THROW_IF_ERROR(QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, target_node_, dq_nodes_, q_nodes_));
|
||||
|
||||
input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0),
|
||||
[](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); });
|
||||
|
||||
// add edges for inputs that are not from DQ nodes. there is one edge to each DQ node.
|
||||
// other inputs could come from initializers or graph inputs (no edges) or other nodes (edge).
|
||||
input_edge_count_ += target_node_.GetInputEdgesCount() - dq_nodes_.size();
|
||||
|
||||
// create output edges. each target node output either goes to Q node/s or non-Q node/s.
|
||||
// ValidateNodeGroupQDQNodes ensures this.
|
||||
auto cur_edge = target_node_.OutputEdgesBegin();
|
||||
auto end_edge = target_node_.OutputEdgesEnd();
|
||||
for (; cur_edge != end_edge; ++cur_edge) {
|
||||
const Node& node = cur_edge->GetNode();
|
||||
|
||||
// if node is in q_nodes we hide the Q node.
|
||||
if (std::find(q_nodes_.cbegin(), q_nodes_.cend(), &node) != q_nodes_.cend()) {
|
||||
auto src_idx = cur_edge->GetSrcArgIndex();
|
||||
auto q_cur_edge = node.OutputEdgesBegin();
|
||||
auto q_end_edge = node.OutputEdgesEnd();
|
||||
for (; q_cur_edge != q_end_edge; ++q_cur_edge) {
|
||||
output_edges_.insert(Node::EdgeEnd{q_cur_edge->GetNode(), src_idx, q_cur_edge->GetDstArgIndex()});
|
||||
}
|
||||
} else {
|
||||
// non-Q node, or Q node that isn't in the QDQ node group (unexpected but may be possible). add as-is.
|
||||
output_edges_.insert(*cur_edge);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); }
|
||||
const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); }
|
||||
const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); }
|
||||
int NodeUnit::SinceVersion() const noexcept { return target_node_.SinceVersion(); }
|
||||
NodeIndex NodeUnit::Index() const noexcept { return target_node_.Index(); }
|
||||
const Path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); }
|
||||
ProviderType NodeUnit::GetExecutionProviderType() const noexcept { return target_node_.GetExecutionProviderType(); }
|
||||
|
||||
void NodeUnit::InitForSingleNode() {
|
||||
const auto& input_defs = target_node_.InputDefs();
|
||||
const auto& output_defs = target_node_.OutputDefs();
|
||||
auto qlinear_type = GetQLinearOpType(target_node_);
|
||||
if (qlinear_type == QLinearOpType::Unknown || IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support
|
||||
// Not a Qlinear op, add all inputs / outputs
|
||||
auto add_all_io = [](std::vector<NodeUnitIODef>& defs,
|
||||
const ConstPointerContainer<std::vector<NodeArg*>>& node_defs) {
|
||||
defs.reserve(node_defs.size());
|
||||
|
||||
for (const auto def : node_defs) {
|
||||
defs.push_back(NodeUnitIODef{*def, std::nullopt});
|
||||
}
|
||||
};
|
||||
|
||||
add_all_io(inputs_, input_defs);
|
||||
add_all_io(outputs_, output_defs);
|
||||
} else if (IsUnaryQLinearOp(qlinear_type)) {
|
||||
// Unary QLinear Op has 5 inputs
|
||||
// x, x_scale, x_zp, y_scale, y_zp (optional)
|
||||
inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}});
|
||||
outputs_.push_back(NodeUnitIODef{*output_defs[0],
|
||||
NodeUnitIODef::QuantParam{*input_defs[3],
|
||||
input_defs.size() > 4 ? input_defs[4] : nullptr}});
|
||||
|
||||
} else if (IsBinaryQLinearOp(qlinear_type)) {
|
||||
// Binary QLinear Op has 9 inputs
|
||||
// x1, x1_scale, x1_zp, x2/w, x2_scale, x2_zp, y_scale , y_zp, B
|
||||
inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}});
|
||||
inputs_.push_back(NodeUnitIODef{*input_defs[3], NodeUnitIODef::QuantParam{*input_defs[4], input_defs[5]}});
|
||||
|
||||
if (input_defs.size() == 9) { // has Bias
|
||||
inputs_.push_back(NodeUnitIODef{*input_defs[8], std::nullopt}); // for Bias the scale and zp are optional
|
||||
}
|
||||
|
||||
outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[6], input_defs[7]}});
|
||||
|
||||
} else if (qlinear_type == QLinearOpType::DequantizeLinear) {
|
||||
// DequantizeLinear has 3 inputs
|
||||
// x, x_scale, x_zp
|
||||
// output is not quantized
|
||||
inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3
|
||||
? input_defs[2]
|
||||
: nullptr}});
|
||||
outputs_.push_back(NodeUnitIODef{*output_defs[0], std::nullopt});
|
||||
|
||||
} else if (qlinear_type == QLinearOpType::QuantizeLinear) {
|
||||
// QuantizeLinear the input is not quantized and has 3 inputs
|
||||
// x, y_scale, y_zp (optional)
|
||||
// The output is quantized
|
||||
inputs_.push_back(NodeUnitIODef{*input_defs[0], std::nullopt});
|
||||
outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3
|
||||
? input_defs[2]
|
||||
: nullptr}});
|
||||
} else {
|
||||
ORT_THROW("The QLinear op [", static_cast<uint8_t>(qlinear_type), "] is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
Node::EdgeConstIterator NodeUnit::OutputEdgesBegin() const {
|
||||
return (type_ == Type::SingleNode) ? target_node_.OutputEdgesBegin() : output_edges_.begin();
|
||||
}
|
||||
|
||||
Node::EdgeConstIterator NodeUnit::OutputEdgesEnd() const {
|
||||
return (type_ == Type::SingleNode) ? target_node_.OutputEdgesEnd() : output_edges_.end();
|
||||
}
|
||||
|
||||
std::vector<const Node*> NodeUnit::GetAllNodesInGroup() const noexcept {
|
||||
std::vector<const Node*> all_nodes = dq_nodes_;
|
||||
all_nodes.push_back(&target_node_);
|
||||
all_nodes.insert(all_nodes.end(), q_nodes_.begin(), q_nodes_.end());
|
||||
return all_nodes;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
|
@ -3,6 +3,9 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
#include <string>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
|
@ -18,8 +21,21 @@ class NodeArg;
|
|||
class Path;
|
||||
|
||||
namespace QDQ {
|
||||
struct NodeGroup;
|
||||
}
|
||||
// Struct to represent a DequantizeLinear -> Op -> QuantizeLinear node group
|
||||
struct NodeGroup {
|
||||
std::vector<NodeIndex> dq_nodes;
|
||||
std::vector<NodeIndex> q_nodes;
|
||||
NodeIndex target_node;
|
||||
|
||||
// Validator to check if the set of nodes can form a valid QDQ NodeGroup.
|
||||
// Checks target node is only consumer of each DQ, and that the outputs remain valid if the QDQ node group was to
|
||||
// be converted into a single node with a quantized operator.
|
||||
static Status CanCreateNodeGroup(const GraphViewer& graph_viewer,
|
||||
const Node& target_node,
|
||||
gsl::span<const Node* const> dq_nodes,
|
||||
gsl::span<const Node* const> q_nodes);
|
||||
};
|
||||
} // namespace QDQ
|
||||
|
||||
// Definition of one input or output
|
||||
// If the optional quant_param is present, then this is a quantized input,
|
||||
|
|
@ -69,26 +85,33 @@ class NodeUnit {
|
|||
const std::vector<const Node*>& GetQNodes() const noexcept { return q_nodes_; }
|
||||
std::vector<const Node*> GetAllNodesInGroup() const noexcept;
|
||||
|
||||
Node::EdgeConstIterator OutputEdgesBegin(size_t index) const;
|
||||
Node::EdgeConstIterator OutputEdgesEnd(size_t index) const;
|
||||
/// Number of input edges to the logical node. For a QDQ node this is the count of input edges to the DQ nodes
|
||||
/// plus any other edges to the target node for inputs that are not via a DQ node.
|
||||
size_t InputEdgeCount() const { return input_edge_count_; }
|
||||
|
||||
// output edges. src index is for outputs of the target node. dest index and node is for consumer of node unit
|
||||
// output. any Q nodes are hidden.
|
||||
Node::EdgeConstIterator OutputEdgesBegin() const;
|
||||
Node::EdgeConstIterator OutputEdgesEnd() const;
|
||||
|
||||
private:
|
||||
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit
|
||||
const std::vector<const Node*> dq_nodes_; // dq nodes for this NodeUnit, not all inputs
|
||||
// Initialization for a NodeUnit that contains a single node
|
||||
void InitForSingleNode();
|
||||
|
||||
const std::vector<const Node*> dq_nodes_; // dq nodes for this NodeUnit, not necessarily all inputs
|
||||
const Node& target_node_;
|
||||
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs
|
||||
const Type type_;
|
||||
|
||||
std::vector<NodeUnitIODef> inputs_;
|
||||
std::vector<NodeUnitIODef> outputs_;
|
||||
|
||||
// Initializing for a single Node
|
||||
void InitForSingleNode();
|
||||
size_t input_edge_count_; // total number of input edges
|
||||
|
||||
// output edges, hiding any Q nodes involved. src_idx will be value from target node. only used for QDQ node group.
|
||||
Node::EdgeSet output_edges_;
|
||||
};
|
||||
|
||||
// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup)
|
||||
// And return a map to quick query the NodeUnit which contains the given Node,
|
||||
// Note, the value of the map is owned by the vector of std::unique_ptr<NodeUnit>
|
||||
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
|
||||
GetAllNodeUnits(const GraphViewer& graph_viewer);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
|
@ -58,8 +58,8 @@ bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Nod
|
|||
return false;
|
||||
}
|
||||
|
||||
if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes);
|
||||
!dq_validation_status.IsOK()) {
|
||||
if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes);
|
||||
!qdq_validation_status.IsOK()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -153,8 +153,8 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer,
|
|||
return false;
|
||||
}
|
||||
|
||||
if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes);
|
||||
!dq_validation_status.IsOK()) {
|
||||
if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes);
|
||||
!qdq_validation_status.IsOK()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -544,8 +544,8 @@ bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer,
|
|||
return false;
|
||||
}
|
||||
|
||||
if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes);
|
||||
!dq_validation_status.IsOK()) {
|
||||
if (const auto qdq_validation_status = QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes);
|
||||
!qdq_validation_status.IsOK()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/optimizer/selectors_actions/selector_action_transformer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -13,13 +14,6 @@ class Node;
|
|||
|
||||
namespace QDQ {
|
||||
|
||||
// Struct to represent a DQ->Op->Q node group
|
||||
struct NodeGroup {
|
||||
std::vector<NodeIndex> dq_nodes;
|
||||
std::vector<NodeIndex> q_nodes;
|
||||
NodeIndex target_node;
|
||||
};
|
||||
|
||||
class NodeGroupSelector {
|
||||
public:
|
||||
// This is a QDQ Selectors only function, will return QDQ::NodeGroup instead of NodesToOptimizeIndices
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
#include <core/providers/common.h>
|
||||
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace QDQ {
|
||||
|
|
@ -43,6 +44,7 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() {
|
|||
{"Tile", {}}};
|
||||
}
|
||||
|
||||
// These produce int64 indices output, which can't be quantized, so there's no downstream Q node.
|
||||
static const OpVersionsAndSelector::OpVersionsMap GetDropDQOpVersionsMap() {
|
||||
return {{"ArgMax", {}},
|
||||
{"ArgMin", {}}};
|
||||
|
|
@ -324,28 +326,48 @@ std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& grap
|
|||
return qdq_selections;
|
||||
}
|
||||
|
||||
Status ValidateNodeGroupDQNodes(const GraphViewer& graph_viewer,
|
||||
const Node& target_node,
|
||||
gsl::span<const Node* const> dq_nodes) {
|
||||
// Within a QDQ node group, a target node input is the only consumer of each DQ.
|
||||
// This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications
|
||||
// may have happened since. Verify that this is still true.
|
||||
for (const auto* dq_node : dq_nodes) {
|
||||
const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node);
|
||||
ORT_RETURN_IF(dq_produces_graph_output,
|
||||
"QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(),
|
||||
", target node: ", target_node.Name());
|
||||
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
|
||||
GetAllNodeUnits(const GraphViewer& graph_viewer) {
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
|
||||
const bool dq_has_single_output_edge_to_target =
|
||||
dq_node->GetOutputEdgesCount() == 1 &&
|
||||
dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index();
|
||||
ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target,
|
||||
"QDQ node group cannot have DQ that doesn't have a single output edge to the target node. "
|
||||
"DQ node: ",
|
||||
dq_node->Name(), ", target node: ", target_node.Name());
|
||||
const auto add_node_unit_to_map = [&](const std::vector<NodeIndex>& node_indices, const NodeUnit* node_unit) {
|
||||
for (const auto& node_idx : node_indices) {
|
||||
const auto* node = graph_viewer.GetNode(node_idx);
|
||||
node_unit_map.insert({node, node_unit});
|
||||
}
|
||||
};
|
||||
|
||||
// Get QDQ NodeUnits first
|
||||
QDQ::SelectorManager selector_mgr;
|
||||
const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer);
|
||||
|
||||
for (const auto& qdq_selection : qdq_selections) {
|
||||
auto qdq_unit = std::make_unique<NodeUnit>(graph_viewer, qdq_selection);
|
||||
|
||||
// Fill the node to node_unit map for all nodes in the QDQ Group
|
||||
add_node_unit_to_map(qdq_selection.dq_nodes, qdq_unit.get());
|
||||
add_node_unit_to_map(qdq_selection.q_nodes, qdq_unit.get());
|
||||
add_node_unit_to_map({qdq_selection.target_node}, qdq_unit.get());
|
||||
|
||||
node_unit_holder.push_back(std::move(qdq_unit));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
// Get the left over SingleNode NodeUnits
|
||||
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (const auto node_idx : node_indices) {
|
||||
const auto* node(graph_viewer.GetNode(node_idx));
|
||||
|
||||
// This is already part of a QDQ NodeUnit
|
||||
if (node_unit_map.find(node) != node_unit_map.cend())
|
||||
continue;
|
||||
|
||||
auto node_unit = std::make_unique<NodeUnit>(*node);
|
||||
node_unit_map[node] = node_unit.get();
|
||||
node_unit_holder.push_back(std::move(node_unit));
|
||||
}
|
||||
|
||||
return std::make_pair(std::move(node_unit_holder), std::move(node_unit_map));
|
||||
}
|
||||
|
||||
} // namespace QDQ
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include "core/common/common.h"
|
||||
#include "core/common/gsl.h"
|
||||
#include "core/common/inlined_containers.h"
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/graph/basic_types.h"
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
|
@ -78,11 +79,16 @@ class SelectorManager {
|
|||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SelectorManager);
|
||||
};
|
||||
|
||||
// Checks whether the provided DQ nodes are valid for forming a QDQ node group with the provided target node.
|
||||
// Returns successful status if so, failed status with reason otherwise.
|
||||
Status ValidateNodeGroupDQNodes(const GraphViewer& graph_viewer,
|
||||
const Node& target_node,
|
||||
gsl::span<const Node* const> dq_nodes);
|
||||
// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup)
|
||||
// And return a map to quick query the NodeUnit which contains the given Node,
|
||||
// Note, the value of the map is owned by the vector of std::unique_ptr<NodeUnit>
|
||||
//
|
||||
// TODO: The overall QDQ setup needs refactoring to separate out generic functionality from optimizer specific
|
||||
// functionality.
|
||||
// We currently have a bit of a mess with generic things like this to get all the node units being in the optimizer
|
||||
// library whereas it should be able to be used by an EP with no dependency on optimizers.
|
||||
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
|
||||
GetAllNodeUnits(const GraphViewer& graph_viewer);
|
||||
|
||||
} // namespace QDQ
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@
|
|||
#include "core/framework/kernel_registry.h"
|
||||
#include "core/graph/function_utils.h"
|
||||
#include "core/graph/indexed_sub_graph.h"
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
#include "data_transfer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/graph/graph.h"
|
||||
|
|
@ -18,7 +19,6 @@
|
|||
#include "core/providers/common.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h"
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/builders/model_builder.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h"
|
||||
|
|
|
|||
|
|
@ -11,17 +11,19 @@
|
|||
#include "core/common/safeint.h"
|
||||
#include "core/common/status.h"
|
||||
#include "core/framework/execution_provider.h"
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/nnapi_api_helper.h"
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/builders/helper.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
|
||||
using namespace android::nn::wrapper;
|
||||
|
||||
|
|
@ -119,7 +121,7 @@ const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const {
|
|||
}
|
||||
|
||||
void ModelBuilder::PreprocessNodeUnits() {
|
||||
std::tie(node_unit_holder_, node_unit_map_) = GetAllNodeUnits(graph_viewer_);
|
||||
std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_);
|
||||
}
|
||||
|
||||
// Help to get all quantized operators' input and the NodeUnit(s) using the input
|
||||
|
|
@ -664,7 +666,7 @@ int32_t ModelBuilder::FindActivation(const NodeUnit& node_unit) {
|
|||
|
||||
int32_t fuse_code = ANEURALNETWORKS_FUSED_NONE;
|
||||
bool fuse_code_assigned_from_activation = false;
|
||||
for (auto it = node_unit.OutputEdgesBegin(0), end = node_unit.OutputEdgesEnd(0); it != end; ++it) {
|
||||
for (auto it = node_unit.OutputEdgesBegin(), end = node_unit.OutputEdgesEnd(); it != end; ++it) {
|
||||
const auto& dst_node = it->GetNode();
|
||||
const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()];
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@
|
|||
#include "core/optimizer/initializer.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h"
|
||||
|
||||
namespace onnxruntime::nnapi::op_builder_helpers {
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@
|
|||
#include <vector>
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/builders/helper.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/builders/model_builder.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksWrapper.h"
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
|
||||
namespace onnxruntime::nnapi::op_builder_helpers {
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,10 @@
|
|||
#include "core/common/logging/logging.h"
|
||||
#include "core/common/string_utils.h"
|
||||
#include "core/framework/compute_capability.h"
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
|
||||
#include "core/platform/env.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/builders/helper.h"
|
||||
|
|
@ -17,7 +20,6 @@
|
|||
#include "core/providers/nnapi/nnapi_builtin/nnapi_api_helper.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h"
|
||||
#include "core/providers/partitioning_utils.h"
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -119,7 +121,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
|
||||
std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
|
||||
|
||||
// This holds the result of whether a NodeUnit is supported or not,
|
||||
// to prevent nodes in a NodeUnit to be checked for multiple times
|
||||
|
|
@ -181,7 +183,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
};
|
||||
|
||||
result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed,
|
||||
gen_metadef_name, NNAPI, kNnapiExecutionProvider);
|
||||
gen_metadef_name, NNAPI, kNnapiExecutionProvider, &node_unit_map);
|
||||
|
||||
// Generally, NNAPI supports sub-graphs with at least one non-constant initializer input and one output.
|
||||
// So far, we have a few cases that sub-graph has zero valid inputs, like `CastLike`
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
#include "core/providers/partitioning_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
|
@ -10,6 +13,7 @@
|
|||
|
||||
#include "core/framework/compute_capability.h"
|
||||
#include "core/framework/execution_provider.h"
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/providers/common.h"
|
||||
|
||||
|
|
@ -76,6 +80,11 @@ When selecting the next node to process, we first take:
|
|||
The remaining unsupported nodes mark the border of the current group so they will be processed later when we consider
|
||||
the next group.
|
||||
|
||||
If node_unit_map is provided, we process NodeUnit instances (a logical 'Node' that can be a single node or a
|
||||
QDQ node group) instead of individual Node instances. As an EP must take complete NodeUnit instances (i.e. it
|
||||
must not break up a QDQ node group by taking a subset of nodes in it), this granularity of processing is valid.
|
||||
It is required to ensure we do not break up a QDQ node unit during partitioning.
|
||||
|
||||
@param graph_viewer GraphViewer that IExecutionProvider::GetCapability is called with.
|
||||
@param is_node_supported_fn Callback to check whether a node is supported.
|
||||
@param on_group_closed_fn Callback to indicate a completed partition node group.
|
||||
|
|
@ -88,6 +97,7 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
|
|||
const IsNodeSupportedFn& is_node_supported_fn,
|
||||
const OnGroupClosedFn& on_group_closed_fn,
|
||||
const std::string& execution_provider_type,
|
||||
const std::unordered_map<const Node*, const NodeUnit*>* node_unit_map,
|
||||
bool debug_output) {
|
||||
#ifdef NDEBUG
|
||||
ORT_UNUSED_PARAMETER(debug_output);
|
||||
|
|
@ -111,7 +121,18 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
|
|||
// initialize in-degrees and find root nodes
|
||||
for (const auto& node_index : graph_viewer.GetNodesInTopologicalOrder()) {
|
||||
const auto& node = *graph_viewer.GetNode(node_index);
|
||||
const auto node_input_edge_count = node.GetInputEdgesCount();
|
||||
auto node_input_edge_count = node.GetInputEdgesCount();
|
||||
|
||||
if (node_unit_map != nullptr) {
|
||||
const auto& node_unit = node_unit_map->at(&node);
|
||||
if (&node_unit->GetNode() != &node) {
|
||||
// only process the target node
|
||||
continue;
|
||||
}
|
||||
|
||||
node_input_edge_count = node_unit->InputEdgeCount();
|
||||
}
|
||||
|
||||
in_degree.insert({node.Index(), node_input_edge_count});
|
||||
if (node_input_edge_count == 0) {
|
||||
nodes_to_process.push_back(&node);
|
||||
|
|
@ -151,6 +172,8 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
|
|||
}
|
||||
};
|
||||
|
||||
size_t num_nodes_processed = 0;
|
||||
|
||||
while (!nodes_to_process.empty() || !nodes_to_process_with_next_group.empty()) {
|
||||
if (nodes_to_process.empty()) {
|
||||
// we have processed all the nodes that we can while building this partition node group, start a new one
|
||||
|
|
@ -162,9 +185,13 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
|
|||
const Node& node = *nodes_to_process.front();
|
||||
nodes_to_process.pop_front();
|
||||
|
||||
const NodeUnit* node_unit = node_unit_map ? node_unit_map->at(&node) : nullptr;
|
||||
const bool is_qdq_node_unit = node_unit && node_unit->UnitType() == NodeUnit::Type::QDQGroup;
|
||||
|
||||
// a node that is already assigned to an EP other than current EP is unsupported
|
||||
const bool is_node_supported =
|
||||
(node.GetExecutionProviderType().empty() || node.GetExecutionProviderType() == execution_provider_type) && is_node_supported_fn(node);
|
||||
const bool is_node_supported = (node.GetExecutionProviderType().empty() ||
|
||||
node.GetExecutionProviderType() == execution_provider_type) &&
|
||||
is_node_supported_fn(node);
|
||||
|
||||
if (!is_node_supported && Contains(supported_group_border, &node)) {
|
||||
// an unsupported node on the border will be processed after the current partition node group
|
||||
|
|
@ -173,34 +200,62 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
|
|||
}
|
||||
|
||||
if (is_node_supported) {
|
||||
// add node to the partition node group
|
||||
supported_group.push_back(&node);
|
||||
if (is_qdq_node_unit) {
|
||||
// add DQ -> node -> Q for the node unit. must be in topological order
|
||||
for (const auto& dq : node_unit->GetDQNodes()) {
|
||||
supported_group.push_back(dq);
|
||||
}
|
||||
|
||||
// remove node from the border and add its outputs to the border
|
||||
supported_group.push_back(&node);
|
||||
|
||||
for (const auto& q : node_unit->GetQNodes()) {
|
||||
supported_group.push_back(q);
|
||||
}
|
||||
} else {
|
||||
supported_group.push_back(&node);
|
||||
}
|
||||
|
||||
// remove node from the border
|
||||
supported_group_border.erase(&node);
|
||||
|
||||
std::for_each(
|
||||
node.OutputNodesBegin(), node.OutputNodesEnd(),
|
||||
[&supported_group_border](const Node& output) {
|
||||
supported_group_border.insert(&output);
|
||||
});
|
||||
}
|
||||
|
||||
// adjust in-degrees of the node outputs and add any new nodes to process
|
||||
std::for_each(
|
||||
node.OutputNodesBegin(), node.OutputNodesEnd(),
|
||||
[&](const Node& output) {
|
||||
auto& output_node_in_degree = in_degree[output.Index()];
|
||||
--output_node_in_degree;
|
||||
// For each downstream node:
|
||||
// 1: add the downstream node to the border if the current node is supported
|
||||
// 2: adjust in-degrees of the nodes consuming the current node's outputs, and add any new nodes to process
|
||||
const auto process_downstream_node = [&](const Node& downstream_node) {
|
||||
if (is_node_supported) {
|
||||
supported_group_border.insert(&downstream_node);
|
||||
}
|
||||
|
||||
if (output_node_in_degree == 0) {
|
||||
nodes_to_process.push_back(&output);
|
||||
}
|
||||
});
|
||||
auto& downstream_node_in_degree = in_degree[downstream_node.Index()];
|
||||
--downstream_node_in_degree;
|
||||
|
||||
if (downstream_node_in_degree == 0) {
|
||||
nodes_to_process.push_back(&downstream_node);
|
||||
}
|
||||
};
|
||||
|
||||
if (node_unit_map) {
|
||||
std::for_each(node_unit->OutputEdgesBegin(), node_unit->OutputEdgesEnd(),
|
||||
[&](const Node::EdgeEnd& edge_end) {
|
||||
const Node& n = edge_end.GetNode();
|
||||
const NodeUnit& downstream_node_unit = *node_unit_map->at(&n);
|
||||
const Node& output = downstream_node_unit.GetNode();
|
||||
|
||||
process_downstream_node(output);
|
||||
});
|
||||
} else {
|
||||
std::for_each(node.OutputNodesBegin(), node.OutputNodesEnd(), process_downstream_node);
|
||||
}
|
||||
|
||||
++num_nodes_processed;
|
||||
}
|
||||
|
||||
close_group();
|
||||
|
||||
ORT_ENFORCE(num_nodes_processed == in_degree.size(),
|
||||
"Processed ", num_nodes_processed, " nodes. Expected to process ", in_degree.size());
|
||||
|
||||
return supported_groups;
|
||||
}
|
||||
} // namespace
|
||||
|
|
@ -318,11 +373,13 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
|
|||
const GenerateMetadefNameFn& generate_metadef_name_fn,
|
||||
const std::string& execution_provider_name,
|
||||
const std::string& execution_provider_type,
|
||||
const std::unordered_map<const Node*, const NodeUnit*>* node_unit_map,
|
||||
bool debug_output) {
|
||||
const auto groups = CreateSupportedPartitionNodeGroups(graph_viewer,
|
||||
is_node_supported_fn,
|
||||
on_partition_closed_fn,
|
||||
execution_provider_type,
|
||||
node_unit_map,
|
||||
debug_output);
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>> partitions{};
|
||||
|
|
@ -346,6 +403,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
|
|||
const GenerateMetadefNameFn& generate_metadef_name_fn,
|
||||
const std::string& execution_provider_name,
|
||||
const std::string& execution_provider_type,
|
||||
const std::unordered_map<const Node*, const NodeUnit*>* node_unit_map,
|
||||
bool debug_output) {
|
||||
const auto excluded_nodes = CreateExcludedNodeSet(graph_viewer, stop_ops);
|
||||
const bool check_excluded_nodes = !excluded_nodes.empty();
|
||||
|
|
@ -360,8 +418,11 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
|
|||
generate_metadef_name_fn,
|
||||
execution_provider_name,
|
||||
execution_provider_type,
|
||||
node_unit_map,
|
||||
debug_output);
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
|
|
@ -14,8 +17,9 @@
|
|||
namespace onnxruntime {
|
||||
struct ComputeCapability;
|
||||
class GraphViewer;
|
||||
class NodeArg;
|
||||
class Node;
|
||||
class NodeArg;
|
||||
class NodeUnit;
|
||||
|
||||
namespace utils {
|
||||
|
||||
|
|
@ -56,6 +60,8 @@ Create the supported partitions for the execution provider.
|
|||
@param generate_metadef_name_fn Callback to create the name for the MetaDef.
|
||||
@param execution_provider_name Name of execution provider creating the ComputeCapability instance.
|
||||
@param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance.
|
||||
@param node_unit_map Map of each Node in the graph_viewer to its NodeUnit. Provide if EP handles QDQ format models.
|
||||
Should be created by EP calling GetAllNodeUnits.
|
||||
@param debug_output Print diagnostic output about the partitions and reasons for partition breaks.
|
||||
No-op in a release build.
|
||||
|
||||
|
|
@ -68,6 +74,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
|
|||
const GenerateMetadefNameFn& generate_metadef_name_fn,
|
||||
const std::string& execution_provider_name,
|
||||
const std::string& execution_provider_type,
|
||||
const std::unordered_map<const Node*, const NodeUnit*>* node_unit_map = nullptr,
|
||||
bool debug_output = false);
|
||||
|
||||
/**
|
||||
|
|
@ -79,6 +86,8 @@ Create the supported partitions for the execution provider.
|
|||
@param generate_metadef_name Functor to create the name for the MetaDef.
|
||||
@param execution_provider_name Name of execution provider creating the ComputeCapability instance.
|
||||
@param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance.
|
||||
@param node_unit_map Map of each Node in the graph_viewer to its NodeUnit. Provide if EP handles QDQ format models.
|
||||
Should be created by EP calling GetAllNodeUnits.
|
||||
@param debug_output Print diagnostic output about the partitions and reasons for partition breaks.
|
||||
No-op in a release build.
|
||||
|
||||
|
|
@ -91,6 +100,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
|
|||
const GenerateMetadefNameFn& generate_metadef_name,
|
||||
const std::string& execution_provider_name,
|
||||
const std::string& execution_provider_type,
|
||||
const std::unordered_map<const Node*, const NodeUnit*>* node_unit_map = nullptr,
|
||||
bool debug_output = false);
|
||||
|
||||
/**
|
||||
|
|
@ -125,3 +135,5 @@ InlinedHashSet<const Node*> CreateExcludedNodeSet(const GraphViewer& graph_viewe
|
|||
const std::unordered_set<std::string>& stop_ops);
|
||||
} // namespace utils
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@
|
|||
#include "core/providers/qnn/builder/op_builder_factory.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
|
||||
#include "core/providers/qnn/builder/qnn_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -95,7 +97,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
|
|||
// valid throughout the lifetime of the ModelBuilder
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
|
||||
|
||||
// This name must be same with the EPContext node name
|
||||
const auto& graph_name = fused_node.Name();
|
||||
|
|
|
|||
|
|
@ -6,13 +6,13 @@
|
|||
#include <vector>
|
||||
|
||||
#include "core/common/status.h"
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/platform/ort_mutex.h"
|
||||
#include "core/providers/qnn/builder/qnn_def.h"
|
||||
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
|
||||
#include "core/providers/qnn/builder/qnn_backend_manager.h"
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace qnn {
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@
|
|||
#include "QnnInterface.h"
|
||||
#include "qnn_def.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@
|
|||
#include "core/session/onnxruntime_run_options_config_keys.h"
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "core/framework/kernel_registry.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
|
||||
#include "core/platform/env.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/partitioning_utils.h"
|
||||
|
|
@ -494,7 +496,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
|
|||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
|
||||
std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
|
||||
|
||||
const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, node_unit_holder.size(),
|
||||
is_qnn_ctx_model, logger);
|
||||
|
|
@ -534,44 +536,39 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
|
|||
size_t num_of_supported_nodes = 0;
|
||||
|
||||
// Create partitions from supported nodes.
|
||||
{
|
||||
std::vector<std::unique_ptr<ComputeCapability>> partitions = utils::CreateSupportedPartitions(graph_viewer,
|
||||
supported_nodes, {},
|
||||
gen_metadef_name, QNN,
|
||||
kQnnExecutionProvider,
|
||||
true);
|
||||
std::vector<std::unique_ptr<ComputeCapability>> partitions = utils::CreateSupportedPartitions(
|
||||
graph_viewer, supported_nodes, {}, gen_metadef_name, QNN, kQnnExecutionProvider, &node_unit_map, true);
|
||||
|
||||
// Filter out partitions that consist of a single QuantizeLinear or DequantizeLinear node.
|
||||
// We also count the number of supported nodes in all valid partitions.
|
||||
for (auto& partition : partitions) {
|
||||
bool is_valid_partition = true;
|
||||
size_t nodes_in_partition = 0;
|
||||
// Filter out partitions that consist of a single QuantizeLinear or DequantizeLinear node.
|
||||
// We also count the number of supported nodes in all valid partitions.
|
||||
for (auto& partition : partitions) {
|
||||
bool is_valid_partition = true;
|
||||
size_t nodes_in_partition = 0;
|
||||
|
||||
if (partition && partition->sub_graph) {
|
||||
nodes_in_partition = partition->sub_graph->nodes.size();
|
||||
if (partition && partition->sub_graph) {
|
||||
nodes_in_partition = partition->sub_graph->nodes.size();
|
||||
|
||||
if (nodes_in_partition == 1 && !is_qnn_ctx_model) {
|
||||
const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]);
|
||||
if (nodes_in_partition == 1 && !is_qnn_ctx_model) {
|
||||
const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]);
|
||||
|
||||
if (!node) {
|
||||
LOGS(logger, ERROR) << "QNN EP: Invalid node in partition of one node.";
|
||||
is_valid_partition = false;
|
||||
} else if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") {
|
||||
LOGS(logger, WARNING) << "QNN EP does not support a single Quantize/Dequantize node in a partition.";
|
||||
is_valid_partition = false;
|
||||
}
|
||||
if (!node) {
|
||||
LOGS(logger, ERROR) << "QNN EP: Invalid node in partition of one node.";
|
||||
is_valid_partition = false;
|
||||
} else if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") {
|
||||
LOGS(logger, WARNING) << "QNN EP does not support a single Quantize/Dequantize node in a partition.";
|
||||
is_valid_partition = false;
|
||||
}
|
||||
} else {
|
||||
LOGS(logger, ERROR) << "QNN EP: Invalid partition.";
|
||||
is_valid_partition = false;
|
||||
}
|
||||
} else {
|
||||
LOGS(logger, ERROR) << "QNN EP: Invalid partition.";
|
||||
is_valid_partition = false;
|
||||
}
|
||||
|
||||
if (is_valid_partition) {
|
||||
result.push_back(std::move(partition));
|
||||
num_of_supported_nodes += nodes_in_partition;
|
||||
}
|
||||
} // for
|
||||
}
|
||||
if (is_valid_partition) {
|
||||
result.push_back(std::move(partition));
|
||||
num_of_supported_nodes += nodes_in_partition;
|
||||
}
|
||||
} // for
|
||||
|
||||
const size_t num_of_partitions = result.size();
|
||||
const auto summary_msg = MakeString("Number of partitions supported by QNN EP: ", num_of_partitions,
|
||||
|
|
|
|||
|
|
@ -1,319 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "node_unit.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace {
|
||||
|
||||
enum class QLinearOpType : uint8_t {
|
||||
Unknown, // Unknown or not a linear quantized op
|
||||
DequantizeLinear,
|
||||
QuantizeLinear,
|
||||
QLinearConv,
|
||||
QLinearMatMul,
|
||||
QLinearAdd,
|
||||
QLinearSigmoid,
|
||||
QLinearAveragePool,
|
||||
QLinearMul,
|
||||
QLinearReduceMean,
|
||||
QLinearConcat,
|
||||
QLinearGlobalAveragePool,
|
||||
QLinearLeakyRelu,
|
||||
};
|
||||
|
||||
QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) {
|
||||
const auto& op_type = node.OpType();
|
||||
if (op_type == "DequantizeLinear")
|
||||
return QLinearOpType::DequantizeLinear;
|
||||
else if (op_type == "QuantizeLinear")
|
||||
return QLinearOpType::QuantizeLinear;
|
||||
else if (op_type == "QLinearConv")
|
||||
return QLinearOpType::QLinearConv;
|
||||
else if (op_type == "QLinearMatMul")
|
||||
return QLinearOpType::QLinearMatMul;
|
||||
else if (op_type == "QLinearAdd")
|
||||
return QLinearOpType::QLinearAdd;
|
||||
else if (op_type == "QLinearSigmoid")
|
||||
return QLinearOpType::QLinearSigmoid;
|
||||
else if (op_type == "QLinearAveragePool")
|
||||
return QLinearOpType::QLinearAveragePool;
|
||||
else if (op_type == "QLinearMul")
|
||||
return QLinearOpType::QLinearMul;
|
||||
else if (op_type == "QLinearReduceMean")
|
||||
return QLinearOpType::QLinearReduceMean;
|
||||
else if (op_type == "QLinearConcat")
|
||||
return QLinearOpType::QLinearConcat;
|
||||
else if (op_type == "QLinearGlobalAveragePool")
|
||||
return QLinearOpType::QLinearGlobalAveragePool;
|
||||
else if (op_type == "QLinearLeakyRelu")
|
||||
return QLinearOpType::QLinearLeakyRelu;
|
||||
|
||||
return QLinearOpType::Unknown;
|
||||
}
|
||||
|
||||
// Ops have 1 input
|
||||
bool IsUnaryQLinearOp(QLinearOpType type) {
|
||||
return type == QLinearOpType::QLinearSigmoid ||
|
||||
type == QLinearOpType::QLinearAveragePool ||
|
||||
type == QLinearOpType::QLinearGlobalAveragePool ||
|
||||
type == QLinearOpType::QLinearLeakyRelu ||
|
||||
type == QLinearOpType::QLinearReduceMean;
|
||||
}
|
||||
|
||||
// Ops have 2 inputs
|
||||
bool IsBinaryQLinearOp(QLinearOpType type) {
|
||||
return type == QLinearOpType::QLinearConv ||
|
||||
type == QLinearOpType::QLinearMatMul ||
|
||||
type == QLinearOpType::QLinearAdd ||
|
||||
type == QLinearOpType::QLinearMul;
|
||||
}
|
||||
|
||||
// Ops have 1 or more inputs
|
||||
bool IsVariadicQLinearOp(QLinearOpType type) {
|
||||
return type == QLinearOpType::QLinearConcat;
|
||||
}
|
||||
|
||||
const std::vector<const Node*> GetQDQIONodes(const GraphViewer& graph_viewer,
|
||||
const QDQ::NodeGroup& node_group, bool is_input) {
|
||||
std::vector<const Node*> io_nodes;
|
||||
const auto& src_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes;
|
||||
io_nodes.reserve(src_nodes.size());
|
||||
for (const auto& node_idx : src_nodes) {
|
||||
io_nodes.push_back(graph_viewer.GetNode(node_idx));
|
||||
}
|
||||
return io_nodes;
|
||||
}
|
||||
|
||||
// Get the input or output NodeUnitIODef(s) for the given QDQ NodeGroup
|
||||
std::vector<NodeUnitIODef> GetQDQIODefs(const Node& target_node, const QDQ::NodeGroup& node_group,
|
||||
bool is_input) {
|
||||
const auto& dq_or_q_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes;
|
||||
const auto target_node_io_defs = is_input ? target_node.InputDefs() : target_node.OutputDefs();
|
||||
const size_t target_node_io_defs_size = target_node_io_defs.size();
|
||||
|
||||
// Find all the quantized IO defs and indices (for the input to the target node)
|
||||
std::unordered_map<size_t, NodeUnitIODef> quantized_io_defs;
|
||||
quantized_io_defs.reserve(target_node_io_defs_size);
|
||||
|
||||
auto cur = is_input ? target_node.InputEdgesBegin() : target_node.OutputEdgesBegin();
|
||||
auto end = is_input ? target_node.InputEdgesEnd() : target_node.OutputEdgesEnd();
|
||||
for (; cur != end; ++cur) {
|
||||
const Node& node = cur->GetNode();
|
||||
|
||||
// If we can find the node index in the dq or q nodes, then this is a quantize node (can be DQ or Q depends on is_input)
|
||||
if (std::find(dq_or_q_nodes.cbegin(), dq_or_q_nodes.cend(), node.Index()) != dq_or_q_nodes.cend()) {
|
||||
const auto node_inputs = node.InputDefs();
|
||||
// quantization scale and zp are always the input[1, 2]
|
||||
NodeUnitIODef::QuantParam quant_param{
|
||||
*node_inputs[1],
|
||||
node_inputs.size() == 3 ? node_inputs[2] : nullptr};
|
||||
if (is_input) {
|
||||
// DQ is input to the target node, use the DstArgIndex
|
||||
auto idx = cur->GetDstArgIndex();
|
||||
// This is a DQ node, we are using x, x_scale, x_zp (input[0, 1, 2])
|
||||
quantized_io_defs.insert({idx, NodeUnitIODef{*node_inputs[0], quant_param}});
|
||||
} else {
|
||||
// Q is output of the target node, use the SrcArgIndex
|
||||
auto idx = cur->GetSrcArgIndex();
|
||||
// This is a Q node, we are using y (output[0]), y_scale, y_zp (input[1, 2])
|
||||
const auto node_outputs = node.OutputDefs();
|
||||
quantized_io_defs.insert({idx, NodeUnitIODef{*node_outputs[0], quant_param}});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Construct the IODefs for this QDQ NodeGroup
|
||||
std::vector<NodeUnitIODef> io_defs;
|
||||
io_defs.reserve(target_node_io_defs_size);
|
||||
for (size_t i = 0; i < target_node_io_defs_size; i++) {
|
||||
// If we can find the NodeUnitIODef for this index, this is a quantized input
|
||||
if (quantized_io_defs.find(i) != quantized_io_defs.cend()) {
|
||||
io_defs.push_back(std::move(quantized_io_defs.at(i)));
|
||||
} else {
|
||||
// This is a regular input
|
||||
io_defs.push_back({*target_node_io_defs[i], std::nullopt});
|
||||
}
|
||||
}
|
||||
|
||||
return io_defs;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
NodeUnit::NodeUnit(const Node& node)
|
||||
: target_node_(node),
|
||||
type_(Type::SingleNode) {
|
||||
InitForSingleNode();
|
||||
}
|
||||
|
||||
NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group)
|
||||
: q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)},
|
||||
dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)},
|
||||
target_node_(*graph_viewer.GetNode(node_group.target_node)),
|
||||
type_(Type::QDQGroup),
|
||||
inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)},
|
||||
outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} {
|
||||
ORT_THROW_IF_ERROR(QDQ::ValidateNodeGroupDQNodes(graph_viewer, target_node_, dq_nodes_));
|
||||
}
|
||||
|
||||
const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); }
|
||||
const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); }
|
||||
const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); }
|
||||
int NodeUnit::SinceVersion() const noexcept { return target_node_.SinceVersion(); }
|
||||
NodeIndex NodeUnit::Index() const noexcept { return target_node_.Index(); }
|
||||
const Path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); }
|
||||
ProviderType NodeUnit::GetExecutionProviderType() const noexcept { return target_node_.GetExecutionProviderType(); }
|
||||
|
||||
void NodeUnit::InitForSingleNode() {
|
||||
const auto& input_defs = target_node_.InputDefs();
|
||||
const auto& output_defs = target_node_.OutputDefs();
|
||||
auto qlinear_type = GetQLinearOpType(target_node_);
|
||||
if (qlinear_type == QLinearOpType::Unknown ||
|
||||
IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support
|
||||
// Not a Qlinear op, add all inputs / outputs
|
||||
auto add_all_io = [](std::vector<NodeUnitIODef>& defs,
|
||||
const ConstPointerContainer<std::vector<NodeArg*>>& node_defs) {
|
||||
defs.reserve(node_defs.size());
|
||||
|
||||
for (const auto def : node_defs) {
|
||||
defs.push_back(NodeUnitIODef{*def, std::nullopt});
|
||||
}
|
||||
};
|
||||
add_all_io(inputs_, input_defs);
|
||||
add_all_io(outputs_, output_defs);
|
||||
} else if (IsUnaryQLinearOp(qlinear_type)) {
|
||||
// Unary QLinear Op has 5 inputs
|
||||
// x, x_scale, x_zp, y_scale, y_zp (optional)
|
||||
inputs_.push_back(NodeUnitIODef{
|
||||
*input_defs[0],
|
||||
NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}});
|
||||
|
||||
outputs_.push_back(NodeUnitIODef{
|
||||
*output_defs[0],
|
||||
NodeUnitIODef::QuantParam{*input_defs[3],
|
||||
input_defs.size() > 4
|
||||
? input_defs[4]
|
||||
: nullptr}});
|
||||
} else if (IsBinaryQLinearOp(qlinear_type)) {
|
||||
// Binary QLinear Op has 9 inputs
|
||||
// x1, x1_scale, x1_zp, x2/w, x2_scale, x2_zp, y_scale , y_zp, B
|
||||
inputs_.push_back(NodeUnitIODef{
|
||||
*input_defs[0],
|
||||
NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}});
|
||||
inputs_.push_back(NodeUnitIODef{
|
||||
*input_defs[3],
|
||||
NodeUnitIODef::QuantParam{*input_defs[4], input_defs[5]}});
|
||||
|
||||
if (input_defs.size() == 9) { // has Bias
|
||||
inputs_.push_back(NodeUnitIODef{
|
||||
*input_defs[8],
|
||||
std::nullopt}); // for Bias the scale and zp are optional
|
||||
}
|
||||
|
||||
outputs_.push_back(NodeUnitIODef{
|
||||
*output_defs[0],
|
||||
NodeUnitIODef::QuantParam{*input_defs[6], input_defs[7]}});
|
||||
} else if (qlinear_type == QLinearOpType::DequantizeLinear) {
|
||||
// DequantizeLinear has 3 inputs
|
||||
// x, x_scale, x_zp
|
||||
// output is not quantized
|
||||
inputs_.push_back(NodeUnitIODef{
|
||||
*input_defs[0],
|
||||
NodeUnitIODef::QuantParam{*input_defs[1],
|
||||
input_defs.size() == 3
|
||||
? input_defs[2]
|
||||
: nullptr}});
|
||||
outputs_.push_back(NodeUnitIODef{*output_defs[0], std::nullopt});
|
||||
} else if (qlinear_type == QLinearOpType::QuantizeLinear) {
|
||||
// QuantizeLinear the input is not quantized and has 3 inputs
|
||||
// x, y_scale, y_zp (optional)
|
||||
// The output is quantized
|
||||
inputs_.push_back(NodeUnitIODef{*input_defs[0], std::nullopt});
|
||||
outputs_.push_back(NodeUnitIODef{
|
||||
*output_defs[0],
|
||||
NodeUnitIODef::QuantParam{*input_defs[1],
|
||||
input_defs.size() == 3
|
||||
? input_defs[2]
|
||||
: nullptr}});
|
||||
} else {
|
||||
ORT_THROW("The QLinear op [", static_cast<uint8_t>(qlinear_type), "] is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
Node::EdgeConstIterator NodeUnit::OutputEdgesBegin(size_t index) const {
|
||||
if (type_ == Type::SingleNode) {
|
||||
ORT_ENFORCE(index == 0, "invalid output node index");
|
||||
return target_node_.OutputEdgesBegin();
|
||||
} else {
|
||||
ORT_ENFORCE(index < q_nodes_.size(), "invalid output node index");
|
||||
return q_nodes_[index]->OutputEdgesBegin();
|
||||
}
|
||||
}
|
||||
|
||||
Node::EdgeConstIterator NodeUnit::OutputEdgesEnd(size_t index) const {
|
||||
if (type_ == Type::SingleNode) {
|
||||
ORT_ENFORCE(index == 0, "invalid output node index");
|
||||
return target_node_.OutputEdgesEnd();
|
||||
} else {
|
||||
ORT_ENFORCE(index < q_nodes_.size(), "invalid output node index");
|
||||
return q_nodes_[index]->OutputEdgesEnd();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const Node*> NodeUnit::GetAllNodesInGroup() const noexcept {
|
||||
std::vector<const Node*> all_nodes = dq_nodes_;
|
||||
all_nodes.push_back(&target_node_);
|
||||
all_nodes.insert(all_nodes.end(), q_nodes_.begin(), q_nodes_.end());
|
||||
return all_nodes;
|
||||
}
|
||||
|
||||
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
|
||||
GetAllNodeUnits(const GraphViewer& graph_viewer) {
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
|
||||
const auto add_node_unit_to_map = [&](const std::vector<NodeIndex>& node_indices, const NodeUnit* node_unit) {
|
||||
for (const auto& node_idx : node_indices) {
|
||||
const auto* node = graph_viewer.GetNode(node_idx);
|
||||
node_unit_map.insert({node, node_unit});
|
||||
}
|
||||
};
|
||||
|
||||
// Get QDQ NodeUnits first
|
||||
QDQ::SelectorManager selector_mgr;
|
||||
const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer);
|
||||
|
||||
for (const auto& qdq_selection : qdq_selections) {
|
||||
auto qdq_unit = std::make_unique<NodeUnit>(graph_viewer, qdq_selection);
|
||||
|
||||
// Fill the node to node_unit map for all nodes in the QDQ Group
|
||||
add_node_unit_to_map(qdq_selection.dq_nodes, qdq_unit.get());
|
||||
add_node_unit_to_map(qdq_selection.q_nodes, qdq_unit.get());
|
||||
add_node_unit_to_map({qdq_selection.target_node}, qdq_unit.get());
|
||||
|
||||
node_unit_holder.push_back(std::move(qdq_unit));
|
||||
}
|
||||
|
||||
// Get the left over SingleNode NodeUnits
|
||||
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (const auto node_idx : node_indices) {
|
||||
const auto* node(graph_viewer.GetNode(node_idx));
|
||||
|
||||
// This is already part of a QDQ NodeUnit
|
||||
if (node_unit_map.find(node) != node_unit_map.cend())
|
||||
continue;
|
||||
|
||||
auto node_unit = std::make_unique<NodeUnit>(*node);
|
||||
node_unit_map[node] = node_unit.get();
|
||||
node_unit_holder.push_back(std::move(node_unit));
|
||||
}
|
||||
|
||||
return std::make_pair(std::move(node_unit_holder), std::move(node_unit_map));
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -4,12 +4,12 @@
|
|||
|
||||
#include "utils.h"
|
||||
|
||||
#include <core/common/safeint.h>
|
||||
#include <core/framework/tensorprotoutils.h>
|
||||
#include <core/graph/graph.h>
|
||||
#include <core/providers/common.h>
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/graph/graph.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/providers/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "utils.h"
|
||||
#include "core/providers/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace utils {
|
||||
|
|
@ -23,6 +23,5 @@ common::Status OutputOptionalWithoutDataHelper(const ONNX_NAMESPACE::TypeProto&
|
|||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace utils
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -6,12 +6,12 @@
|
|||
#include <unordered_map>
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/framework/op_node_proto_helper.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cpu/nn/pool_attributes.h"
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
#include "core/providers/xnnpack/detail/utils.h"
|
||||
|
||||
// each operator provides a helper to check if supported
|
||||
|
|
|
|||
|
|
@ -6,15 +6,15 @@
|
|||
#include <vector>
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/graph/indexed_sub_graph.h"
|
||||
#include "core/graph/node_attr_utils.h"
|
||||
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
#include "onnx/defs/attr_proto_util.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
|
||||
#include "onnx/defs/attr_proto_util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace xnnpack {
|
||||
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/graph/indexed_sub_graph.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
|
||||
#include "xnnpack.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -6,17 +6,17 @@
|
|||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "core/graph/function_utils.h"
|
||||
#include "xnnpack_execution_provider.h"
|
||||
#include "detail/utils.h"
|
||||
#include "detail/node_support_checker.h"
|
||||
|
||||
#include "core/framework/compute_capability.h"
|
||||
#include "core/framework/kernel_registry.h"
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/graph/function_utils.h"
|
||||
#include "core/session/onnxruntime_session_options_config_keys.h"
|
||||
|
||||
#include "xnnpack_init.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
|
||||
#include "core/providers/xnnpack/xnnpack_execution_provider.h"
|
||||
#include "core/providers/xnnpack/detail/utils.h"
|
||||
#include "core/providers/xnnpack/detail/node_support_checker.h"
|
||||
#include "core/providers/xnnpack/xnnpack_init.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -268,7 +268,7 @@ std::vector<std::unique_ptr<ComputeCapability>> XnnpackExecutionProvider::GetCap
|
|||
// Get all the NodeUnits in the GraphViewer so we can check if something is in a QDQ node group
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph);
|
||||
|
||||
// This holds the result of whether a NodeUnit is supported or not,
|
||||
// to prevent nodes in a NodeUnit being checked for multiple times
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "test_fp16.h"
|
||||
#include <iomanip>
|
||||
|
||||
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/compute_capability.h"
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
|
@ -9,7 +10,6 @@
|
|||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/providers/partitioning_utils.h"
|
||||
#include "core/session/onnxruntime_session_options_config_keys.h"
|
||||
#include "core/session/environment.h"
|
||||
|
|
@ -30,10 +30,6 @@
|
|||
#pragma warning(disable : 4127)
|
||||
#endif // #if defined(_MSC_VER)
|
||||
|
||||
#ifdef USE_NNAPI
|
||||
#include "core/providers/shared/node_unit/node_unit.h"
|
||||
#endif // #ifdef USE_NNAPI
|
||||
|
||||
struct QDQOpKeys {
|
||||
const char* quantize_linear;
|
||||
const char* dequantize_linear;
|
||||
|
|
@ -3243,14 +3239,14 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) {
|
|||
ASSERT_EQ(std::vector<NodeIndex>({4}), qdq_group.q_nodes);
|
||||
}
|
||||
|
||||
// The function GetAllNodeUnits is enabled for NNAPI EP only for now
|
||||
#ifdef USE_NNAPI
|
||||
// The function GetAllNodeUnits is used by NNAPI, XNNPACK and QNN
|
||||
#if defined(USE_NNAPI) || defined(USE_QNN) || defined(USE_XNNPACK)
|
||||
{
|
||||
// Get all the NodeUnits in the graph_viewer
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
|
||||
std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(whole_graph_viewer);
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer);
|
||||
|
||||
// We should get a single QDQ Node unit in the result
|
||||
ASSERT_EQ(1, node_unit_holder.size());
|
||||
|
|
@ -3288,7 +3284,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) {
|
|||
verify_io_def(qdq_node_unit.Inputs()[2], *whole_graph_viewer.GetNode(2)); // DQ_bias
|
||||
verify_io_def(qdq_node_unit.Outputs()[0], *whole_graph_viewer.GetNode(4)); // Q_output
|
||||
}
|
||||
#endif // #ifdef USE_NNAPI
|
||||
#endif // defined(USE_NNAPI) || defined(USE_QNN) || defined(USE_XNNPACK)
|
||||
|
||||
// Create a graph viewer covers part of the graph
|
||||
// Make sure the qdq conv selector will fail for the partial graph
|
||||
|
|
|
|||
|
|
@ -220,6 +220,7 @@ InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer&
|
|||
auto compile_capabilities = utils::CreateSupportedPartitions(graph_viewer, supported_compiled_nodes, stop_ops_,
|
||||
generate_metadef_name, ep_name_,
|
||||
onnxruntime::utils::kInternalTestingExecutionProvider,
|
||||
/*QDQ NodeUnit map*/ nullptr,
|
||||
debug_output_);
|
||||
|
||||
if (!static_capabilities.empty()) {
|
||||
|
|
|
|||
174
onnxruntime/test/providers/partitioning_utils_test.cc
Normal file
174
onnxruntime/test/providers/partitioning_utils_test.cc
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/framework/node_unit.h"
|
||||
#include "core/framework/compute_capability.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
|
||||
#include "core/providers/partitioning_utils.h"
|
||||
|
||||
#include "test/optimizer/graph_transform_test_builder.h"
|
||||
#include "test/optimizer/qdq_test_utils.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
#include "test/util/include/test_utils.h"
|
||||
#include "test/util/include/test/test_environment.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
// Test handling of a DQ node that is connected to an initializer at the start of the graph, but not used
|
||||
// in a QDQ node group until after an unsupported node in the graph. If we do not process QDQ node units
|
||||
// correctly this DQ will incorrectly be in the first partition, with the rest of the QDQ node group in
|
||||
// the second partition.
|
||||
TEST(PartitioningUtilsTest, TestQDQHandling) {
|
||||
constexpr const ORTCHAR_T* model_uri = ORT_TSTR("testdata/ort_github_issue_19590.onnx");
|
||||
auto& logger = DefaultLoggingManager().DefaultLogger();
|
||||
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, logger));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
GraphViewer graph_viewer = GraphViewer(graph);
|
||||
|
||||
// we want everything but the Cast in the test model to be supported
|
||||
const auto is_node_supported = [&](const Node& node) -> bool {
|
||||
return node.OpType() != "Cast";
|
||||
};
|
||||
|
||||
const auto on_group_closed = [&](const std::vector<const Node*>& /*group*/) -> bool {
|
||||
return true;
|
||||
};
|
||||
|
||||
const auto gen_metadef_name = [&]() {
|
||||
static int metadef_id = 0;
|
||||
return "TestMetaDef_" + std::to_string(metadef_id++);
|
||||
};
|
||||
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
|
||||
|
||||
auto result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed,
|
||||
gen_metadef_name, "TEST", kCpuExecutionProvider, &node_unit_map,
|
||||
true);
|
||||
|
||||
// we should have 2 supported partitions, split by the Cast node.
|
||||
// the first should have the Mul and NOT the DQ for the initializer if everything worked correctly.
|
||||
ASSERT_EQ(result.size(), size_t(2)) << "Expected 2 partitions";
|
||||
ASSERT_EQ(result[0]->sub_graph->nodes.size(), size_t(1)) << "First partition should only have the Mul and not a DQ";
|
||||
ASSERT_EQ(result[1]->sub_graph->nodes.size(), size_t(5)); // everything else except the unsupported Cast
|
||||
}
|
||||
|
||||
/// Check that CreateSupportedPartitions processes all nodes without error.
|
||||
static void CheckAllNodesProcessed(const std::function<void(ModelTestBuilder&)>& build_model) {
|
||||
auto& logger = DefaultLoggingManager().DefaultLogger();
|
||||
const std::unordered_map<std::string, int> domain_to_version = {{"", 15}};
|
||||
|
||||
Model model("PartitioningUtils_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
domain_to_version, {}, logger);
|
||||
|
||||
Graph& graph = model.MainGraph();
|
||||
ModelTestBuilder helper(graph);
|
||||
build_model(helper);
|
||||
helper.SetGraphOutputs();
|
||||
ASSERT_STATUS_OK(model.MainGraph().Resolve());
|
||||
|
||||
GraphViewer graph_viewer = GraphViewer(graph);
|
||||
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
|
||||
|
||||
const auto is_node_supported = [&](const Node& /*node*/) -> bool {
|
||||
return true;
|
||||
};
|
||||
|
||||
const auto on_group_closed = [&](const std::vector<const Node*>& /*group*/) -> bool {
|
||||
return true;
|
||||
};
|
||||
|
||||
const auto gen_metadef_name = [&]() {
|
||||
static int metadef_id = 0;
|
||||
return "TestMetaDef_" + std::to_string(metadef_id++);
|
||||
};
|
||||
|
||||
auto result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed,
|
||||
gen_metadef_name, "TEST", kCpuExecutionProvider, &node_unit_map,
|
||||
true);
|
||||
|
||||
// the 'real' test is that CreateSupportedPartitions doesn't throw due to a mismatch with expected vs processed nodes
|
||||
// as all ops are supported there should only ever be 1 partition
|
||||
ASSERT_EQ(result.size(), size_t(1)) << "Expected 1 partition";
|
||||
}
|
||||
|
||||
TEST(PartitioningUtilsTest, TestHandlingQDQNodeUnitWithNoQNodes) {
|
||||
// build graph with QDQ node unit for logical operator (Equal) that has no Q node and a downstream node (Cast).
|
||||
auto build_model = [](ModelTestBuilder& builder) {
|
||||
constexpr uint8_t zero_point = 0;
|
||||
constexpr float qdq_scale = 0.0038f;
|
||||
const std::vector<int64_t> input_shape = {1, 3, 8, 8};
|
||||
|
||||
auto* input0 = builder.MakeInput<float>(input_shape, -1.0f, 1.0f);
|
||||
auto* input1 = builder.MakeInput<float>(input_shape, -1.0f, 1.0f);
|
||||
auto* output = builder.MakeOutput();
|
||||
|
||||
// input -> Q -> DQ -> Op
|
||||
auto* qdq0_output = AddQDQNodePair<uint8_t>(builder, input0, qdq_scale, zero_point);
|
||||
auto* qdq1_output = AddQDQNodePair<uint8_t>(builder, input1, qdq_scale, zero_point);
|
||||
|
||||
// Equal ->
|
||||
auto* equal_output = builder.MakeIntermediate();
|
||||
builder.AddNode("Equal", {qdq0_output, qdq1_output}, {equal_output});
|
||||
|
||||
// -> Cast -> output
|
||||
Node& cast_node = builder.AddNode("Cast", {equal_output}, {output});
|
||||
cast_node.AddAttribute("to",
|
||||
static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT));
|
||||
};
|
||||
|
||||
CheckAllNodesProcessed(build_model);
|
||||
}
|
||||
|
||||
// TopK produces 2 outputs, one of which is used in a QDQ node group (Q of values output)
|
||||
// and the other (indices output) is not. A downstream node consuming the indices output has an edge from the target
|
||||
// node and not a Q node.
|
||||
// To process this correctly, the QDQ NodeUnit must return output edges for both the Q node/s of the values output,
|
||||
// and the downstream node (Cast in this case) of the indices output.
|
||||
TEST(PartitioningUtilsTest, TestQDQNodeGroupWithOutputFromTargetNode) {
|
||||
const auto build_model = [](ModelTestBuilder& builder) {
|
||||
constexpr uint8_t zero_point = 0;
|
||||
constexpr float qdq_scale = 0.0038f;
|
||||
const std::vector<int64_t> input_shape = {1, 3, 8, 8};
|
||||
|
||||
auto* input0 = builder.MakeInput<float>(input_shape, -1.0f, 1.0f);
|
||||
|
||||
// input -> Q -> DQ ->
|
||||
auto* qdq0_output = AddQDQNodePair<uint8_t>(builder, input0, qdq_scale, zero_point);
|
||||
|
||||
// K input
|
||||
NodeArg* k_input = builder.MakeInput<int64_t>({1}, {10});
|
||||
|
||||
// TopK op
|
||||
NodeArg* values_output = builder.MakeIntermediate();
|
||||
NodeArg* indices_output = builder.MakeIntermediate();
|
||||
builder.AddNode("TopK", {qdq0_output, k_input}, {values_output, indices_output});
|
||||
|
||||
// values -> Q -> DQ -> graph output
|
||||
AddQDQNodePairWithOutputAsGraphOutput<uint8_t>(builder, values_output, qdq_scale, zero_point);
|
||||
|
||||
// indices -> Cast -> graph output
|
||||
auto* i_output = builder.MakeOutput();
|
||||
Node& cast_node = builder.AddNode("Cast", {indices_output}, {i_output});
|
||||
const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32;
|
||||
cast_node.AddAttribute("to", static_cast<int64_t>(dst_type));
|
||||
};
|
||||
|
||||
CheckAllNodesProcessed(build_model);
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
BIN
onnxruntime/test/testdata/ort_github_issue_19590.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/ort_github_issue_19590.onnx
vendored
Normal file
Binary file not shown.
77
onnxruntime/test/testdata/ort_github_issue_19590.py
vendored
Normal file
77
onnxruntime/test/testdata/ort_github_issue_19590.py
vendored
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import onnx
|
||||
from onnx import TensorProto, helper
|
||||
|
||||
# graph with a QDQ MatMul node unit where one input is and initializer -> DQ and the other is on a path that
|
||||
# contains a supported node followed by an unsupported node followed by the DQ -> MatMul.
|
||||
# The DQ of the initializer is prior to the unsupported node. If the partitioning utils do not process the QDQ node
|
||||
# unit together, the DQ for the initializer and the first supported node will be in the first partition, which
|
||||
# incorrectly breaks up the QDQ node unit.
|
||||
graph_proto = helper.make_graph(
|
||||
[
|
||||
# DQ of initializer for MatMul B input
|
||||
helper.make_node(
|
||||
"DequantizeLinear",
|
||||
inputs=["matmul_b_uint8", "scale0"],
|
||||
outputs=["dq_matmul_b"],
|
||||
name="dq_matmul_b",
|
||||
),
|
||||
# Treat as supported
|
||||
helper.make_node(
|
||||
"Mul",
|
||||
inputs=["input:0", "scale_input"],
|
||||
outputs=["mul:0"],
|
||||
name="mul0",
|
||||
),
|
||||
# Treat as unsupported
|
||||
helper.make_node("Cast", inputs=["mul:0"], outputs=["mul_uint8"], name="cast0", to=2),
|
||||
# DQ of MatMul A input
|
||||
helper.make_node(
|
||||
"DequantizeLinear",
|
||||
inputs=["mul_uint8", "scale1"],
|
||||
outputs=["dq_matmul_a"],
|
||||
name="dq_matmul_a",
|
||||
),
|
||||
# MatMul
|
||||
helper.make_node(
|
||||
"MatMul",
|
||||
inputs=[
|
||||
"dq_matmul_a",
|
||||
"dq_matmul_b",
|
||||
],
|
||||
outputs=["matmul_ab"],
|
||||
name="matmul_ab",
|
||||
),
|
||||
# Q
|
||||
helper.make_node(
|
||||
"QuantizeLinear",
|
||||
inputs=["matmul_ab", "scale2"],
|
||||
outputs=["q_matmul_ab"],
|
||||
name="q_matmul_ab",
|
||||
),
|
||||
# DQ for model output
|
||||
helper.make_node(
|
||||
"DequantizeLinear",
|
||||
inputs=["q_matmul_ab", "scale2"],
|
||||
outputs=["out:0"],
|
||||
name="dq_graph_output",
|
||||
),
|
||||
],
|
||||
"Main_graph",
|
||||
[
|
||||
helper.make_tensor_value_info("input:0", TensorProto.FLOAT, [3, 2]),
|
||||
],
|
||||
[
|
||||
helper.make_tensor_value_info("out:0", TensorProto.FLOAT, [3, 2]),
|
||||
],
|
||||
[
|
||||
helper.make_tensor("scale0", TensorProto.FLOAT, [1], [20.0]),
|
||||
helper.make_tensor("scale1", TensorProto.FLOAT, [1], [30.0]),
|
||||
helper.make_tensor("scale2", TensorProto.FLOAT, [1], [40.0]),
|
||||
helper.make_tensor("matmul_b_uint8", TensorProto.UINT8, [2, 2], [1, 2, 3, 4]),
|
||||
helper.make_tensor("scale_input", TensorProto.FLOAT, [2], [3.0, 4.0]),
|
||||
],
|
||||
)
|
||||
|
||||
model = helper.make_model(graph_proto)
|
||||
onnx.checker.check_model(model, True)
|
||||
onnx.save(model, "ort_github_issue_19590.onnx")
|
||||
Loading…
Reference in a new issue