From a099bd454b8f4cb5098f2a45f757b1cc40e73869 Mon Sep 17 00:00:00 2001 From: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Date: Tue, 11 Jan 2022 19:41:45 -0800 Subject: [PATCH] [QDQ] Add shared qdq selectors (#10178) * wip * wip * wip * wip * wip * save * minor changes * update test graph name * address pr comments * update * address pr comments * address pr comments * fix warning * minor include fix * update to nodegroupselectors * delete unnecessary includes Co-authored-by: rachguo --- cmake/onnxruntime_optimizer.cmake | 2 + .../selectors_actions/shared/utils.cc | 144 ++++++++++++++++++ .../selectors_actions/shared/utils.h | 76 +++++++++ .../test/optimizer/qdq_transformer_test.cc | 34 +++++ 4 files changed, 256 insertions(+) create mode 100644 onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc create mode 100644 onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index 6350b8a0a7..8b1e84acc4 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -32,6 +32,8 @@ else() "${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/*.cc" "${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/*.h" "${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/*.cc" + "${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" + "${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc" "${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/*.h" "${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/*.cc" "${ONNXRUNTIME_ROOT}/core/optimizer/transpose_optimizer/*.h" diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc new file mode 100644 index 0000000000..329ec66d94 --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "utils.h" + +#include +#include +#include + +#include +#include + +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" + +namespace onnxruntime { +namespace QDQ { + +void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops_and_versions_in, + std::unique_ptr selector_in) { + auto entry = std::make_unique( + ops_and_versions_in, + std::move(selector_in)); + + ORT_IGNORE_RETURN_VALUE(selectors_set_.insert(std::move(entry))); +} + +/* static methods to return different operator's OpVersionMap */ +static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { return {{"Gather", {}}, + {"Reshape", {}}, + {"Transpose", {}}, + {"MaxPool", {12}}, + {"Resize", {}}}; } + +static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() { return {{"AveragePool", {}}, + {"LeakyRelu", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() { return {{"Add", {}}, + {"Mul", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetVariadicOpVersionsMap() { return {{"Concat", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetConvOpVersionsMap() { return {{"Conv", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetMatMulOpVersionsMap() { return {{"MatMul", {}}}; } + +/* Selector rules registration related */ +void RegisterMiscSelectors(Selectors& qdq_selectors) { + /* register selectors for miscellaneous ops */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetMiscOpVersionsMap(), + std::move(selector)); +} + +void RegisterUnarySelectors(Selectors& qdq_selectors) { + /* regsiter selectors for unary ops */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetUnaryOpVersionsMap(), + std::move(selector)); +} + +void RegisterBinarySelectors(Selectors& qdq_selectors) { + /* register selectors for binary ops */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetBinaryOpVersionsMap(), + std::move(selector)); +} + +void RegisterVariadicSelectors(Selectors& qdq_selectors) { + /* register selectors for variadic ops */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetVariadicOpVersionsMap(), + std::move(selector)); +} + +void RegisterConvSelector(Selectors& qdq_selectors) { + /* register selector for conv op */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetConvOpVersionsMap(), + std::move(selector)); +} + +void RegisterMatMulSelector(Selectors& qdq_selectors) { + /* register selector for matmul op */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetMatMulOpVersionsMap(), + std::move(selector)); +} + +void SelectorManager::CreateSelectors() { + RegisterMiscSelectors(qdq_selectors_); + RegisterUnarySelectors(qdq_selectors_); + RegisterBinarySelectors(qdq_selectors_); + RegisterVariadicSelectors(qdq_selectors_); + RegisterConvSelector(qdq_selectors_); + RegisterMatMulSelector(qdq_selectors_); +} + +void SelectorManager::InitializeSelectorsMap() { + for (const auto& entry : qdq_selectors_.SelectorsSet()) { + for (const auto& op_info : entry->op_versions_map) { + bool inserted = op_type_to_selectors_map_.insert({op_info.first, &*entry}).second; + ORT_ENFORCE(inserted, "Multiple entries for operator is not supported. OpType=", op_info.first); + } + } +} + +void SelectorManager::Initialize() { + CreateSelectors(); + InitializeSelectorsMap(); +} + +std::vector SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer) const { + std::vector qdq_selections; + for (auto index : graph_viewer.GetNodesInTopologicalOrder()) { + const auto* node = graph_viewer.GetNode(index); + if (node->Domain() != kOnnxDomain) { + continue; + } + + auto op_rule = op_type_to_selectors_map_.find(node->OpType()); + if (op_rule == op_type_to_selectors_map_.cend()) { + continue; + } + + const auto& op_versions_and_selector = *op_rule->second; + + // check the supported versions if specified + const auto& versions = op_versions_and_selector.op_versions_map.find(node->OpType())->second; + if (!versions.empty()) { + if (std::find(versions.cbegin(), versions.cend(), node->SinceVersion()) == versions.cend()) { + LOGS_DEFAULT(VERBOSE) << "Op version is not supported for" << node->OpType(); + continue; + } + } + + const auto qdq_node_group_selection = op_versions_and_selector.selector->GetQDQSelection(graph_viewer, *node); + if (qdq_node_group_selection.has_value()) { + const auto& qdq_group = *qdq_node_group_selection; + qdq_selections.push_back(qdq_group); + } + } + + return qdq_selections; +} + +} // namespace QDQ +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h new file mode 100644 index 0000000000..cb44ed2fa5 --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/graph/basic_types.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/selectors_actions/helpers.h" + +namespace onnxruntime { + +class GraphViewer; +class Node; + +namespace QDQ { + +// struct that provides a join between selector and op versions supported +struct OpVersionsAndSelector { + using OpVersionsMap = std::unordered_map>; + + OpVersionsAndSelector(const OpVersionsMap& ops_and_versions_in, + std::unique_ptr selector_in) + : op_versions_map{ops_and_versions_in}, + selector{std::move(selector_in)} {} + + OpVersionsMap op_versions_map; + std::unique_ptr selector; + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpVersionsAndSelector); +}; + +// class that manages a set of node group selectors +class Selectors { + public: + Selectors() = default; + + // register a selector for the specified ops. + void RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops_and_versions_in, + std::unique_ptr selector_in); + + const std::unordered_set>& SelectorsSet() const { + return selectors_set_; + } + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Selectors); + + private: + std::unordered_set> selectors_set_; +}; + +// class that manages qdq node group selections +class SelectorManager { + public: + SelectorManager() = default; + + void Initialize(); + + // Methods that finds and returns a vector of QDQ::NodeGroup in a given graph + // Can be used in QDQ support in different EPs + std::vector GetQDQSelections(const GraphViewer& graph_viewer) const; + + private: + Selectors qdq_selectors_; + + std::unordered_map op_type_to_selectors_map_; + + void InitializeSelectorsMap(); + + void CreateSelectors(); + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SelectorManager); +}; + +} // namespace QDQ +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 4470eb4293..4eea3be246 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -7,6 +7,7 @@ #include "core/mlas/inc/mlas.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/providers/partitioning_utils.h" #include "core/session/environment.h" #include "core/session/inference_session.h" @@ -1826,5 +1827,38 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { ASSERT_FALSE(result.has_value()); } } + +TEST(QDQTransformerTests, QDQ_Shared_GetSelectors_Test) { + const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/transform/qdq_conv.onnx"); + + SessionOptions so; + so.graph_optimization_level = TransformerLevel::Default; + InferenceSessionWrapper session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(model_file_name)); + ASSERT_STATUS_OK(session_object.Initialize()); + const Graph& graph = session_object.GetGraph(); + const auto* conv_node = graph.GetNode(3); + + // Make sure node 3 is the conv node + ASSERT_TRUE(nullptr != conv_node); + ASSERT_EQ("Conv", conv_node->OpType()); + + const GraphViewer graph_viewer(graph); + + // Initialize SelectorManager + QDQ::SelectorManager selector_mgr; + selector_mgr.Initialize(); + + // Check if SelectorManager get a conv qdq group selection as expected + { + const auto result = selector_mgr.GetQDQSelections(graph_viewer); + ASSERT_EQ(false, result.empty()); + const auto& qdq_group = result.at(0); + ASSERT_EQ(std::vector({0, 1, 2}), qdq_group.dq_nodes); + ASSERT_EQ(NodeIndex(3), qdq_group.target_node); + ASSERT_EQ(std::vector({4}), qdq_group.q_nodes); + } +} + } // namespace test } // namespace onnxruntime