mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
[QDQ] Add shared qdq selectors (#10178)
* wip * wip * wip * wip * wip * save * minor changes * update test graph name * address pr comments * update * address pr comments * address pr comments * fix warning * minor include fix * update to nodegroupselectors * delete unnecessary includes Co-authored-by: rachguo <rachguo@rachguos-Mini.attlocal.net>
This commit is contained in:
parent
79d2a0d185
commit
a099bd454b
4 changed files with 256 additions and 0 deletions
|
|
@ -32,6 +32,8 @@ else()
|
|||
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/*.cc"
|
||||
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/*.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/*.cc"
|
||||
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc"
|
||||
"${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/*.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/*.cc"
|
||||
"${ONNXRUNTIME_ROOT}/core/optimizer/transpose_optimizer/*.h"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,144 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <core/graph/graph_viewer.h>
|
||||
#include <core/providers/common.h>
|
||||
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace QDQ {
|
||||
|
||||
void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops_and_versions_in,
|
||||
std::unique_ptr<NodeGroupSelector> selector_in) {
|
||||
auto entry = std::make_unique<OpVersionsAndSelector>(
|
||||
ops_and_versions_in,
|
||||
std::move(selector_in));
|
||||
|
||||
ORT_IGNORE_RETURN_VALUE(selectors_set_.insert(std::move(entry)));
|
||||
}
|
||||
|
||||
/* static methods to return different operator's OpVersionMap */
|
||||
static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { return {{"Gather", {}},
|
||||
{"Reshape", {}},
|
||||
{"Transpose", {}},
|
||||
{"MaxPool", {12}},
|
||||
{"Resize", {}}}; }
|
||||
|
||||
static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() { return {{"AveragePool", {}},
|
||||
{"LeakyRelu", {}}}; }
|
||||
static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() { return {{"Add", {}},
|
||||
{"Mul", {}}}; }
|
||||
static const OpVersionsAndSelector::OpVersionsMap GetVariadicOpVersionsMap() { return {{"Concat", {}}}; }
|
||||
static const OpVersionsAndSelector::OpVersionsMap GetConvOpVersionsMap() { return {{"Conv", {}}}; }
|
||||
static const OpVersionsAndSelector::OpVersionsMap GetMatMulOpVersionsMap() { return {{"MatMul", {}}}; }
|
||||
|
||||
/* Selector rules registration related */
|
||||
void RegisterMiscSelectors(Selectors& qdq_selectors) {
|
||||
/* register selectors for miscellaneous ops */
|
||||
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<DropQDQNodeGroupSelector>();
|
||||
qdq_selectors.RegisterSelector(GetMiscOpVersionsMap(),
|
||||
std::move(selector));
|
||||
}
|
||||
|
||||
void RegisterUnarySelectors(Selectors& qdq_selectors) {
|
||||
/* regsiter selectors for unary ops */
|
||||
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<UnaryNodeGroupSelector>();
|
||||
qdq_selectors.RegisterSelector(GetUnaryOpVersionsMap(),
|
||||
std::move(selector));
|
||||
}
|
||||
|
||||
void RegisterBinarySelectors(Selectors& qdq_selectors) {
|
||||
/* register selectors for binary ops */
|
||||
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<BinaryNodeGroupSelector>();
|
||||
qdq_selectors.RegisterSelector(GetBinaryOpVersionsMap(),
|
||||
std::move(selector));
|
||||
}
|
||||
|
||||
void RegisterVariadicSelectors(Selectors& qdq_selectors) {
|
||||
/* register selectors for variadic ops */
|
||||
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<VariadicNodeGroupSelector>();
|
||||
qdq_selectors.RegisterSelector(GetVariadicOpVersionsMap(),
|
||||
std::move(selector));
|
||||
}
|
||||
|
||||
void RegisterConvSelector(Selectors& qdq_selectors) {
|
||||
/* register selector for conv op */
|
||||
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<ConvNodeGroupSelector>();
|
||||
qdq_selectors.RegisterSelector(GetConvOpVersionsMap(),
|
||||
std::move(selector));
|
||||
}
|
||||
|
||||
void RegisterMatMulSelector(Selectors& qdq_selectors) {
|
||||
/* register selector for matmul op */
|
||||
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<MatMulNodeGroupSelector>();
|
||||
qdq_selectors.RegisterSelector(GetMatMulOpVersionsMap(),
|
||||
std::move(selector));
|
||||
}
|
||||
|
||||
void SelectorManager::CreateSelectors() {
|
||||
RegisterMiscSelectors(qdq_selectors_);
|
||||
RegisterUnarySelectors(qdq_selectors_);
|
||||
RegisterBinarySelectors(qdq_selectors_);
|
||||
RegisterVariadicSelectors(qdq_selectors_);
|
||||
RegisterConvSelector(qdq_selectors_);
|
||||
RegisterMatMulSelector(qdq_selectors_);
|
||||
}
|
||||
|
||||
void SelectorManager::InitializeSelectorsMap() {
|
||||
for (const auto& entry : qdq_selectors_.SelectorsSet()) {
|
||||
for (const auto& op_info : entry->op_versions_map) {
|
||||
bool inserted = op_type_to_selectors_map_.insert({op_info.first, &*entry}).second;
|
||||
ORT_ENFORCE(inserted, "Multiple entries for operator is not supported. OpType=", op_info.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SelectorManager::Initialize() {
|
||||
CreateSelectors();
|
||||
InitializeSelectorsMap();
|
||||
}
|
||||
|
||||
std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer) const {
|
||||
std::vector<NodeGroup> qdq_selections;
|
||||
for (auto index : graph_viewer.GetNodesInTopologicalOrder()) {
|
||||
const auto* node = graph_viewer.GetNode(index);
|
||||
if (node->Domain() != kOnnxDomain) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto op_rule = op_type_to_selectors_map_.find(node->OpType());
|
||||
if (op_rule == op_type_to_selectors_map_.cend()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& op_versions_and_selector = *op_rule->second;
|
||||
|
||||
// check the supported versions if specified
|
||||
const auto& versions = op_versions_and_selector.op_versions_map.find(node->OpType())->second;
|
||||
if (!versions.empty()) {
|
||||
if (std::find(versions.cbegin(), versions.cend(), node->SinceVersion()) == versions.cend()) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Op version is not supported for" << node->OpType();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const auto qdq_node_group_selection = op_versions_and_selector.selector->GetQDQSelection(graph_viewer, *node);
|
||||
if (qdq_node_group_selection.has_value()) {
|
||||
const auto& qdq_group = *qdq_node_group_selection;
|
||||
qdq_selections.push_back(qdq_group);
|
||||
}
|
||||
}
|
||||
|
||||
return qdq_selections;
|
||||
}
|
||||
|
||||
} // namespace QDQ
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "core/graph/basic_types.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
#include "core/optimizer/selectors_actions/helpers.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class GraphViewer;
|
||||
class Node;
|
||||
|
||||
namespace QDQ {
|
||||
|
||||
// struct that provides a join between selector and op versions supported
|
||||
struct OpVersionsAndSelector {
|
||||
using OpVersionsMap = std::unordered_map<std::string, std::vector<ONNX_NAMESPACE::OperatorSetVersion>>;
|
||||
|
||||
OpVersionsAndSelector(const OpVersionsMap& ops_and_versions_in,
|
||||
std::unique_ptr<NodeGroupSelector> selector_in)
|
||||
: op_versions_map{ops_and_versions_in},
|
||||
selector{std::move(selector_in)} {}
|
||||
|
||||
OpVersionsMap op_versions_map;
|
||||
std::unique_ptr<NodeGroupSelector> selector;
|
||||
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpVersionsAndSelector);
|
||||
};
|
||||
|
||||
// class that manages a set of node group selectors
|
||||
class Selectors {
|
||||
public:
|
||||
Selectors() = default;
|
||||
|
||||
// register a selector for the specified ops.
|
||||
void RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops_and_versions_in,
|
||||
std::unique_ptr<NodeGroupSelector> selector_in);
|
||||
|
||||
const std::unordered_set<std::unique_ptr<OpVersionsAndSelector>>& SelectorsSet() const {
|
||||
return selectors_set_;
|
||||
}
|
||||
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Selectors);
|
||||
|
||||
private:
|
||||
std::unordered_set<std::unique_ptr<OpVersionsAndSelector>> selectors_set_;
|
||||
};
|
||||
|
||||
// class that manages qdq node group selections
|
||||
class SelectorManager {
|
||||
public:
|
||||
SelectorManager() = default;
|
||||
|
||||
void Initialize();
|
||||
|
||||
// Methods that finds and returns a vector of QDQ::NodeGroup in a given graph
|
||||
// Can be used in QDQ support in different EPs
|
||||
std::vector<NodeGroup> GetQDQSelections(const GraphViewer& graph_viewer) const;
|
||||
|
||||
private:
|
||||
Selectors qdq_selectors_;
|
||||
|
||||
std::unordered_map<std::string, const OpVersionsAndSelector*> op_type_to_selectors_map_;
|
||||
|
||||
void InitializeSelectorsMap();
|
||||
|
||||
void CreateSelectors();
|
||||
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SelectorManager);
|
||||
};
|
||||
|
||||
} // namespace QDQ
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -7,6 +7,7 @@
|
|||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
|
||||
#include "core/providers/partitioning_utils.h"
|
||||
#include "core/session/environment.h"
|
||||
#include "core/session/inference_session.h"
|
||||
|
|
@ -1826,5 +1827,38 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) {
|
|||
ASSERT_FALSE(result.has_value());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, QDQ_Shared_GetSelectors_Test) {
|
||||
const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/transform/qdq_conv.onnx");
|
||||
|
||||
SessionOptions so;
|
||||
so.graph_optimization_level = TransformerLevel::Default;
|
||||
InferenceSessionWrapper session_object{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object.Load(model_file_name));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
const Graph& graph = session_object.GetGraph();
|
||||
const auto* conv_node = graph.GetNode(3);
|
||||
|
||||
// Make sure node 3 is the conv node
|
||||
ASSERT_TRUE(nullptr != conv_node);
|
||||
ASSERT_EQ("Conv", conv_node->OpType());
|
||||
|
||||
const GraphViewer graph_viewer(graph);
|
||||
|
||||
// Initialize SelectorManager
|
||||
QDQ::SelectorManager selector_mgr;
|
||||
selector_mgr.Initialize();
|
||||
|
||||
// Check if SelectorManager get a conv qdq group selection as expected
|
||||
{
|
||||
const auto result = selector_mgr.GetQDQSelections(graph_viewer);
|
||||
ASSERT_EQ(false, result.empty());
|
||||
const auto& qdq_group = result.at(0);
|
||||
ASSERT_EQ(std::vector<NodeIndex>({0, 1, 2}), qdq_group.dq_nodes);
|
||||
ASSERT_EQ(NodeIndex(3), qdq_group.target_node);
|
||||
ASSERT_EQ(std::vector<NodeIndex>({4}), qdq_group.q_nodes);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue