mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-23 02:38:28 +00:00
Add QDQ::Selector::Select to use const GraphViewer instead of mutable Graph (#9621)
* Move qdq selector to use const GraphViewer * minor update * Move qdq logic from NodeSelector to QDQ Selectors * Fix build break * Move selector result to NodesToOptimizeIndexes * fix build break * address CR comments * move indexes -> indices * Pass graph_viewer to avoid recreating many times * Update after merge master * update graph viewer remarks * update comments * Add ut for new qdq selector logic * Increase minimal binary size limit * UT minor update * Address CR comments
This commit is contained in:
parent
65590b049c
commit
a70ae24475
17 changed files with 402 additions and 146 deletions
|
|
@ -80,6 +80,10 @@ class GraphViewer {
|
|||
*/
|
||||
const std::vector<const NodeArg*>& GetOutputs() const noexcept;
|
||||
|
||||
/** Returns true if one or more of the Node outputs are Graph outputs.
|
||||
*/
|
||||
bool NodeProducesGraphOutput(const Node& node) const;
|
||||
|
||||
/** Gets all ValueInfo NodeArg instances in the Graph.
|
||||
@remarks NOT filtered using filter_info_.
|
||||
*/
|
||||
|
|
@ -146,6 +150,16 @@ class GraphViewer {
|
|||
*/
|
||||
bool IsConstantInitializer(const std::string& name, bool check_outer_scope) const;
|
||||
|
||||
/** returns the initializer's TensorProto if 'name' is an initializer, is constant and
|
||||
cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned.
|
||||
@param check_outer_scope If true and the graph is a subgraph,
|
||||
check ancestor graph/s for 'name' if not found in 'graph'.
|
||||
@remarks This function will return the result from GetConstantInitializer of the underlying Graph,
|
||||
if a const initializer is part of the underlying Graph but not part of this GraphViewer,
|
||||
it will still be returned instead of nullptr
|
||||
*/
|
||||
const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const;
|
||||
|
||||
/** Get the Node containing this Graph if IsSubgraph is true. Returns nullptr otherwise. */
|
||||
const Node* ParentNode() const noexcept { return graph_->ParentNode(); }
|
||||
|
||||
|
|
|
|||
|
|
@ -199,6 +199,17 @@ const std::vector<const NodeArg*>& GraphViewer::GetOutputs() const noexcept {
|
|||
: filtered_node_outputs_;
|
||||
}
|
||||
|
||||
bool GraphViewer::NodeProducesGraphOutput(const Node& node) const {
|
||||
const auto& outputs = GetOutputs();
|
||||
auto end_outputs = outputs.cend();
|
||||
for (auto output_def : node.OutputDefs()) {
|
||||
if (std::find(outputs.cbegin(), end_outputs, output_def) != end_outputs) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get graph value infos.
|
||||
const std::unordered_set<const NodeArg*>& GraphViewer::GetValueInfo() const noexcept {
|
||||
return graph_->GetValueInfo();
|
||||
|
|
@ -262,7 +273,12 @@ bool GraphViewer::IsSubgraph() const {
|
|||
}
|
||||
|
||||
bool GraphViewer::IsConstantInitializer(const std::string& name, bool check_outer_scope) const {
|
||||
return graph_->GetConstantInitializer(name, check_outer_scope) != nullptr;
|
||||
return GetConstantInitializer(name, check_outer_scope) != nullptr;
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* GraphViewer::GetConstantInitializer(const std::string& initializer_name,
|
||||
bool check_outer_scope) const {
|
||||
return graph_->GetConstantInitializer(initializer_name, check_outer_scope);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@
|
|||
namespace onnxruntime {
|
||||
|
||||
/** Struct to serialize the node indices in an ORT format model.
|
||||
Use EmptyNodeIndex for nullptr entries in the vectors for missing optional inputs
|
||||
Use kEmptyNodeIndex for nullptr entries in the vectors for missing optional inputs
|
||||
*/
|
||||
struct NodesToOptimizeIndices {
|
||||
/** Index value that represents an empty node.
|
||||
|
|
|
|||
|
|
@ -19,7 +19,11 @@ static bool CanNodePropagate(const Node& node) {
|
|||
}
|
||||
|
||||
static bool TryCancelOutDQQPair(Graph& graph, Node& dq_node, Node& q_node) {
|
||||
if (!QDQ::IsQDQPairSupported(graph, q_node, dq_node)) {
|
||||
auto get_const_initializer = [&graph](const std::string& initializer_name) {
|
||||
return graph.GetConstantInitializer(initializer_name, true);
|
||||
};
|
||||
|
||||
if (!QDQ::IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph.ModelPath())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -36,7 +40,8 @@ static bool TryCancelOutDQQPair(Graph& graph, Node& dq_node, Node& q_node) {
|
|||
if (dq_input_edge_0) {
|
||||
input_edge_info.first = dq_input_edge_0->GetNode().Index();
|
||||
input_edge_info.second = dq_input_edge_0->GetSrcArgIndex();
|
||||
graph.RemoveEdge(dq_input_edge_0->GetNode().Index(), dq_node.Index(), dq_input_edge_0->GetSrcArgIndex(), dq_input_edge_0->GetDstArgIndex());
|
||||
graph.RemoveEdge(dq_input_edge_0->GetNode().Index(), dq_node.Index(),
|
||||
dq_input_edge_0->GetSrcArgIndex(), dq_input_edge_0->GetDstArgIndex());
|
||||
}
|
||||
|
||||
graph_utils::RemoveNodeOutputEdges(graph, dq_node); // Remove DQ node output edges
|
||||
|
|
@ -45,11 +50,13 @@ static bool TryCancelOutDQQPair(Graph& graph, Node& dq_node, Node& q_node) {
|
|||
graph_utils::RemoveNodeOutputEdges(graph, q_node); // Remove Q node output edges
|
||||
for (auto& output_edge : output_edges) {
|
||||
// set input NodeArg of Q's children to the 1st input of DQ
|
||||
graph.GetNode(output_edge.dst_node)->MutableInputDefs()[output_edge.dst_arg_index] = dq_node.MutableInputDefs()[0];
|
||||
graph.GetNode(output_edge.dst_node)->MutableInputDefs()[output_edge.dst_arg_index] =
|
||||
dq_node.MutableInputDefs()[0];
|
||||
|
||||
// add edge between parent of DQ to children of Q
|
||||
if (input_edge_info.second != -1) {
|
||||
graph.AddEdge(input_edge_info.first, output_edge.dst_node, input_edge_info.second, output_edge.dst_arg_index);
|
||||
graph.AddEdge(input_edge_info.first, output_edge.dst_node,
|
||||
input_edge_info.second, output_edge.dst_arg_index);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,10 @@
|
|||
namespace onnxruntime {
|
||||
namespace QDQ {
|
||||
|
||||
bool IsQDQPairSupported(const Graph& graph, const Node& q_node, const Node& dq_node) {
|
||||
bool IsQDQPairSupported(
|
||||
const Node& q_node, const Node& dq_node,
|
||||
const std::function<const ONNX_NAMESPACE::TensorProto*(const std::string&)>& get_const_initializer,
|
||||
const Path& model_path) {
|
||||
ConstPointerContainer<std::vector<NodeArg*>> dq_input_defs = dq_node.InputDefs();
|
||||
ConstPointerContainer<std::vector<NodeArg*>> q_input_defs = q_node.InputDefs();
|
||||
|
||||
|
|
@ -30,13 +33,13 @@ bool IsQDQPairSupported(const Graph& graph, const Node& q_node, const Node& dq_n
|
|||
|
||||
// if Q/DQ scale and zero point are not constant, return false
|
||||
const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, dq_input_defs[InputIndex::SCALE_ID]->Name());
|
||||
get_const_initializer(dq_input_defs[InputIndex::SCALE_ID]->Name());
|
||||
const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, q_input_defs[InputIndex::SCALE_ID]->Name());
|
||||
get_const_initializer(q_input_defs[InputIndex::SCALE_ID]->Name());
|
||||
const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, dq_input_defs[InputIndex::ZERO_POINT_ID]->Name());
|
||||
get_const_initializer(dq_input_defs[InputIndex::ZERO_POINT_ID]->Name());
|
||||
const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, q_input_defs[InputIndex::ZERO_POINT_ID]->Name());
|
||||
get_const_initializer(q_input_defs[InputIndex::ZERO_POINT_ID]->Name());
|
||||
if (nullptr == q_zp_tensor_proto ||
|
||||
nullptr == dq_zp_tensor_proto ||
|
||||
nullptr == q_scale_tensor_proto ||
|
||||
|
|
@ -45,10 +48,10 @@ bool IsQDQPairSupported(const Graph& graph, const Node& q_node, const Node& dq_n
|
|||
}
|
||||
|
||||
// check Q/DQ have same scale and zero point
|
||||
Initializer q_zp(*q_zp_tensor_proto, graph.ModelPath());
|
||||
Initializer q_scale(*q_scale_tensor_proto, graph.ModelPath());
|
||||
Initializer dq_zp(*dq_zp_tensor_proto, graph.ModelPath());
|
||||
Initializer dq_scale(*dq_scale_tensor_proto, graph.ModelPath());
|
||||
Initializer q_zp(*q_zp_tensor_proto, model_path);
|
||||
Initializer q_scale(*q_scale_tensor_proto, model_path);
|
||||
Initializer dq_zp(*dq_zp_tensor_proto, model_path);
|
||||
Initializer dq_scale(*dq_scale_tensor_proto, model_path);
|
||||
|
||||
return q_zp.data_type() == dq_zp.data_type() &&
|
||||
*q_zp.data<int8_t>() == *dq_zp.data<int8_t>() &&
|
||||
|
|
|
|||
|
|
@ -3,10 +3,17 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
namespace ONNX_NAMESPACE {
|
||||
class TensorProto;
|
||||
}
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class Graph;
|
||||
class Node;
|
||||
class Path;
|
||||
|
||||
namespace QDQ {
|
||||
|
||||
|
|
@ -24,7 +31,10 @@ enum InputIndex : int {
|
|||
// 1. Q/DQ doesn't have optional input.
|
||||
// 2. scale and zero point is constant scalar
|
||||
// 3. Q and DQ have same scale and zero point
|
||||
bool IsQDQPairSupported(const Graph& graph, const Node& q_node, const Node& dq_node);
|
||||
bool IsQDQPairSupported(
|
||||
const Node& q_node, const Node& dq_node,
|
||||
const std::function<const ONNX_NAMESPACE::TensorProto*(const std::string&)>& get_const_initializer,
|
||||
const Path& model_path);
|
||||
|
||||
} // namespace QDQ
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -21,7 +21,23 @@ int NumActualValues(const Node& node, bool input) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
bool BaseSelector::CheckQDQNodes(const Graph& graph, const Node& node,
|
||||
static std::vector<const Node*> FindQDQNodes(const GraphViewer& graph_viewer, const Node& node, bool find_dq_nodes) {
|
||||
// First get all the upstream (DQ) or downstream (Q) nodes
|
||||
std::vector<const Node*> nodes =
|
||||
find_dq_nodes ? graph_utils::FindParentsByType(node, QDQ::DQOpName)
|
||||
: graph_utils::FindChildrenByType(node, QDQ::QOpName);
|
||||
|
||||
// Remove all the nodes which are not in the graph_viewer
|
||||
nodes.erase(std::remove_if(nodes.begin(), nodes.end(),
|
||||
[&graph_viewer](const Node* _node) {
|
||||
return _node == nullptr || graph_viewer.GetNode(_node->Index()) == nullptr;
|
||||
}),
|
||||
nodes.end());
|
||||
|
||||
return nodes;
|
||||
}
|
||||
|
||||
bool BaseSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes,
|
||||
int num_dq_inputs) const {
|
||||
|
|
@ -31,63 +47,67 @@ bool BaseSelector::CheckQDQNodes(const Graph& graph, const Node& node,
|
|||
|
||||
int num_outputs = NumActualValues(node, false); // number of outputs that exist
|
||||
|
||||
// The input is a Graph Viewer, so cannot use graph_utils or optimizer_utils
|
||||
return num_dq_inputs == gsl::narrow_cast<int>(dq_nodes.size()) &&
|
||||
num_outputs == gsl::narrow_cast<int>(q_nodes.size()) &&
|
||||
optimizer_utils::CheckOutputEdges(graph, node, q_nodes.size());
|
||||
q_nodes.size() == node.GetOutputEdgesCount() &&
|
||||
!graph_viewer.NodeProducesGraphOutput(node);
|
||||
}
|
||||
|
||||
bool BaseSelector::Select(Graph& graph, const Node& node, std::unique_ptr<NodesToOptimize>& selection) const {
|
||||
std::vector<const Node*> dq_nodes = graph_utils::FindParentsByType(node, QDQ::DQOpName);
|
||||
std::vector<const Node*> q_nodes = graph_utils::FindChildrenByType(node, QDQ::QOpName);
|
||||
|
||||
if (!Check(graph, node, dq_nodes, q_nodes)) {
|
||||
return false;
|
||||
std::optional<NodeGroup> BaseSelector::GetQDQSelection(const GraphViewer& graph_viewer, const Node& node) const {
|
||||
std::vector<const Node*> dq_nodes = FindQDQNodes(graph_viewer, node, true);
|
||||
std::vector<const Node*> q_nodes = FindQDQNodes(graph_viewer, node, false);
|
||||
if (!Check(graph_viewer, node, dq_nodes, q_nodes)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto get_mutable_node = [&graph](const Node* node) {
|
||||
// we use the non-const GetNode to convert the const Node* to Node*
|
||||
return graph.GetNode(node->Index());
|
||||
};
|
||||
NodeGroup node_group;
|
||||
node_group.dq_nodes.reserve(dq_nodes.size());
|
||||
node_group.q_nodes.reserve(q_nodes.size());
|
||||
node_group.target_node = node.Index();
|
||||
auto get_node_idx = [&](const Node* n) { return n->Index(); };
|
||||
std::transform(dq_nodes.begin(), dq_nodes.end(), std::back_inserter(node_group.dq_nodes), get_node_idx);
|
||||
std::transform(q_nodes.begin(), q_nodes.end(), std::back_inserter(node_group.q_nodes), get_node_idx);
|
||||
return node_group;
|
||||
}
|
||||
|
||||
NodesToOptimizeBuilder builder;
|
||||
builder.input_nodes.reserve(dq_nodes.size());
|
||||
builder.output_nodes.reserve(q_nodes.size());
|
||||
|
||||
for (const Node* dq_node : dq_nodes) {
|
||||
builder.input_nodes.push_back(dq_node != nullptr ? get_mutable_node(dq_node) : nullptr);
|
||||
std::optional<NodesToOptimizeIndices> BaseSelector::Select(const GraphViewer& graph_viewer, const Node& node) const {
|
||||
const auto qdq_group = GetQDQSelection(graph_viewer, node);
|
||||
if (!qdq_group.has_value()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
builder.target_node = get_mutable_node(&node);
|
||||
|
||||
for (const Node* q_node : q_nodes) {
|
||||
builder.output_nodes.push_back(get_mutable_node(q_node));
|
||||
}
|
||||
NodesToOptimizeIndicesBuilder builder;
|
||||
builder.input_nodes = qdq_group->dq_nodes;
|
||||
builder.output_nodes = qdq_group->q_nodes;
|
||||
builder.target_node = qdq_group->target_node;
|
||||
|
||||
UpdateBuilder(builder);
|
||||
|
||||
selection = builder.Build();
|
||||
|
||||
return true;
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
bool DropDQDNodesSelector::Check(const Graph& graph,
|
||||
bool DropDQDNodesSelector::Check(const GraphViewer& graph_viewer,
|
||||
const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes) const {
|
||||
if (!CheckQDQNodes(graph, node, dq_nodes, q_nodes, 1)) {
|
||||
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const Node& dq_node = *dq_nodes.front();
|
||||
const Node& q_node = *q_nodes.front();
|
||||
|
||||
return IsQDQPairSupported(graph, q_node, dq_node);
|
||||
auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) {
|
||||
return graph_viewer.GetConstantInitializer(initializer_name, true);
|
||||
};
|
||||
|
||||
return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath());
|
||||
}
|
||||
|
||||
bool UnarySelector::Check(const Graph& graph, const Node& node,
|
||||
bool UnarySelector::Check(const GraphViewer& graph_viewer, const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes) const {
|
||||
if (!CheckQDQNodes(graph, node, dq_nodes, q_nodes, 1)) {
|
||||
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -100,11 +120,11 @@ bool UnarySelector::Check(const Graph& graph, const Node& node,
|
|||
(int8_allowed_ && dt_output == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)));
|
||||
}
|
||||
|
||||
bool BinarySelector::Check(const Graph& graph,
|
||||
bool BinarySelector::Check(const GraphViewer& graph_viewer,
|
||||
const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes) const {
|
||||
if (!CheckQDQNodes(graph, node, dq_nodes, q_nodes)) {
|
||||
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -116,11 +136,11 @@ bool BinarySelector::Check(const Graph& graph,
|
|||
dt_input_1 == dt_output;
|
||||
}
|
||||
|
||||
bool VariadicSelector::Check(const Graph& graph,
|
||||
bool VariadicSelector::Check(const GraphViewer& graph_viewer,
|
||||
const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes) const {
|
||||
if (!CheckQDQNodes(graph, node, dq_nodes, q_nodes)) {
|
||||
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -136,15 +156,15 @@ bool VariadicSelector::Check(const Graph& graph,
|
|||
return dt_input == dt_output;
|
||||
}
|
||||
|
||||
void VariadicSelector::UpdateBuilder(NodesToOptimizeBuilder& builder) const {
|
||||
void VariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const {
|
||||
builder.num_input_defs = 1; // set to 1 as the first input is variadic
|
||||
}
|
||||
|
||||
bool ConvSelector::Check(const Graph& graph,
|
||||
bool ConvSelector::Check(const GraphViewer& graph_viewer,
|
||||
const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes) const {
|
||||
if (!CheckQDQNodes(graph, node, dq_nodes, q_nodes)) {
|
||||
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -164,11 +184,11 @@ bool ConvSelector::Check(const Graph& graph,
|
|||
return dt_bias == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32;
|
||||
}
|
||||
|
||||
void ConvSelector::UpdateBuilder(NodesToOptimizeBuilder& builder) const {
|
||||
builder.input_nodes.resize(3); // add nullptr for bias if missing
|
||||
void ConvSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const {
|
||||
builder.input_nodes.resize(3, NodesToOptimizeIndices::kEmptyNodeIndex);
|
||||
}
|
||||
|
||||
bool MatMulSelector::Check(const Graph& graph,
|
||||
bool MatMulSelector::Check(const GraphViewer& graph_viewer,
|
||||
const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes) const {
|
||||
|
|
@ -181,7 +201,7 @@ bool MatMulSelector::Check(const Graph& graph,
|
|||
|
||||
if (qlinear) {
|
||||
// QLinearMatMul
|
||||
if (!CheckQDQNodes(graph, node, dq_nodes, q_nodes)) {
|
||||
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,39 +12,51 @@ class Graph;
|
|||
class Node;
|
||||
|
||||
namespace QDQ {
|
||||
|
||||
// Struct to represent a DQ->Op->Q node group
|
||||
struct NodeGroup {
|
||||
std::vector<NodeIndex> dq_nodes;
|
||||
std::vector<NodeIndex> q_nodes;
|
||||
NodeIndex target_node;
|
||||
};
|
||||
|
||||
// Base QDQ checker. Finds and provides the DQ and Q nodes to the operator specific checkers, as the QDQ optimizations
|
||||
// always involve those nodes.
|
||||
class BaseSelector : public NodeSelector {
|
||||
public:
|
||||
bool Select(Graph& graph, const Node& node, std::unique_ptr<NodesToOptimize>& selection) const override;
|
||||
std::optional<NodesToOptimizeIndices> Select(const GraphViewer& graph_viewer, const Node& node) const override;
|
||||
|
||||
// This is a QDQ Selectors only function, will return QDQ::NodeGroup instead of NodesToOptimizeIndices
|
||||
// Can be used in QDQ handling in EPs such as NNAPI
|
||||
std::optional<NodeGroup> GetQDQSelection(const GraphViewer& graph_viewer, const Node& node) const;
|
||||
|
||||
protected:
|
||||
BaseSelector() = default;
|
||||
|
||||
// base check that we have the expected number of QDQ inputs/outputs, and `node` isn't producing a graph output.
|
||||
// num_dq_inputs defaults to the number of inputs `node` has if not explicitly specified
|
||||
bool CheckQDQNodes(const Graph& graph, const Node& node,
|
||||
bool CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes,
|
||||
int num_dq_inputs = -1) const;
|
||||
|
||||
private:
|
||||
// derived classes should implement this check
|
||||
bool virtual Check(const Graph& graph, const Node& node,
|
||||
bool virtual Check(const GraphViewer& graph_viewer, const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes) const = 0;
|
||||
|
||||
// override if you need to adjust the values in NodesToOptimize.
|
||||
// e.g. add entries for missing optional DQ inputs or set num_inputs to handle variadic inputs
|
||||
// Called post-Check, if Check returned `true`
|
||||
virtual void UpdateBuilder(NodesToOptimizeBuilder&) const {}
|
||||
virtual void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const {}
|
||||
};
|
||||
|
||||
// Single DQ -> node that does not change data -> Q.
|
||||
// Zero point and scale are constant scalars and must match
|
||||
class DropDQDNodesSelector : public BaseSelector {
|
||||
private:
|
||||
bool Check(const Graph& graph, const Node& node,
|
||||
bool Check(const GraphViewer& graph_viewer, const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes) const override;
|
||||
};
|
||||
|
|
@ -55,7 +67,7 @@ class UnarySelector : public BaseSelector {
|
|||
UnarySelector(bool int8_allowed = false) : int8_allowed_{int8_allowed} {}
|
||||
|
||||
private:
|
||||
bool Check(const Graph& graph, const Node& node,
|
||||
bool Check(const GraphViewer& graph_viewer, const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes) const override;
|
||||
|
||||
|
|
@ -64,31 +76,36 @@ class UnarySelector : public BaseSelector {
|
|||
|
||||
// 2 DQ nodes providing input -> node -> Q
|
||||
class BinarySelector : public BaseSelector {
|
||||
bool Check(const Graph& graph, const Node& node,
|
||||
bool Check(const GraphViewer& graph_viewer, const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes) const override;
|
||||
};
|
||||
|
||||
// Variadic DQ nodes -> node -> Q
|
||||
class VariadicSelector : public BaseSelector {
|
||||
bool Check(const Graph& graph, const Node& node,
|
||||
public:
|
||||
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
|
||||
|
||||
private:
|
||||
bool Check(const GraphViewer& graph_viewer, const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes) const override;
|
||||
void UpdateBuilder(NodesToOptimizeBuilder&) const override;
|
||||
};
|
||||
|
||||
// DQ nodes for X, W and optionally B -> node -> Q
|
||||
class ConvSelector : public BaseSelector {
|
||||
bool Check(const Graph& graph, const Node& node,
|
||||
public:
|
||||
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
|
||||
|
||||
private:
|
||||
bool Check(const GraphViewer& graph_viewer, const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes) const override;
|
||||
|
||||
void UpdateBuilder(NodesToOptimizeBuilder&) const override;
|
||||
};
|
||||
|
||||
// 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not
|
||||
class MatMulSelector : public BaseSelector {
|
||||
bool Check(const Graph& graph, const Node& node,
|
||||
bool Check(const GraphViewer& graph_viewer, const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes) const override;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -8,6 +8,16 @@ using namespace ::onnxruntime::common;
|
|||
namespace onnxruntime {
|
||||
|
||||
namespace {
|
||||
|
||||
// if the last input/output in num_io is for the variadic input/output,
|
||||
// the variadic input/output could have zero or more values
|
||||
// so we need to special case the zero and count that as one.
|
||||
int NumIOEntries(bool variadic_io, int num_io, int num_variadic_io) {
|
||||
return variadic_io
|
||||
? num_io + std::max(1, num_variadic_io) - 1
|
||||
: num_io;
|
||||
}
|
||||
|
||||
// Move or remove an edge.
|
||||
// - moves edges from src+src_slot to dest node+dest_slot if provided.
|
||||
// - remove edges for the src+src_slot if dest+dest_slot not provided.
|
||||
|
|
@ -116,11 +126,11 @@ Node* GetNodeByNodeIndex(Graph& graph, NodeIndex idx, bool& missing) {
|
|||
return node;
|
||||
}
|
||||
|
||||
bool GetNodesByNodeIndex(Graph& graph, const std::vector<NodeIndex>& indexes, std::vector<Node*>& nodes) {
|
||||
nodes.reserve(indexes.size());
|
||||
bool GetNodesByNodeIndex(Graph& graph, const std::vector<NodeIndex>& indices, std::vector<Node*>& nodes) {
|
||||
nodes.reserve(indices.size());
|
||||
bool missing = false;
|
||||
|
||||
for (auto iter = indexes.cbegin(), end = indexes.cend(); iter != end; ++iter) {
|
||||
for (auto iter = indices.cbegin(), end = indices.cend(); iter != end; ++iter) {
|
||||
nodes.push_back(GetNodeByNodeIndex(graph, *iter, missing));
|
||||
|
||||
// bail if we're missing a node
|
||||
|
|
@ -137,6 +147,50 @@ bool GetNodesByNodeIndex(Graph& graph, const std::vector<NodeIndex>& indexes, st
|
|||
// Selections
|
||||
//
|
||||
|
||||
// Helper to create the NodesToOptimizeIndices
|
||||
// specify num_input_defs/num_output_defs if the last input/output is variadic (default is non-variadic)
|
||||
static NodesToOptimizeIndices GetNodesToOptimizeIndices(
|
||||
const std::vector<NodeIndex>& input_nodes, NodeIndex target_node, const std::vector<NodeIndex>& output_nodes,
|
||||
int num_input_defs, int num_output_defs) {
|
||||
int num_inputs = num_input_defs == -1 ? gsl::narrow_cast<int>(input_nodes.size()) : num_input_defs;
|
||||
int num_outputs = num_output_defs == -1 ? gsl::narrow_cast<int>(output_nodes.size()) : num_output_defs;
|
||||
bool variadic_input = false;
|
||||
bool variadic_output = false;
|
||||
int num_variadic_inputs = 0;
|
||||
int num_variadic_outputs = 0;
|
||||
|
||||
if (num_input_defs != -1) {
|
||||
variadic_input = true;
|
||||
num_variadic_inputs = gsl::narrow_cast<int>(input_nodes.size()) - num_input_defs + 1;
|
||||
}
|
||||
|
||||
if (num_output_defs != -1) {
|
||||
variadic_output = true;
|
||||
num_variadic_outputs = gsl::narrow_cast<int>(output_nodes.size()) - num_output_defs + 1;
|
||||
}
|
||||
|
||||
std::vector<NodeIndex> node_indices;
|
||||
node_indices.reserve(NumIOEntries(variadic_input, num_inputs, num_variadic_inputs) + 1 +
|
||||
NumIOEntries(variadic_output, num_outputs, num_variadic_outputs));
|
||||
std::copy(input_nodes.begin(), input_nodes.end(), std::back_inserter(node_indices));
|
||||
node_indices.push_back(target_node);
|
||||
std::copy(output_nodes.begin(), output_nodes.end(), std::back_inserter(node_indices));
|
||||
|
||||
std::for_each(node_indices.cbegin(), node_indices.cend(), [](NodeIndex node_idx) {
|
||||
ORT_ENFORCE(node_idx <= NodesToOptimizeIndices::kEmptyNodeIndex,
|
||||
"Node index value is too large to save to ORT format model: ", node_idx);
|
||||
});
|
||||
|
||||
return NodesToOptimizeIndices{std::move(node_indices), num_inputs, num_outputs,
|
||||
variadic_input, variadic_output,
|
||||
num_variadic_inputs, num_variadic_outputs};
|
||||
}
|
||||
|
||||
NodesToOptimizeIndices NodesToOptimizeIndicesBuilder::Build() const {
|
||||
ORT_ENFORCE(target_node != NodesToOptimizeIndices::kEmptyNodeIndex, "A target node must be set.");
|
||||
return GetNodesToOptimizeIndices(input_nodes, target_node, output_nodes, num_input_defs, num_output_defs);
|
||||
}
|
||||
|
||||
NodesToOptimize::NodesToOptimize(const std::vector<Node*>& input_nodes,
|
||||
Node& target_node,
|
||||
const std::vector<Node*>& output_nodes,
|
||||
|
|
@ -160,41 +214,39 @@ NodesToOptimize::NodesToOptimize(const std::vector<Node*>& input_nodes,
|
|||
}
|
||||
|
||||
NodesToOptimize::NodesToOptimize(Graph& graph,
|
||||
const NodesToOptimizeIndices& indexes)
|
||||
: num_inputs{indexes.num_inputs},
|
||||
num_outputs{indexes.num_outputs} {
|
||||
bool missing_nodes = GetNodesByNodeIndex(graph, indexes.nodes, nodes_);
|
||||
const NodesToOptimizeIndices& indices)
|
||||
: num_inputs{indices.num_inputs},
|
||||
num_outputs{indices.num_outputs},
|
||||
variadic_input_{indices.variadic_input},
|
||||
variadic_output_{indices.variadic_output},
|
||||
num_variadic_inputs_{indices.num_variadic_inputs},
|
||||
num_variadic_outputs_{indices.num_variadic_outputs} {
|
||||
bool missing_nodes = !GetNodesByNodeIndex(graph, indices.nodes, nodes_);
|
||||
if (missing_nodes) {
|
||||
nodes_.clear(); // this will result in IsValid returning false
|
||||
}
|
||||
}
|
||||
|
||||
NodesToOptimizeIndices NodesToOptimize::ToIndices() const {
|
||||
NodesToOptimizeIndices indexes;
|
||||
|
||||
indexes.nodes.reserve(nodes_.size());
|
||||
std::for_each(nodes_.cbegin(), nodes_.cend(), [&indexes](const Node* node) {
|
||||
std::vector<NodeIndex> node_indices;
|
||||
node_indices.reserve(nodes_.size());
|
||||
std::for_each(nodes_.cbegin(), nodes_.cend(), [&node_indices](const Node* node) {
|
||||
const NodeIndex node_idx = node != nullptr ? node->Index() : NodesToOptimizeIndices::kEmptyNodeIndex;
|
||||
ORT_ENFORCE(node_idx <= NodesToOptimizeIndices::kEmptyNodeIndex,
|
||||
"Node index value is too large to save to ORT format model: ", node_idx);
|
||||
indexes.nodes.push_back(node_idx);
|
||||
node_indices.push_back(node_idx);
|
||||
});
|
||||
|
||||
indexes.num_inputs = num_inputs;
|
||||
indexes.num_outputs = num_outputs;
|
||||
indexes.variadic_input = variadic_input_;
|
||||
indexes.variadic_output = variadic_output_;
|
||||
indexes.num_variadic_inputs = num_variadic_inputs_;
|
||||
indexes.num_variadic_outputs = num_variadic_outputs_;
|
||||
|
||||
return indexes;
|
||||
return NodesToOptimizeIndices{std::move(node_indices), num_inputs, num_outputs,
|
||||
variadic_input_, variadic_output_,
|
||||
num_variadic_inputs_, num_variadic_outputs_};
|
||||
}
|
||||
|
||||
std::vector<Node*> NodesToOptimize::Inputs(const std::vector<int>& indexes, bool required) const {
|
||||
std::vector<Node*> NodesToOptimize::Inputs(const std::vector<int>& indices, bool required) const {
|
||||
std::vector<Node*> results;
|
||||
results.reserve(NumInputEntries());
|
||||
|
||||
for (auto idx : indexes) {
|
||||
for (auto idx : indices) {
|
||||
if (idx == num_inputs - 1 && HasVariadicInput()) {
|
||||
for (int i = 0, end = NumVariadicInputs(); i < end; ++i) {
|
||||
results.push_back(GetNode(idx + i, required));
|
||||
|
|
@ -207,14 +259,14 @@ std::vector<Node*> NodesToOptimize::Inputs(const std::vector<int>& indexes, bool
|
|||
return results;
|
||||
}
|
||||
|
||||
std::vector<Node*> NodesToOptimize::Outputs(const std::vector<int>& indexes, bool required) const {
|
||||
std::vector<Node*> NodesToOptimize::Outputs(const std::vector<int>& indices, bool required) const {
|
||||
std::vector<Node*> results;
|
||||
results.reserve(NumOutputEntries());
|
||||
|
||||
// offset by all the inputs and the target node
|
||||
const int offset = NumInputEntries() + 1;
|
||||
|
||||
for (auto idx : indexes) {
|
||||
for (auto idx : indices) {
|
||||
if (idx == num_outputs - 1 && HasVariadicOutput()) {
|
||||
for (int i = 0, end = NumVariadicOutputs(); i < end; ++i) {
|
||||
results.push_back(GetNode(offset + idx + i, required));
|
||||
|
|
@ -236,6 +288,14 @@ std::vector<Node*> NodesToOptimize::GetNodesAtLocation(const NodeLocation& locat
|
|||
return {&Target()};
|
||||
};
|
||||
|
||||
int NodesToOptimize::NumInputEntries() const {
|
||||
return NumIOEntries(variadic_input_, num_inputs, num_variadic_inputs_);
|
||||
}
|
||||
|
||||
int NodesToOptimize::NumOutputEntries() const {
|
||||
return NumIOEntries(variadic_output_, num_outputs, num_variadic_outputs_);
|
||||
}
|
||||
|
||||
//
|
||||
// Actions
|
||||
//
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class NodesToOptimize {
|
|||
|
||||
// construct from saved NodeIndex values. IsValid() will return false if one or more nodes were missing.
|
||||
// Use NodesToOptimizeIndices::kEmptyNodeIndex for nullptr entries in the vectors for missing optional inputs
|
||||
NodesToOptimize(Graph& graph, const NodesToOptimizeIndices& node_indexes);
|
||||
NodesToOptimize(Graph& graph, const NodesToOptimizeIndices& node_indices);
|
||||
|
||||
NodesToOptimizeIndices ToIndices() const;
|
||||
|
||||
|
|
@ -72,15 +72,15 @@ class NodesToOptimize {
|
|||
bool IsValid() const { return !nodes_.empty(); }
|
||||
|
||||
// fetch an input.
|
||||
// valid indexes are 0 to num_inputs - 1 if no variadic inputs.
|
||||
// if there are variadic inputs, valid indexes are 0 to num_inputs + num_extra_variadic_inputs - 1
|
||||
// valid indices are 0 to num_inputs - 1 if no variadic inputs.
|
||||
// if there are variadic inputs, valid indices are 0 to num_inputs + num_extra_variadic_inputs - 1
|
||||
// e.g. 3 inputs. last is variadic with 3 values. num_inputs=3 num_extra_variadic_inputs=2 for a total of 5 inputs.
|
||||
Node* Input(int idx, bool required = true) const {
|
||||
return GetNode(idx, required);
|
||||
}
|
||||
|
||||
// inputs filtered by index. includes all variadic.
|
||||
std::vector<Node*> Inputs(const std::vector<int>& indexes, bool required = true) const;
|
||||
std::vector<Node*> Inputs(const std::vector<int>& indices, bool required = true) const;
|
||||
|
||||
Node& Target() const {
|
||||
return *GetNode(NumInputEntries() + 0, /*required*/ true);
|
||||
|
|
@ -91,7 +91,7 @@ class NodesToOptimize {
|
|||
}
|
||||
|
||||
// outputs filtered by index. includes all variadic.
|
||||
std::vector<Node*> Outputs(const std::vector<int>& indexes, bool required = true) const;
|
||||
std::vector<Node*> Outputs(const std::vector<int>& indices, bool required = true) const;
|
||||
|
||||
// Get the Node or Nodes (if variadic) at a specific index.
|
||||
std::vector<Node*> GetNodesAtLocation(const NodeLocation& location, bool required = true) const;
|
||||
|
|
@ -109,12 +109,8 @@ class NodesToOptimize {
|
|||
return node;
|
||||
}
|
||||
|
||||
// if the last input in num_inputs is for the variadic input, the variadic input could have zero or more values
|
||||
// so we need to special case the zero and count that as one. same for outputs
|
||||
int NumInputEntries() const { return variadic_input_ ? num_inputs + std::max(1, num_variadic_inputs_) - 1
|
||||
: num_inputs; }
|
||||
int NumOutputEntries() const { return variadic_output_ ? num_outputs + std::max(1, num_variadic_outputs_) - 1
|
||||
: num_outputs; }
|
||||
int NumInputEntries() const;
|
||||
int NumOutputEntries() const;
|
||||
|
||||
bool variadic_input_{false}; // is last input variadic
|
||||
bool variadic_output_{false};
|
||||
|
|
@ -123,19 +119,16 @@ class NodesToOptimize {
|
|||
std::vector<Node*> nodes_;
|
||||
};
|
||||
|
||||
// Helper to build a NodesToOptimize instance
|
||||
// Helper to build a NodesToOptimizeIndices instance
|
||||
// Use in selector to incrementally add pieces
|
||||
struct NodesToOptimizeBuilder {
|
||||
std::vector<Node*> input_nodes;
|
||||
Node* target_node{nullptr};
|
||||
std::vector<Node*> output_nodes;
|
||||
struct NodesToOptimizeIndicesBuilder {
|
||||
std::vector<NodeIndex> input_nodes;
|
||||
NodeIndex target_node{NodesToOptimizeIndices::kEmptyNodeIndex};
|
||||
std::vector<NodeIndex> output_nodes;
|
||||
int num_input_defs{-1};
|
||||
int num_output_defs{-1};
|
||||
|
||||
std::unique_ptr<NodesToOptimize> Build() {
|
||||
ORT_ENFORCE(target_node != nullptr, "A target node must be set.");
|
||||
return std::make_unique<NodesToOptimize>(input_nodes, *target_node, output_nodes, num_input_defs, num_output_defs);
|
||||
}
|
||||
NodesToOptimizeIndices Build() const;
|
||||
};
|
||||
|
||||
//
|
||||
|
|
|
|||
|
|
@ -51,13 +51,10 @@ void SelectorsAndActions::RegisterSelectorAndAction(const std::string& name,
|
|||
ORT_IGNORE_RETURN_VALUE(selectors_and_actions_map_.emplace(name, std::move(entry)));
|
||||
}
|
||||
|
||||
// check if the node matches any of the registered operators.
|
||||
// if it does, run the Selector.
|
||||
// if that selects nodes, run the Action.
|
||||
Status SelectorActionTransformer::MatchAndProcess(Graph& graph, Node& node, bool& modified,
|
||||
Status SelectorActionTransformer::MatchAndProcess(Graph& graph, const GraphViewer& graph_viewer,
|
||||
Node& node, bool& modified,
|
||||
const logging::Logger& logger) const {
|
||||
Status status = Status::OK();
|
||||
|
||||
do {
|
||||
// TODO: for now this just needs to support ONNX ops. If we ever had a transformer that was going to
|
||||
// target non-ONNX ops we'd need to rework a few things to include the op domain in the matches
|
||||
|
|
@ -80,20 +77,23 @@ Status SelectorActionTransformer::MatchAndProcess(Graph& graph, Node& node, bool
|
|||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<NodesToOptimize> node_group;
|
||||
if (!selector_and_action.selector->Select(graph, node, node_group)) {
|
||||
const auto node_selection_opt = selector_and_action.selector->Select(graph_viewer, node);
|
||||
if (!node_selection_opt.has_value()) {
|
||||
break;
|
||||
}
|
||||
const auto& node_selection = *node_selection_opt;
|
||||
|
||||
LOGS(logger, VERBOSE) << "Matched " << node.OpType();
|
||||
|
||||
NodesToOptimize node_group(graph, node_selection);
|
||||
|
||||
if (runtime_optimization_save_context_.has_value()) {
|
||||
#if defined(ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)
|
||||
const auto& action = *selector_and_action.action;
|
||||
|
||||
Action::SavedState action_saved_state{};
|
||||
status = action.RunForSave(graph, *node_group, *runtime_optimization_save_context_, action_saved_state,
|
||||
modified);
|
||||
status = action.RunForSave(graph, node_group, *runtime_optimization_save_context_, action_saved_state,
|
||||
modified);
|
||||
if (!status.IsOK()) {
|
||||
break;
|
||||
}
|
||||
|
|
@ -101,7 +101,7 @@ Status SelectorActionTransformer::MatchAndProcess(Graph& graph, Node& node, bool
|
|||
graph.MutableRuntimeOptimizations().AddRecord(
|
||||
Name(),
|
||||
RuntimeOptimizationRecord{selector_and_action.name,
|
||||
node_group->ToIndices(),
|
||||
node_selection,
|
||||
action_saved_state.produced_nodes});
|
||||
#else
|
||||
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAILED,
|
||||
|
|
@ -109,7 +109,7 @@ Status SelectorActionTransformer::MatchAndProcess(Graph& graph, Node& node, bool
|
|||
break;
|
||||
#endif
|
||||
} else {
|
||||
status = selector_and_action.action->Run(graph, *node_group);
|
||||
status = selector_and_action.action->Run(graph, node_group);
|
||||
if (!status.IsOK()) {
|
||||
break;
|
||||
}
|
||||
|
|
@ -192,7 +192,7 @@ Status SelectorActionTransformer::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
// TODO: use GraphTransformer::GetCompatibleExecutionProviders if we need something more flexible
|
||||
if (node->GetExecutionProviderType() == kCpuExecutionProvider) {
|
||||
ORT_RETURN_IF_ERROR(MatchAndProcess(graph, *node, modified, logger));
|
||||
ORT_RETURN_IF_ERROR(MatchAndProcess(graph, graph_viewer, *node, modified, logger));
|
||||
}
|
||||
#else
|
||||
ORT_RETURN_IF_ERROR(ApplySaved(graph, modified, logger));
|
||||
|
|
|
|||
|
|
@ -14,14 +14,17 @@
|
|||
namespace onnxruntime {
|
||||
|
||||
class Graph;
|
||||
class GraphViewer;
|
||||
class Node;
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
// Base class for a selector which checks for a match and returns the set of nodes involved.
|
||||
struct NodeSelector {
|
||||
// Select one or more nodes for an Action to process if the constraints are satisfied.
|
||||
// `selection` should not be set if this returns false
|
||||
virtual bool Select(Graph& graph, const Node& node, std::unique_ptr<NodesToOptimize>& selection) const = 0;
|
||||
// Select one or more nodes for an Action to process if the constraints are satisfied,
|
||||
// otherwise returns std::nullopt
|
||||
virtual std::optional<NodesToOptimizeIndices> Select(const GraphViewer& graph_viewer, const Node& node) const = 0;
|
||||
|
||||
virtual ~NodeSelector() = default;
|
||||
|
||||
protected:
|
||||
|
|
@ -96,7 +99,7 @@ class SelectorsAndActions {
|
|||
};
|
||||
|
||||
/**
|
||||
Class that implements graph transformation via a set of Selector+Action pairs.
|
||||
Class that implements graph transformation via a set of Selector+Action pairs.
|
||||
This setup allows optimizations to be captured and applied at runtime in a minimal build.
|
||||
*/
|
||||
class SelectorActionTransformer : public GraphTransformer {
|
||||
|
|
@ -116,7 +119,17 @@ class SelectorActionTransformer : public GraphTransformer {
|
|||
SelectorsAndActions selectors_and_actions_;
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
Status MatchAndProcess(Graph& graph, Node& node, bool& modified, const logging::Logger& logger) const;
|
||||
|
||||
// check if the node matches any of the registered operators.
|
||||
// if it does, run the Selector.
|
||||
// if that selects nodes, run the Action.
|
||||
//
|
||||
// Some part of the MatchAndProcess use a GraphViewer of the given graph,
|
||||
// we choose to supply both the graph and the graph_viewer to avoid expensive
|
||||
// and repeatedly construction of the graph_viewer.
|
||||
// NOTE, the graph must be the same as the graph_viewer's underlying graph
|
||||
Status MatchAndProcess(Graph& graph, const GraphViewer& graph_viewer, Node& node,
|
||||
bool& modified, const logging::Logger& logger) const;
|
||||
|
||||
std::unordered_map<std::string, const SelectorAndAction*> op_type_to_selector_and_action_;
|
||||
// If set, save runtime optimization to graph. Otherwise, apply optimization to graph nodes.
|
||||
|
|
|
|||
|
|
@ -32,27 +32,27 @@ class TestTransformer : public SelectorActionTransformer {
|
|||
|
||||
private:
|
||||
struct SurroundingIdentitySelector : NodeSelector {
|
||||
bool Select(Graph& graph, const Node& node, std::unique_ptr<NodesToOptimize>& selection_out) const override {
|
||||
std::optional<NodesToOptimizeIndices> Select(const GraphViewer& graph_viewer, const Node& node) const override {
|
||||
// all inputs are identity
|
||||
const auto inputs = graph_utils::FindParentsByType(node, "Identity");
|
||||
if (inputs.size() != node.GetInputEdgesCount()) return false;
|
||||
if (inputs.size() != node.GetInputEdgesCount()) return std::nullopt;
|
||||
|
||||
// does not produce graph output
|
||||
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) return false;
|
||||
if (graph_viewer.NodeProducesGraphOutput(node)) return std::nullopt;
|
||||
|
||||
// all outputs are identity
|
||||
const auto outputs = graph_utils::FindChildrenByType(node, "Identity");
|
||||
if (outputs.size() != node.GetOutputEdgesCount()) return false;
|
||||
if (outputs.size() != node.GetOutputEdgesCount()) return std::nullopt;
|
||||
|
||||
NodesToOptimizeBuilder builder{};
|
||||
builder.target_node = graph.GetNode(node.Index());
|
||||
NodesToOptimizeIndicesBuilder builder;
|
||||
|
||||
auto get_mutable_node = [&](const Node* n) { return n ? graph.GetNode(n->Index()) : nullptr; };
|
||||
std::transform(inputs.begin(), inputs.end(), std::back_inserter(builder.input_nodes), get_mutable_node);
|
||||
std::transform(outputs.begin(), outputs.end(), std::back_inserter(builder.output_nodes), get_mutable_node);
|
||||
builder.target_node = node.Index();
|
||||
|
||||
selection_out = builder.Build();
|
||||
return true;
|
||||
auto get_node_idx = [&](const Node* n) { return n ? n->Index() : NodesToOptimizeIndices::kEmptyNodeIndex; };
|
||||
std::transform(inputs.begin(), inputs.end(), std::back_inserter(builder.input_nodes), get_node_idx);
|
||||
std::transform(outputs.begin(), outputs.end(), std::back_inserter(builder.output_nodes), get_node_idx);
|
||||
|
||||
return builder.Build();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,16 +1,20 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/compute_capability.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
|
||||
#include "core/providers/partitioning_utils.h"
|
||||
#include "core/session/environment.h"
|
||||
#include "core/session/inference_session.h"
|
||||
|
||||
#include "test/compare_ortvalue.h"
|
||||
#include "test/test_environment.h"
|
||||
#include "test/framework/test_utils.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
#include "test/util/include/inference_session_wrapper.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
|
@ -1629,5 +1633,59 @@ TEST(QDQTransformerTests, Concat_UInt8) {
|
|||
}
|
||||
|
||||
#endif // DISABLE_CONTRIB_OPS
|
||||
|
||||
TEST(QDQTransformerTests, QDQ_Selector_Test) {
|
||||
const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/qdq_conv_model_basic.onnx");
|
||||
|
||||
SessionOptions so;
|
||||
// We want to keep the graph un-optimized to prevent QDQ transformer to kick in
|
||||
so.graph_optimization_level = TransformerLevel::Default;
|
||||
InferenceSessionWrapper session_object{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object.Load(model_file_name));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
const Graph& graph = session_object.GetGraph();
|
||||
const auto* conv_node = graph.GetNode(3);
|
||||
|
||||
// Make sure node 3 is the conv node
|
||||
ASSERT_TRUE(nullptr != conv_node);
|
||||
ASSERT_EQ("Conv", conv_node->OpType());
|
||||
|
||||
onnxruntime::QDQ::ConvSelector conv_selector;
|
||||
|
||||
// Create a GraphViewer covers the whole graph
|
||||
const GraphViewer whole_graph_viewer(graph);
|
||||
|
||||
// Make sure the conv QDQ group is selected for the full graph
|
||||
{
|
||||
const auto result = conv_selector.GetQDQSelection(whole_graph_viewer, *conv_node);
|
||||
ASSERT_TRUE(result.has_value());
|
||||
const auto& qdq_group = *result;
|
||||
ASSERT_EQ(std::vector<NodeIndex>({0, 1, 2}), qdq_group.dq_nodes);
|
||||
ASSERT_EQ(NodeIndex(3), qdq_group.target_node);
|
||||
ASSERT_EQ(std::vector<NodeIndex>({4}), qdq_group.q_nodes);
|
||||
}
|
||||
|
||||
// Create a graph viewer covers part of the graph
|
||||
// Make sure the qdq conv selector will fail for the partial graph
|
||||
{
|
||||
// Get 3 nodes out of 5 nodes in the graph
|
||||
std::vector<const Node*> nodes{
|
||||
graph.GetNode(0),
|
||||
graph.GetNode(3),
|
||||
graph.GetNode(4),
|
||||
};
|
||||
|
||||
// Generate the indexed subgraph
|
||||
const auto compute_capability = utils::MakeComputeCapability(
|
||||
whole_graph_viewer, nodes,
|
||||
[]() { return "sub_graph"; },
|
||||
"Test Provider");
|
||||
|
||||
const GraphViewer partial_graph_viewer(graph, *compute_capability->sub_graph);
|
||||
ASSERT_EQ(3, partial_graph_viewer.NumberOfNodes());
|
||||
const auto result = conv_selector.GetQDQSelection(partial_graph_viewer, *conv_node);
|
||||
ASSERT_FALSE(result.has_value());
|
||||
}
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/qdq_conv_model_basic.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/qdq_conv_model_basic.onnx
vendored
Normal file
Binary file not shown.
45
onnxruntime/test/testdata/qdq_conv_test.py
vendored
Normal file
45
onnxruntime/test/testdata/qdq_conv_test.py
vendored
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
|
||||
|
||||
# Generate a basic QDQ Conv model
|
||||
def GenerateModel(model_name):
|
||||
nodes = [
|
||||
helper.make_node("DequantizeLinear", ["X", "Scale", "Zero_point_uint8"], ["input_DQ"], "input_DQ"),
|
||||
helper.make_node("DequantizeLinear", ["W", "Scale", "Zero_point_uint8"], ["conv_weight_DQ"], "conv_weight_DQ"),
|
||||
helper.make_node("DequantizeLinear", ["Bias", "Scale", "Zero_point_int32"], ["conv_bias_DQ"], "conv_bias_DQ"),
|
||||
helper.make_node("Conv", ["input_DQ", "conv_weight_DQ", "conv_bias_DQ"], ["conv_output"], "conv"),
|
||||
helper.make_node("QuantizeLinear", ["conv_output", "Scale", "Zero_point_uint8"], ["Y"], "output_Q"),
|
||||
]
|
||||
|
||||
initializers = [
|
||||
helper.make_tensor('Scale', TensorProto.FLOAT, [1], [256.0]),
|
||||
helper.make_tensor('Zero_point_uint8', TensorProto.UINT8, [1], [0]),
|
||||
helper.make_tensor('Zero_point_int32', TensorProto.INT32, [1], [0]),
|
||||
helper.make_tensor('W', TensorProto.UINT8, [1, 1, 3, 3], [128] * 9),
|
||||
helper.make_tensor('Bias', TensorProto.INT32, [1], [64]),
|
||||
]
|
||||
|
||||
inputs = [
|
||||
helper.make_tensor_value_info('X', TensorProto.UINT8, [1, 1, 5, 5]),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
helper.make_tensor_value_info('Y', TensorProto.UINT8, [1, 1, 3, 3]),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"QDQ_Conv_Model_Basic",
|
||||
inputs,
|
||||
outputs,
|
||||
initializers
|
||||
)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GenerateModel('qdq_conv_model_basic.onnx')
|
||||
|
|
@ -29,7 +29,7 @@ python3 /onnxruntime_src/tools/ci_build/build.py \
|
|||
--include_ops_by_config /home/onnxruntimedev/.test_data/include_no_operators.config
|
||||
|
||||
# set current size limit to BINARY_SIZE_LIMIT_IN_BYTES.
|
||||
BINARY_SIZE_LIMIT_IN_BYTES=1303608
|
||||
BINARY_SIZE_LIMIT_IN_BYTES=1305000
|
||||
echo "The current preset binary size limit is $BINARY_SIZE_LIMIT_IN_BYTES"
|
||||
python3 /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/check_build_binary_size.py \
|
||||
--threshold=$BINARY_SIZE_LIMIT_IN_BYTES \
|
||||
|
|
|
|||
Loading…
Reference in a new issue