[QDQ] Add shared qdq selectors (#10178)

* wip

* wip

* wip

* wip

* wip

* save

* minor changes

* update test graph name

* address pr comments

* update

* address pr comments

* address pr comments

* fix warning

* minor include fix

* update to nodegroupselectors

* delete unnecessary includes

Co-authored-by: rachguo <rachguo@rachguos-Mini.attlocal.net>
This commit is contained in:
Rachel Guo 2022-01-11 19:41:45 -08:00 committed by GitHub
parent 79d2a0d185
commit a099bd454b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 256 additions and 0 deletions

View file

@ -32,6 +32,8 @@ else()
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/*.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/*.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/*.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/*.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/*.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/transpose_optimizer/*.h"

View file

@ -0,0 +1,144 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "utils.h"
#include <iostream>
#include <string>
#include <vector>
#include <core/graph/graph_viewer.h>
#include <core/providers/common.h>
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
namespace onnxruntime {
namespace QDQ {
void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops_and_versions_in,
std::unique_ptr<NodeGroupSelector> selector_in) {
auto entry = std::make_unique<OpVersionsAndSelector>(
ops_and_versions_in,
std::move(selector_in));
ORT_IGNORE_RETURN_VALUE(selectors_set_.insert(std::move(entry)));
}
/* static methods to return different operator's OpVersionMap */
static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { return {{"Gather", {}},
{"Reshape", {}},
{"Transpose", {}},
{"MaxPool", {12}},
{"Resize", {}}}; }
static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() { return {{"AveragePool", {}},
{"LeakyRelu", {}}}; }
static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() { return {{"Add", {}},
{"Mul", {}}}; }
static const OpVersionsAndSelector::OpVersionsMap GetVariadicOpVersionsMap() { return {{"Concat", {}}}; }
static const OpVersionsAndSelector::OpVersionsMap GetConvOpVersionsMap() { return {{"Conv", {}}}; }
static const OpVersionsAndSelector::OpVersionsMap GetMatMulOpVersionsMap() { return {{"MatMul", {}}}; }
/* Selector rules registration related */
void RegisterMiscSelectors(Selectors& qdq_selectors) {
/* register selectors for miscellaneous ops */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<DropQDQNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetMiscOpVersionsMap(),
std::move(selector));
}
void RegisterUnarySelectors(Selectors& qdq_selectors) {
/* regsiter selectors for unary ops */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<UnaryNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetUnaryOpVersionsMap(),
std::move(selector));
}
void RegisterBinarySelectors(Selectors& qdq_selectors) {
/* register selectors for binary ops */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<BinaryNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetBinaryOpVersionsMap(),
std::move(selector));
}
void RegisterVariadicSelectors(Selectors& qdq_selectors) {
/* register selectors for variadic ops */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<VariadicNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetVariadicOpVersionsMap(),
std::move(selector));
}
void RegisterConvSelector(Selectors& qdq_selectors) {
/* register selector for conv op */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<ConvNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetConvOpVersionsMap(),
std::move(selector));
}
void RegisterMatMulSelector(Selectors& qdq_selectors) {
/* register selector for matmul op */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<MatMulNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetMatMulOpVersionsMap(),
std::move(selector));
}
void SelectorManager::CreateSelectors() {
RegisterMiscSelectors(qdq_selectors_);
RegisterUnarySelectors(qdq_selectors_);
RegisterBinarySelectors(qdq_selectors_);
RegisterVariadicSelectors(qdq_selectors_);
RegisterConvSelector(qdq_selectors_);
RegisterMatMulSelector(qdq_selectors_);
}
void SelectorManager::InitializeSelectorsMap() {
for (const auto& entry : qdq_selectors_.SelectorsSet()) {
for (const auto& op_info : entry->op_versions_map) {
bool inserted = op_type_to_selectors_map_.insert({op_info.first, &*entry}).second;
ORT_ENFORCE(inserted, "Multiple entries for operator is not supported. OpType=", op_info.first);
}
}
}
void SelectorManager::Initialize() {
CreateSelectors();
InitializeSelectorsMap();
}
std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer) const {
std::vector<NodeGroup> qdq_selections;
for (auto index : graph_viewer.GetNodesInTopologicalOrder()) {
const auto* node = graph_viewer.GetNode(index);
if (node->Domain() != kOnnxDomain) {
continue;
}
auto op_rule = op_type_to_selectors_map_.find(node->OpType());
if (op_rule == op_type_to_selectors_map_.cend()) {
continue;
}
const auto& op_versions_and_selector = *op_rule->second;
// check the supported versions if specified
const auto& versions = op_versions_and_selector.op_versions_map.find(node->OpType())->second;
if (!versions.empty()) {
if (std::find(versions.cbegin(), versions.cend(), node->SinceVersion()) == versions.cend()) {
LOGS_DEFAULT(VERBOSE) << "Op version is not supported for" << node->OpType();
continue;
}
}
const auto qdq_node_group_selection = op_versions_and_selector.selector->GetQDQSelection(graph_viewer, *node);
if (qdq_node_group_selection.has_value()) {
const auto& qdq_group = *qdq_node_group_selection;
qdq_selections.push_back(qdq_group);
}
}
return qdq_selections;
}
} // namespace QDQ
} // namespace onnxruntime

View file

@ -0,0 +1,76 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include "core/graph/basic_types.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/selectors_actions/helpers.h"
namespace onnxruntime {
class GraphViewer;
class Node;
namespace QDQ {
// struct that provides a join between selector and op versions supported
struct OpVersionsAndSelector {
using OpVersionsMap = std::unordered_map<std::string, std::vector<ONNX_NAMESPACE::OperatorSetVersion>>;
OpVersionsAndSelector(const OpVersionsMap& ops_and_versions_in,
std::unique_ptr<NodeGroupSelector> selector_in)
: op_versions_map{ops_and_versions_in},
selector{std::move(selector_in)} {}
OpVersionsMap op_versions_map;
std::unique_ptr<NodeGroupSelector> selector;
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpVersionsAndSelector);
};
// class that manages a set of node group selectors
class Selectors {
public:
Selectors() = default;
// register a selector for the specified ops.
void RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops_and_versions_in,
std::unique_ptr<NodeGroupSelector> selector_in);
const std::unordered_set<std::unique_ptr<OpVersionsAndSelector>>& SelectorsSet() const {
return selectors_set_;
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Selectors);
private:
std::unordered_set<std::unique_ptr<OpVersionsAndSelector>> selectors_set_;
};
// class that manages qdq node group selections
class SelectorManager {
public:
SelectorManager() = default;
void Initialize();
// Methods that finds and returns a vector of QDQ::NodeGroup in a given graph
// Can be used in QDQ support in different EPs
std::vector<NodeGroup> GetQDQSelections(const GraphViewer& graph_viewer) const;
private:
Selectors qdq_selectors_;
std::unordered_map<std::string, const OpVersionsAndSelector*> op_type_to_selectors_map_;
void InitializeSelectorsMap();
void CreateSelectors();
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SelectorManager);
};
} // namespace QDQ
} // namespace onnxruntime

View file

@ -7,6 +7,7 @@
#include "core/mlas/inc/mlas.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
#include "core/providers/partitioning_utils.h"
#include "core/session/environment.h"
#include "core/session/inference_session.h"
@ -1826,5 +1827,38 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) {
ASSERT_FALSE(result.has_value());
}
}
TEST(QDQTransformerTests, QDQ_Shared_GetSelectors_Test) {
const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/transform/qdq_conv.onnx");
SessionOptions so;
so.graph_optimization_level = TransformerLevel::Default;
InferenceSessionWrapper session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.Load(model_file_name));
ASSERT_STATUS_OK(session_object.Initialize());
const Graph& graph = session_object.GetGraph();
const auto* conv_node = graph.GetNode(3);
// Make sure node 3 is the conv node
ASSERT_TRUE(nullptr != conv_node);
ASSERT_EQ("Conv", conv_node->OpType());
const GraphViewer graph_viewer(graph);
// Initialize SelectorManager
QDQ::SelectorManager selector_mgr;
selector_mgr.Initialize();
// Check if SelectorManager get a conv qdq group selection as expected
{
const auto result = selector_mgr.GetQDQSelections(graph_viewer);
ASSERT_EQ(false, result.empty());
const auto& qdq_group = result.at(0);
ASSERT_EQ(std::vector<NodeIndex>({0, 1, 2}), qdq_group.dq_nodes);
ASSERT_EQ(NodeIndex(3), qdq_group.target_node);
ASSERT_EQ(std::vector<NodeIndex>({4}), qdq_group.q_nodes);
}
}
} // namespace test
} // namespace onnxruntime