Add QDQ::Selector::Select to use const GraphViewer instead of mutable Graph (#9621)

* Move qdq selector to use const GraphViewer

* minor update

* Move qdq logic from NodeSelector to QDQ Selectors

* Fix build break

* Move selector result to NodesToOptimizeIndexes

* fix build break

* address CR comments

* move indexes -> indices

* Pass  graph_viewer to avoid recreating many times

* Update after merge master

* update graph viewer remarks

* update comments

* Add ut for new qdq selector logic

* Increase minimal binary size limit

* UT minor update

* Address CR comments
This commit is contained in:
Guoyu Wang 2021-11-08 21:36:29 -08:00 committed by GitHub
parent 65590b049c
commit a70ae24475
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 402 additions and 146 deletions

View file

@ -80,6 +80,10 @@ class GraphViewer {
*/
const std::vector<const NodeArg*>& GetOutputs() const noexcept;
/** Returns true if one or more of the Node outputs are Graph outputs.
*/
bool NodeProducesGraphOutput(const Node& node) const;
/** Gets all ValueInfo NodeArg instances in the Graph.
@remarks NOT filtered using filter_info_.
*/
@ -146,6 +150,16 @@ class GraphViewer {
*/
bool IsConstantInitializer(const std::string& name, bool check_outer_scope) const;
/** returns the initializer's TensorProto if 'name' is an initializer, is constant and
cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned.
@param check_outer_scope If true and the graph is a subgraph,
check ancestor graph/s for 'name' if not found in 'graph'.
@remarks This function will return the result from GetConstantInitializer of the underlying Graph,
if a const initializer is part of the underlying Graph but not part of this GraphViewer,
it will still be returned instead of nullptr
*/
const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const;
/** Get the Node containing this Graph if IsSubgraph is true. Returns nullptr otherwise. */
const Node* ParentNode() const noexcept { return graph_->ParentNode(); }

View file

@ -199,6 +199,17 @@ const std::vector<const NodeArg*>& GraphViewer::GetOutputs() const noexcept {
: filtered_node_outputs_;
}
bool GraphViewer::NodeProducesGraphOutput(const Node& node) const {
const auto& outputs = GetOutputs();
auto end_outputs = outputs.cend();
for (auto output_def : node.OutputDefs()) {
if (std::find(outputs.cbegin(), end_outputs, output_def) != end_outputs) {
return true;
}
}
return false;
}
// Get graph value infos.
const std::unordered_set<const NodeArg*>& GraphViewer::GetValueInfo() const noexcept {
return graph_->GetValueInfo();
@ -262,7 +273,12 @@ bool GraphViewer::IsSubgraph() const {
}
bool GraphViewer::IsConstantInitializer(const std::string& name, bool check_outer_scope) const {
return graph_->GetConstantInitializer(name, check_outer_scope) != nullptr;
return GetConstantInitializer(name, check_outer_scope) != nullptr;
}
const ONNX_NAMESPACE::TensorProto* GraphViewer::GetConstantInitializer(const std::string& initializer_name,
bool check_outer_scope) const {
return graph_->GetConstantInitializer(initializer_name, check_outer_scope);
}
} // namespace onnxruntime

View file

@ -35,7 +35,7 @@
namespace onnxruntime {
/** Struct to serialize the node indices in an ORT format model.
Use EmptyNodeIndex for nullptr entries in the vectors for missing optional inputs
Use kEmptyNodeIndex for nullptr entries in the vectors for missing optional inputs
*/
struct NodesToOptimizeIndices {
/** Index value that represents an empty node.

View file

@ -19,7 +19,11 @@ static bool CanNodePropagate(const Node& node) {
}
static bool TryCancelOutDQQPair(Graph& graph, Node& dq_node, Node& q_node) {
if (!QDQ::IsQDQPairSupported(graph, q_node, dq_node)) {
auto get_const_initializer = [&graph](const std::string& initializer_name) {
return graph.GetConstantInitializer(initializer_name, true);
};
if (!QDQ::IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph.ModelPath())) {
return false;
}
@ -36,7 +40,8 @@ static bool TryCancelOutDQQPair(Graph& graph, Node& dq_node, Node& q_node) {
if (dq_input_edge_0) {
input_edge_info.first = dq_input_edge_0->GetNode().Index();
input_edge_info.second = dq_input_edge_0->GetSrcArgIndex();
graph.RemoveEdge(dq_input_edge_0->GetNode().Index(), dq_node.Index(), dq_input_edge_0->GetSrcArgIndex(), dq_input_edge_0->GetDstArgIndex());
graph.RemoveEdge(dq_input_edge_0->GetNode().Index(), dq_node.Index(),
dq_input_edge_0->GetSrcArgIndex(), dq_input_edge_0->GetDstArgIndex());
}
graph_utils::RemoveNodeOutputEdges(graph, dq_node); // Remove DQ node output edges
@ -45,11 +50,13 @@ static bool TryCancelOutDQQPair(Graph& graph, Node& dq_node, Node& q_node) {
graph_utils::RemoveNodeOutputEdges(graph, q_node); // Remove Q node output edges
for (auto& output_edge : output_edges) {
// set input NodeArg of Q's children to the 1st input of DQ
graph.GetNode(output_edge.dst_node)->MutableInputDefs()[output_edge.dst_arg_index] = dq_node.MutableInputDefs()[0];
graph.GetNode(output_edge.dst_node)->MutableInputDefs()[output_edge.dst_arg_index] =
dq_node.MutableInputDefs()[0];
// add edge between parent of DQ to children of Q
if (input_edge_info.second != -1) {
graph.AddEdge(input_edge_info.first, output_edge.dst_node, input_edge_info.second, output_edge.dst_arg_index);
graph.AddEdge(input_edge_info.first, output_edge.dst_node,
input_edge_info.second, output_edge.dst_arg_index);
}
}

View file

@ -13,7 +13,10 @@
namespace onnxruntime {
namespace QDQ {
bool IsQDQPairSupported(const Graph& graph, const Node& q_node, const Node& dq_node) {
bool IsQDQPairSupported(
const Node& q_node, const Node& dq_node,
const std::function<const ONNX_NAMESPACE::TensorProto*(const std::string&)>& get_const_initializer,
const Path& model_path) {
ConstPointerContainer<std::vector<NodeArg*>> dq_input_defs = dq_node.InputDefs();
ConstPointerContainer<std::vector<NodeArg*>> q_input_defs = q_node.InputDefs();
@ -30,13 +33,13 @@ bool IsQDQPairSupported(const Graph& graph, const Node& q_node, const Node& dq_n
// if Q/DQ scale and zero point are not constant, return false
const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto =
graph_utils::GetConstantInitializer(graph, dq_input_defs[InputIndex::SCALE_ID]->Name());
get_const_initializer(dq_input_defs[InputIndex::SCALE_ID]->Name());
const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto =
graph_utils::GetConstantInitializer(graph, q_input_defs[InputIndex::SCALE_ID]->Name());
get_const_initializer(q_input_defs[InputIndex::SCALE_ID]->Name());
const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto =
graph_utils::GetConstantInitializer(graph, dq_input_defs[InputIndex::ZERO_POINT_ID]->Name());
get_const_initializer(dq_input_defs[InputIndex::ZERO_POINT_ID]->Name());
const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto =
graph_utils::GetConstantInitializer(graph, q_input_defs[InputIndex::ZERO_POINT_ID]->Name());
get_const_initializer(q_input_defs[InputIndex::ZERO_POINT_ID]->Name());
if (nullptr == q_zp_tensor_proto ||
nullptr == dq_zp_tensor_proto ||
nullptr == q_scale_tensor_proto ||
@ -45,10 +48,10 @@ bool IsQDQPairSupported(const Graph& graph, const Node& q_node, const Node& dq_n
}
// check Q/DQ have same scale and zero point
Initializer q_zp(*q_zp_tensor_proto, graph.ModelPath());
Initializer q_scale(*q_scale_tensor_proto, graph.ModelPath());
Initializer dq_zp(*dq_zp_tensor_proto, graph.ModelPath());
Initializer dq_scale(*dq_scale_tensor_proto, graph.ModelPath());
Initializer q_zp(*q_zp_tensor_proto, model_path);
Initializer q_scale(*q_scale_tensor_proto, model_path);
Initializer dq_zp(*dq_zp_tensor_proto, model_path);
Initializer dq_scale(*dq_scale_tensor_proto, model_path);
return q_zp.data_type() == dq_zp.data_type() &&
*q_zp.data<int8_t>() == *dq_zp.data<int8_t>() &&

View file

@ -3,10 +3,17 @@
#pragma once
#include <functional>
#include <string>
namespace ONNX_NAMESPACE {
class TensorProto;
}
namespace onnxruntime {
class Graph;
class Node;
class Path;
namespace QDQ {
@ -24,7 +31,10 @@ enum InputIndex : int {
// 1. Q/DQ doesn't have optional input.
// 2. scale and zero point is constant scalar
// 3. Q and DQ have same scale and zero point
bool IsQDQPairSupported(const Graph& graph, const Node& q_node, const Node& dq_node);
bool IsQDQPairSupported(
const Node& q_node, const Node& dq_node,
const std::function<const ONNX_NAMESPACE::TensorProto*(const std::string&)>& get_const_initializer,
const Path& model_path);
} // namespace QDQ
} // namespace onnxruntime

View file

@ -21,7 +21,23 @@ int NumActualValues(const Node& node, bool input) {
}
} // namespace
bool BaseSelector::CheckQDQNodes(const Graph& graph, const Node& node,
static std::vector<const Node*> FindQDQNodes(const GraphViewer& graph_viewer, const Node& node, bool find_dq_nodes) {
// First get all the upstream (DQ) or downstream (Q) nodes
std::vector<const Node*> nodes =
find_dq_nodes ? graph_utils::FindParentsByType(node, QDQ::DQOpName)
: graph_utils::FindChildrenByType(node, QDQ::QOpName);
// Remove all the nodes which are not in the graph_viewer
nodes.erase(std::remove_if(nodes.begin(), nodes.end(),
[&graph_viewer](const Node* _node) {
return _node == nullptr || graph_viewer.GetNode(_node->Index()) == nullptr;
}),
nodes.end());
return nodes;
}
bool BaseSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes,
int num_dq_inputs) const {
@ -31,63 +47,67 @@ bool BaseSelector::CheckQDQNodes(const Graph& graph, const Node& node,
int num_outputs = NumActualValues(node, false); // number of outputs that exist
// The input is a Graph Viewer, so cannot use graph_utils or optimizer_utils
return num_dq_inputs == gsl::narrow_cast<int>(dq_nodes.size()) &&
num_outputs == gsl::narrow_cast<int>(q_nodes.size()) &&
optimizer_utils::CheckOutputEdges(graph, node, q_nodes.size());
q_nodes.size() == node.GetOutputEdgesCount() &&
!graph_viewer.NodeProducesGraphOutput(node);
}
bool BaseSelector::Select(Graph& graph, const Node& node, std::unique_ptr<NodesToOptimize>& selection) const {
std::vector<const Node*> dq_nodes = graph_utils::FindParentsByType(node, QDQ::DQOpName);
std::vector<const Node*> q_nodes = graph_utils::FindChildrenByType(node, QDQ::QOpName);
if (!Check(graph, node, dq_nodes, q_nodes)) {
return false;
std::optional<NodeGroup> BaseSelector::GetQDQSelection(const GraphViewer& graph_viewer, const Node& node) const {
std::vector<const Node*> dq_nodes = FindQDQNodes(graph_viewer, node, true);
std::vector<const Node*> q_nodes = FindQDQNodes(graph_viewer, node, false);
if (!Check(graph_viewer, node, dq_nodes, q_nodes)) {
return std::nullopt;
}
auto get_mutable_node = [&graph](const Node* node) {
// we use the non-const GetNode to convert the const Node* to Node*
return graph.GetNode(node->Index());
};
NodeGroup node_group;
node_group.dq_nodes.reserve(dq_nodes.size());
node_group.q_nodes.reserve(q_nodes.size());
node_group.target_node = node.Index();
auto get_node_idx = [&](const Node* n) { return n->Index(); };
std::transform(dq_nodes.begin(), dq_nodes.end(), std::back_inserter(node_group.dq_nodes), get_node_idx);
std::transform(q_nodes.begin(), q_nodes.end(), std::back_inserter(node_group.q_nodes), get_node_idx);
return node_group;
}
NodesToOptimizeBuilder builder;
builder.input_nodes.reserve(dq_nodes.size());
builder.output_nodes.reserve(q_nodes.size());
for (const Node* dq_node : dq_nodes) {
builder.input_nodes.push_back(dq_node != nullptr ? get_mutable_node(dq_node) : nullptr);
std::optional<NodesToOptimizeIndices> BaseSelector::Select(const GraphViewer& graph_viewer, const Node& node) const {
const auto qdq_group = GetQDQSelection(graph_viewer, node);
if (!qdq_group.has_value()) {
return std::nullopt;
}
builder.target_node = get_mutable_node(&node);
for (const Node* q_node : q_nodes) {
builder.output_nodes.push_back(get_mutable_node(q_node));
}
NodesToOptimizeIndicesBuilder builder;
builder.input_nodes = qdq_group->dq_nodes;
builder.output_nodes = qdq_group->q_nodes;
builder.target_node = qdq_group->target_node;
UpdateBuilder(builder);
selection = builder.Build();
return true;
return builder.Build();
}
bool DropDQDNodesSelector::Check(const Graph& graph,
bool DropDQDNodesSelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
if (!CheckQDQNodes(graph, node, dq_nodes, q_nodes, 1)) {
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 1)) {
return false;
}
const Node& dq_node = *dq_nodes.front();
const Node& q_node = *q_nodes.front();
return IsQDQPairSupported(graph, q_node, dq_node);
auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) {
return graph_viewer.GetConstantInitializer(initializer_name, true);
};
return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath());
}
bool UnarySelector::Check(const Graph& graph, const Node& node,
bool UnarySelector::Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
if (!CheckQDQNodes(graph, node, dq_nodes, q_nodes, 1)) {
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 1)) {
return false;
}
@ -100,11 +120,11 @@ bool UnarySelector::Check(const Graph& graph, const Node& node,
(int8_allowed_ && dt_output == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)));
}
bool BinarySelector::Check(const Graph& graph,
bool BinarySelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
if (!CheckQDQNodes(graph, node, dq_nodes, q_nodes)) {
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) {
return false;
}
@ -116,11 +136,11 @@ bool BinarySelector::Check(const Graph& graph,
dt_input_1 == dt_output;
}
bool VariadicSelector::Check(const Graph& graph,
bool VariadicSelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
if (!CheckQDQNodes(graph, node, dq_nodes, q_nodes)) {
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) {
return false;
}
@ -136,15 +156,15 @@ bool VariadicSelector::Check(const Graph& graph,
return dt_input == dt_output;
}
void VariadicSelector::UpdateBuilder(NodesToOptimizeBuilder& builder) const {
void VariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const {
builder.num_input_defs = 1; // set to 1 as the first input is variadic
}
bool ConvSelector::Check(const Graph& graph,
bool ConvSelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
if (!CheckQDQNodes(graph, node, dq_nodes, q_nodes)) {
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) {
return false;
}
@ -164,11 +184,11 @@ bool ConvSelector::Check(const Graph& graph,
return dt_bias == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32;
}
void ConvSelector::UpdateBuilder(NodesToOptimizeBuilder& builder) const {
builder.input_nodes.resize(3); // add nullptr for bias if missing
void ConvSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const {
builder.input_nodes.resize(3, NodesToOptimizeIndices::kEmptyNodeIndex);
}
bool MatMulSelector::Check(const Graph& graph,
bool MatMulSelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
@ -181,7 +201,7 @@ bool MatMulSelector::Check(const Graph& graph,
if (qlinear) {
// QLinearMatMul
if (!CheckQDQNodes(graph, node, dq_nodes, q_nodes)) {
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) {
return false;
}

View file

@ -12,39 +12,51 @@ class Graph;
class Node;
namespace QDQ {
// Struct to represent a DQ->Op->Q node group
struct NodeGroup {
std::vector<NodeIndex> dq_nodes;
std::vector<NodeIndex> q_nodes;
NodeIndex target_node;
};
// Base QDQ checker. Finds and provides the DQ and Q nodes to the operator specific checkers, as the QDQ optimizations
// always involve those nodes.
class BaseSelector : public NodeSelector {
public:
bool Select(Graph& graph, const Node& node, std::unique_ptr<NodesToOptimize>& selection) const override;
std::optional<NodesToOptimizeIndices> Select(const GraphViewer& graph_viewer, const Node& node) const override;
// This is a QDQ Selectors only function, will return QDQ::NodeGroup instead of NodesToOptimizeIndices
// Can be used in QDQ handling in EPs such as NNAPI
std::optional<NodeGroup> GetQDQSelection(const GraphViewer& graph_viewer, const Node& node) const;
protected:
BaseSelector() = default;
// base check that we have the expected number of QDQ inputs/outputs, and `node` isn't producing a graph output.
// num_dq_inputs defaults to the number of inputs `node` has if not explicitly specified
bool CheckQDQNodes(const Graph& graph, const Node& node,
bool CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes,
int num_dq_inputs = -1) const;
private:
// derived classes should implement this check
bool virtual Check(const Graph& graph, const Node& node,
bool virtual Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const = 0;
// override if you need to adjust the values in NodesToOptimize.
// e.g. add entries for missing optional DQ inputs or set num_inputs to handle variadic inputs
// Called post-Check, if Check returned `true`
virtual void UpdateBuilder(NodesToOptimizeBuilder&) const {}
virtual void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const {}
};
// Single DQ -> node that does not change data -> Q.
// Zero point and scale are constant scalars and must match
class DropDQDNodesSelector : public BaseSelector {
private:
bool Check(const Graph& graph, const Node& node,
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;
};
@ -55,7 +67,7 @@ class UnarySelector : public BaseSelector {
UnarySelector(bool int8_allowed = false) : int8_allowed_{int8_allowed} {}
private:
bool Check(const Graph& graph, const Node& node,
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;
@ -64,31 +76,36 @@ class UnarySelector : public BaseSelector {
// 2 DQ nodes providing input -> node -> Q
class BinarySelector : public BaseSelector {
bool Check(const Graph& graph, const Node& node,
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;
};
// Variadic DQ nodes -> node -> Q
class VariadicSelector : public BaseSelector {
bool Check(const Graph& graph, const Node& node,
public:
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;
void UpdateBuilder(NodesToOptimizeBuilder&) const override;
};
// DQ nodes for X, W and optionally B -> node -> Q
class ConvSelector : public BaseSelector {
bool Check(const Graph& graph, const Node& node,
public:
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;
void UpdateBuilder(NodesToOptimizeBuilder&) const override;
};
// 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not
class MatMulSelector : public BaseSelector {
bool Check(const Graph& graph, const Node& node,
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;
};

View file

@ -8,6 +8,16 @@ using namespace ::onnxruntime::common;
namespace onnxruntime {
namespace {
// if the last input/output in num_io is for the variadic input/output,
// the variadic input/output could have zero or more values
// so we need to special case the zero and count that as one.
int NumIOEntries(bool variadic_io, int num_io, int num_variadic_io) {
return variadic_io
? num_io + std::max(1, num_variadic_io) - 1
: num_io;
}
// Move or remove an edge.
// - moves edges from src+src_slot to dest node+dest_slot if provided.
// - remove edges for the src+src_slot if dest+dest_slot not provided.
@ -116,11 +126,11 @@ Node* GetNodeByNodeIndex(Graph& graph, NodeIndex idx, bool& missing) {
return node;
}
bool GetNodesByNodeIndex(Graph& graph, const std::vector<NodeIndex>& indexes, std::vector<Node*>& nodes) {
nodes.reserve(indexes.size());
bool GetNodesByNodeIndex(Graph& graph, const std::vector<NodeIndex>& indices, std::vector<Node*>& nodes) {
nodes.reserve(indices.size());
bool missing = false;
for (auto iter = indexes.cbegin(), end = indexes.cend(); iter != end; ++iter) {
for (auto iter = indices.cbegin(), end = indices.cend(); iter != end; ++iter) {
nodes.push_back(GetNodeByNodeIndex(graph, *iter, missing));
// bail if we're missing a node
@ -137,6 +147,50 @@ bool GetNodesByNodeIndex(Graph& graph, const std::vector<NodeIndex>& indexes, st
// Selections
//
// Helper to create the NodesToOptimizeIndices
// specify num_input_defs/num_output_defs if the last input/output is variadic (default is non-variadic)
static NodesToOptimizeIndices GetNodesToOptimizeIndices(
const std::vector<NodeIndex>& input_nodes, NodeIndex target_node, const std::vector<NodeIndex>& output_nodes,
int num_input_defs, int num_output_defs) {
int num_inputs = num_input_defs == -1 ? gsl::narrow_cast<int>(input_nodes.size()) : num_input_defs;
int num_outputs = num_output_defs == -1 ? gsl::narrow_cast<int>(output_nodes.size()) : num_output_defs;
bool variadic_input = false;
bool variadic_output = false;
int num_variadic_inputs = 0;
int num_variadic_outputs = 0;
if (num_input_defs != -1) {
variadic_input = true;
num_variadic_inputs = gsl::narrow_cast<int>(input_nodes.size()) - num_input_defs + 1;
}
if (num_output_defs != -1) {
variadic_output = true;
num_variadic_outputs = gsl::narrow_cast<int>(output_nodes.size()) - num_output_defs + 1;
}
std::vector<NodeIndex> node_indices;
node_indices.reserve(NumIOEntries(variadic_input, num_inputs, num_variadic_inputs) + 1 +
NumIOEntries(variadic_output, num_outputs, num_variadic_outputs));
std::copy(input_nodes.begin(), input_nodes.end(), std::back_inserter(node_indices));
node_indices.push_back(target_node);
std::copy(output_nodes.begin(), output_nodes.end(), std::back_inserter(node_indices));
std::for_each(node_indices.cbegin(), node_indices.cend(), [](NodeIndex node_idx) {
ORT_ENFORCE(node_idx <= NodesToOptimizeIndices::kEmptyNodeIndex,
"Node index value is too large to save to ORT format model: ", node_idx);
});
return NodesToOptimizeIndices{std::move(node_indices), num_inputs, num_outputs,
variadic_input, variadic_output,
num_variadic_inputs, num_variadic_outputs};
}
NodesToOptimizeIndices NodesToOptimizeIndicesBuilder::Build() const {
ORT_ENFORCE(target_node != NodesToOptimizeIndices::kEmptyNodeIndex, "A target node must be set.");
return GetNodesToOptimizeIndices(input_nodes, target_node, output_nodes, num_input_defs, num_output_defs);
}
NodesToOptimize::NodesToOptimize(const std::vector<Node*>& input_nodes,
Node& target_node,
const std::vector<Node*>& output_nodes,
@ -160,41 +214,39 @@ NodesToOptimize::NodesToOptimize(const std::vector<Node*>& input_nodes,
}
NodesToOptimize::NodesToOptimize(Graph& graph,
const NodesToOptimizeIndices& indexes)
: num_inputs{indexes.num_inputs},
num_outputs{indexes.num_outputs} {
bool missing_nodes = GetNodesByNodeIndex(graph, indexes.nodes, nodes_);
const NodesToOptimizeIndices& indices)
: num_inputs{indices.num_inputs},
num_outputs{indices.num_outputs},
variadic_input_{indices.variadic_input},
variadic_output_{indices.variadic_output},
num_variadic_inputs_{indices.num_variadic_inputs},
num_variadic_outputs_{indices.num_variadic_outputs} {
bool missing_nodes = !GetNodesByNodeIndex(graph, indices.nodes, nodes_);
if (missing_nodes) {
nodes_.clear(); // this will result in IsValid returning false
}
}
NodesToOptimizeIndices NodesToOptimize::ToIndices() const {
NodesToOptimizeIndices indexes;
indexes.nodes.reserve(nodes_.size());
std::for_each(nodes_.cbegin(), nodes_.cend(), [&indexes](const Node* node) {
std::vector<NodeIndex> node_indices;
node_indices.reserve(nodes_.size());
std::for_each(nodes_.cbegin(), nodes_.cend(), [&node_indices](const Node* node) {
const NodeIndex node_idx = node != nullptr ? node->Index() : NodesToOptimizeIndices::kEmptyNodeIndex;
ORT_ENFORCE(node_idx <= NodesToOptimizeIndices::kEmptyNodeIndex,
"Node index value is too large to save to ORT format model: ", node_idx);
indexes.nodes.push_back(node_idx);
node_indices.push_back(node_idx);
});
indexes.num_inputs = num_inputs;
indexes.num_outputs = num_outputs;
indexes.variadic_input = variadic_input_;
indexes.variadic_output = variadic_output_;
indexes.num_variadic_inputs = num_variadic_inputs_;
indexes.num_variadic_outputs = num_variadic_outputs_;
return indexes;
return NodesToOptimizeIndices{std::move(node_indices), num_inputs, num_outputs,
variadic_input_, variadic_output_,
num_variadic_inputs_, num_variadic_outputs_};
}
std::vector<Node*> NodesToOptimize::Inputs(const std::vector<int>& indexes, bool required) const {
std::vector<Node*> NodesToOptimize::Inputs(const std::vector<int>& indices, bool required) const {
std::vector<Node*> results;
results.reserve(NumInputEntries());
for (auto idx : indexes) {
for (auto idx : indices) {
if (idx == num_inputs - 1 && HasVariadicInput()) {
for (int i = 0, end = NumVariadicInputs(); i < end; ++i) {
results.push_back(GetNode(idx + i, required));
@ -207,14 +259,14 @@ std::vector<Node*> NodesToOptimize::Inputs(const std::vector<int>& indexes, bool
return results;
}
std::vector<Node*> NodesToOptimize::Outputs(const std::vector<int>& indexes, bool required) const {
std::vector<Node*> NodesToOptimize::Outputs(const std::vector<int>& indices, bool required) const {
std::vector<Node*> results;
results.reserve(NumOutputEntries());
// offset by all the inputs and the target node
const int offset = NumInputEntries() + 1;
for (auto idx : indexes) {
for (auto idx : indices) {
if (idx == num_outputs - 1 && HasVariadicOutput()) {
for (int i = 0, end = NumVariadicOutputs(); i < end; ++i) {
results.push_back(GetNode(offset + idx + i, required));
@ -236,6 +288,14 @@ std::vector<Node*> NodesToOptimize::GetNodesAtLocation(const NodeLocation& locat
return {&Target()};
};
int NodesToOptimize::NumInputEntries() const {
return NumIOEntries(variadic_input_, num_inputs, num_variadic_inputs_);
}
int NodesToOptimize::NumOutputEntries() const {
return NumIOEntries(variadic_output_, num_outputs, num_variadic_outputs_);
}
//
// Actions
//

View file

@ -43,7 +43,7 @@ class NodesToOptimize {
// construct from saved NodeIndex values. IsValid() will return false if one or more nodes were missing.
// Use NodesToOptimizeIndices::kEmptyNodeIndex for nullptr entries in the vectors for missing optional inputs
NodesToOptimize(Graph& graph, const NodesToOptimizeIndices& node_indexes);
NodesToOptimize(Graph& graph, const NodesToOptimizeIndices& node_indices);
NodesToOptimizeIndices ToIndices() const;
@ -72,15 +72,15 @@ class NodesToOptimize {
bool IsValid() const { return !nodes_.empty(); }
// fetch an input.
// valid indexes are 0 to num_inputs - 1 if no variadic inputs.
// if there are variadic inputs, valid indexes are 0 to num_inputs + num_extra_variadic_inputs - 1
// valid indices are 0 to num_inputs - 1 if no variadic inputs.
// if there are variadic inputs, valid indices are 0 to num_inputs + num_extra_variadic_inputs - 1
// e.g. 3 inputs. last is variadic with 3 values. num_inputs=3 num_extra_variadic_inputs=2 for a total of 5 inputs.
Node* Input(int idx, bool required = true) const {
return GetNode(idx, required);
}
// inputs filtered by index. includes all variadic.
std::vector<Node*> Inputs(const std::vector<int>& indexes, bool required = true) const;
std::vector<Node*> Inputs(const std::vector<int>& indices, bool required = true) const;
Node& Target() const {
return *GetNode(NumInputEntries() + 0, /*required*/ true);
@ -91,7 +91,7 @@ class NodesToOptimize {
}
// outputs filtered by index. includes all variadic.
std::vector<Node*> Outputs(const std::vector<int>& indexes, bool required = true) const;
std::vector<Node*> Outputs(const std::vector<int>& indices, bool required = true) const;
// Get the Node or Nodes (if variadic) at a specific index.
std::vector<Node*> GetNodesAtLocation(const NodeLocation& location, bool required = true) const;
@ -109,12 +109,8 @@ class NodesToOptimize {
return node;
}
// if the last input in num_inputs is for the variadic input, the variadic input could have zero or more values
// so we need to special case the zero and count that as one. same for outputs
int NumInputEntries() const { return variadic_input_ ? num_inputs + std::max(1, num_variadic_inputs_) - 1
: num_inputs; }
int NumOutputEntries() const { return variadic_output_ ? num_outputs + std::max(1, num_variadic_outputs_) - 1
: num_outputs; }
int NumInputEntries() const;
int NumOutputEntries() const;
bool variadic_input_{false}; // is last input variadic
bool variadic_output_{false};
@ -123,19 +119,16 @@ class NodesToOptimize {
std::vector<Node*> nodes_;
};
// Helper to build a NodesToOptimize instance
// Helper to build a NodesToOptimizeIndices instance
// Use in selector to incrementally add pieces
struct NodesToOptimizeBuilder {
std::vector<Node*> input_nodes;
Node* target_node{nullptr};
std::vector<Node*> output_nodes;
struct NodesToOptimizeIndicesBuilder {
std::vector<NodeIndex> input_nodes;
NodeIndex target_node{NodesToOptimizeIndices::kEmptyNodeIndex};
std::vector<NodeIndex> output_nodes;
int num_input_defs{-1};
int num_output_defs{-1};
std::unique_ptr<NodesToOptimize> Build() {
ORT_ENFORCE(target_node != nullptr, "A target node must be set.");
return std::make_unique<NodesToOptimize>(input_nodes, *target_node, output_nodes, num_input_defs, num_output_defs);
}
NodesToOptimizeIndices Build() const;
};
//

View file

@ -51,13 +51,10 @@ void SelectorsAndActions::RegisterSelectorAndAction(const std::string& name,
ORT_IGNORE_RETURN_VALUE(selectors_and_actions_map_.emplace(name, std::move(entry)));
}
// check if the node matches any of the registered operators.
// if it does, run the Selector.
// if that selects nodes, run the Action.
Status SelectorActionTransformer::MatchAndProcess(Graph& graph, Node& node, bool& modified,
Status SelectorActionTransformer::MatchAndProcess(Graph& graph, const GraphViewer& graph_viewer,
Node& node, bool& modified,
const logging::Logger& logger) const {
Status status = Status::OK();
do {
// TODO: for now this just needs to support ONNX ops. If we ever had a transformer that was going to
// target non-ONNX ops we'd need to rework a few things to include the op domain in the matches
@ -80,20 +77,23 @@ Status SelectorActionTransformer::MatchAndProcess(Graph& graph, Node& node, bool
}
}
std::unique_ptr<NodesToOptimize> node_group;
if (!selector_and_action.selector->Select(graph, node, node_group)) {
const auto node_selection_opt = selector_and_action.selector->Select(graph_viewer, node);
if (!node_selection_opt.has_value()) {
break;
}
const auto& node_selection = *node_selection_opt;
LOGS(logger, VERBOSE) << "Matched " << node.OpType();
NodesToOptimize node_group(graph, node_selection);
if (runtime_optimization_save_context_.has_value()) {
#if defined(ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)
const auto& action = *selector_and_action.action;
Action::SavedState action_saved_state{};
status = action.RunForSave(graph, *node_group, *runtime_optimization_save_context_, action_saved_state,
modified);
status = action.RunForSave(graph, node_group, *runtime_optimization_save_context_, action_saved_state,
modified);
if (!status.IsOK()) {
break;
}
@ -101,7 +101,7 @@ Status SelectorActionTransformer::MatchAndProcess(Graph& graph, Node& node, bool
graph.MutableRuntimeOptimizations().AddRecord(
Name(),
RuntimeOptimizationRecord{selector_and_action.name,
node_group->ToIndices(),
node_selection,
action_saved_state.produced_nodes});
#else
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAILED,
@ -109,7 +109,7 @@ Status SelectorActionTransformer::MatchAndProcess(Graph& graph, Node& node, bool
break;
#endif
} else {
status = selector_and_action.action->Run(graph, *node_group);
status = selector_and_action.action->Run(graph, node_group);
if (!status.IsOK()) {
break;
}
@ -192,7 +192,7 @@ Status SelectorActionTransformer::ApplyImpl(Graph& graph, bool& modified, int gr
#if !defined(ORT_MINIMAL_BUILD)
// TODO: use GraphTransformer::GetCompatibleExecutionProviders if we need something more flexible
if (node->GetExecutionProviderType() == kCpuExecutionProvider) {
ORT_RETURN_IF_ERROR(MatchAndProcess(graph, *node, modified, logger));
ORT_RETURN_IF_ERROR(MatchAndProcess(graph, graph_viewer, *node, modified, logger));
}
#else
ORT_RETURN_IF_ERROR(ApplySaved(graph, modified, logger));

View file

@ -14,14 +14,17 @@
namespace onnxruntime {
class Graph;
class GraphViewer;
class Node;
#if !defined(ORT_MINIMAL_BUILD)
// Base class for a selector which checks for a match and returns the set of nodes involved.
struct NodeSelector {
// Select one or more nodes for an Action to process if the constraints are satisfied.
// `selection` should not be set if this returns false
virtual bool Select(Graph& graph, const Node& node, std::unique_ptr<NodesToOptimize>& selection) const = 0;
// Select one or more nodes for an Action to process if the constraints are satisfied,
// otherwise returns std::nullopt
virtual std::optional<NodesToOptimizeIndices> Select(const GraphViewer& graph_viewer, const Node& node) const = 0;
virtual ~NodeSelector() = default;
protected:
@ -96,7 +99,7 @@ class SelectorsAndActions {
};
/**
Class that implements graph transformation via a set of Selector+Action pairs.
Class that implements graph transformation via a set of Selector+Action pairs.
This setup allows optimizations to be captured and applied at runtime in a minimal build.
*/
class SelectorActionTransformer : public GraphTransformer {
@ -116,7 +119,17 @@ class SelectorActionTransformer : public GraphTransformer {
SelectorsAndActions selectors_and_actions_;
#if !defined(ORT_MINIMAL_BUILD)
Status MatchAndProcess(Graph& graph, Node& node, bool& modified, const logging::Logger& logger) const;
// check if the node matches any of the registered operators.
// if it does, run the Selector.
// if that selects nodes, run the Action.
//
// Some part of the MatchAndProcess use a GraphViewer of the given graph,
// we choose to supply both the graph and the graph_viewer to avoid expensive
// and repeatedly construction of the graph_viewer.
// NOTE, the graph must be the same as the graph_viewer's underlying graph
Status MatchAndProcess(Graph& graph, const GraphViewer& graph_viewer, Node& node,
bool& modified, const logging::Logger& logger) const;
std::unordered_map<std::string, const SelectorAndAction*> op_type_to_selector_and_action_;
// If set, save runtime optimization to graph. Otherwise, apply optimization to graph nodes.

View file

@ -32,27 +32,27 @@ class TestTransformer : public SelectorActionTransformer {
private:
struct SurroundingIdentitySelector : NodeSelector {
bool Select(Graph& graph, const Node& node, std::unique_ptr<NodesToOptimize>& selection_out) const override {
std::optional<NodesToOptimizeIndices> Select(const GraphViewer& graph_viewer, const Node& node) const override {
// all inputs are identity
const auto inputs = graph_utils::FindParentsByType(node, "Identity");
if (inputs.size() != node.GetInputEdgesCount()) return false;
if (inputs.size() != node.GetInputEdgesCount()) return std::nullopt;
// does not produce graph output
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) return false;
if (graph_viewer.NodeProducesGraphOutput(node)) return std::nullopt;
// all outputs are identity
const auto outputs = graph_utils::FindChildrenByType(node, "Identity");
if (outputs.size() != node.GetOutputEdgesCount()) return false;
if (outputs.size() != node.GetOutputEdgesCount()) return std::nullopt;
NodesToOptimizeBuilder builder{};
builder.target_node = graph.GetNode(node.Index());
NodesToOptimizeIndicesBuilder builder;
auto get_mutable_node = [&](const Node* n) { return n ? graph.GetNode(n->Index()) : nullptr; };
std::transform(inputs.begin(), inputs.end(), std::back_inserter(builder.input_nodes), get_mutable_node);
std::transform(outputs.begin(), outputs.end(), std::back_inserter(builder.output_nodes), get_mutable_node);
builder.target_node = node.Index();
selection_out = builder.Build();
return true;
auto get_node_idx = [&](const Node* n) { return n ? n->Index() : NodesToOptimizeIndices::kEmptyNodeIndex; };
std::transform(inputs.begin(), inputs.end(), std::back_inserter(builder.input_nodes), get_node_idx);
std::transform(outputs.begin(), outputs.end(), std::back_inserter(builder.output_nodes), get_node_idx);
return builder.Build();
}
};

View file

@ -1,16 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/compute_capability.h"
#include "core/graph/model.h"
#include "core/graph/onnx_protobuf.h"
#include "core/mlas/inc/mlas.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
#include "core/providers/partitioning_utils.h"
#include "core/session/environment.h"
#include "core/session/inference_session.h"
#include "test/compare_ortvalue.h"
#include "test/test_environment.h"
#include "test/framework/test_utils.h"
#include "test/util/include/asserts.h"
#include "test/util/include/inference_session_wrapper.h"
#include "gtest/gtest.h"
@ -1629,5 +1633,59 @@ TEST(QDQTransformerTests, Concat_UInt8) {
}
#endif // DISABLE_CONTRIB_OPS
TEST(QDQTransformerTests, QDQ_Selector_Test) {
const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/qdq_conv_model_basic.onnx");
SessionOptions so;
// We want to keep the graph un-optimized to prevent QDQ transformer to kick in
so.graph_optimization_level = TransformerLevel::Default;
InferenceSessionWrapper session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.Load(model_file_name));
ASSERT_STATUS_OK(session_object.Initialize());
const Graph& graph = session_object.GetGraph();
const auto* conv_node = graph.GetNode(3);
// Make sure node 3 is the conv node
ASSERT_TRUE(nullptr != conv_node);
ASSERT_EQ("Conv", conv_node->OpType());
onnxruntime::QDQ::ConvSelector conv_selector;
// Create a GraphViewer covers the whole graph
const GraphViewer whole_graph_viewer(graph);
// Make sure the conv QDQ group is selected for the full graph
{
const auto result = conv_selector.GetQDQSelection(whole_graph_viewer, *conv_node);
ASSERT_TRUE(result.has_value());
const auto& qdq_group = *result;
ASSERT_EQ(std::vector<NodeIndex>({0, 1, 2}), qdq_group.dq_nodes);
ASSERT_EQ(NodeIndex(3), qdq_group.target_node);
ASSERT_EQ(std::vector<NodeIndex>({4}), qdq_group.q_nodes);
}
// Create a graph viewer covers part of the graph
// Make sure the qdq conv selector will fail for the partial graph
{
// Get 3 nodes out of 5 nodes in the graph
std::vector<const Node*> nodes{
graph.GetNode(0),
graph.GetNode(3),
graph.GetNode(4),
};
// Generate the indexed subgraph
const auto compute_capability = utils::MakeComputeCapability(
whole_graph_viewer, nodes,
[]() { return "sub_graph"; },
"Test Provider");
const GraphViewer partial_graph_viewer(graph, *compute_capability->sub_graph);
ASSERT_EQ(3, partial_graph_viewer.NumberOfNodes());
const auto result = conv_selector.GetQDQSelection(partial_graph_viewer, *conv_node);
ASSERT_FALSE(result.has_value());
}
}
} // namespace test
} // namespace onnxruntime

Binary file not shown.

View file

@ -0,0 +1,45 @@
import onnx
from onnx import helper
from onnx import TensorProto
# Generate a basic QDQ Conv model
def GenerateModel(model_name):
nodes = [
helper.make_node("DequantizeLinear", ["X", "Scale", "Zero_point_uint8"], ["input_DQ"], "input_DQ"),
helper.make_node("DequantizeLinear", ["W", "Scale", "Zero_point_uint8"], ["conv_weight_DQ"], "conv_weight_DQ"),
helper.make_node("DequantizeLinear", ["Bias", "Scale", "Zero_point_int32"], ["conv_bias_DQ"], "conv_bias_DQ"),
helper.make_node("Conv", ["input_DQ", "conv_weight_DQ", "conv_bias_DQ"], ["conv_output"], "conv"),
helper.make_node("QuantizeLinear", ["conv_output", "Scale", "Zero_point_uint8"], ["Y"], "output_Q"),
]
initializers = [
helper.make_tensor('Scale', TensorProto.FLOAT, [1], [256.0]),
helper.make_tensor('Zero_point_uint8', TensorProto.UINT8, [1], [0]),
helper.make_tensor('Zero_point_int32', TensorProto.INT32, [1], [0]),
helper.make_tensor('W', TensorProto.UINT8, [1, 1, 3, 3], [128] * 9),
helper.make_tensor('Bias', TensorProto.INT32, [1], [64]),
]
inputs = [
helper.make_tensor_value_info('X', TensorProto.UINT8, [1, 1, 5, 5]),
]
outputs = [
helper.make_tensor_value_info('Y', TensorProto.UINT8, [1, 1, 3, 3]),
]
graph = helper.make_graph(
nodes,
"QDQ_Conv_Model_Basic",
inputs,
outputs,
initializers
)
model = helper.make_model(graph)
onnx.save(model, model_name)
if __name__ == "__main__":
GenerateModel('qdq_conv_model_basic.onnx')

View file

@ -29,7 +29,7 @@ python3 /onnxruntime_src/tools/ci_build/build.py \
--include_ops_by_config /home/onnxruntimedev/.test_data/include_no_operators.config
# set current size limit to BINARY_SIZE_LIMIT_IN_BYTES.
BINARY_SIZE_LIMIT_IN_BYTES=1303608
BINARY_SIZE_LIMIT_IN_BYTES=1305000
echo "The current preset binary size limit is $BINARY_SIZE_LIMIT_IN_BYTES"
python3 /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/check_build_binary_size.py \
--threshold=$BINARY_SIZE_LIMIT_IN_BYTES \