mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-28 03:20:58 +00:00
Enable QDQ quantization for DML EP
This commit is contained in:
parent
7520974970
commit
33c1d9bc3e
4 changed files with 53 additions and 41 deletions
|
|
@ -91,7 +91,8 @@ void UnaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
|
|||
std::unique_ptr<Action> action = std::make_unique<QDQ::UnaryReplaceWithQLinear>(kMSDomain);
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::UnarySelector>();
|
||||
std::vector<const char*> providers = {kCpuExecutionProvider};
|
||||
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::UnarySelector>(providers);
|
||||
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
|
||||
{{"AveragePool", {}},
|
||||
{"LeakyRelu", {}},
|
||||
|
|
@ -108,16 +109,31 @@ void UnaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
|
|||
void BinaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
|
||||
// 4 nodes. 2 x DQ for inputs, target, Q
|
||||
// Replace with internal QLinear version of operator. Delete all original nodes.
|
||||
const std::string action_name{"2DQ"};
|
||||
std::unique_ptr<Action> action = std::make_unique<QDQ::BinaryReplaceWithQLinear>(kMSDomain);
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::BinarySelector>();
|
||||
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
|
||||
{{"Add", {}},
|
||||
{"Mul", {}}},
|
||||
std::move(selector),
|
||||
std::move(action));
|
||||
{
|
||||
const std::string action_name{"2DQ_Mul"};
|
||||
std::unique_ptr<Action> action = std::make_unique<QDQ::BinaryReplaceWithQLinear>(kMSDomain);
|
||||
|
||||
std::vector<const char*> providers = {kCpuExecutionProvider};
|
||||
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::BinarySelector>(providers);
|
||||
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
|
||||
{{"Mul", {}}},
|
||||
std::move(selector),
|
||||
std::move(action));
|
||||
}
|
||||
|
||||
{
|
||||
const std::string action_name{"2DQ_Add"};
|
||||
std::unique_ptr<Action> action = std::make_unique<QDQ::BinaryReplaceWithQLinear>(kMSDomain);
|
||||
|
||||
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::BinarySelector>();
|
||||
|
||||
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
|
||||
{{"Add", {}}},
|
||||
std::move(selector),
|
||||
std::move(action));
|
||||
}
|
||||
|
||||
#else
|
||||
qdq_selector_action_registry.RegisterAction(action_name, std::move(action));
|
||||
|
|
@ -195,7 +211,8 @@ void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
|
|||
std::unique_ptr<Action> action = std::make_unique<QDQ::GemmReplaceWithQuant>();
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::GemmSelector>();
|
||||
std::vector<const char*> providers = {kCpuExecutionProvider};
|
||||
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::GemmSelector>(providers);
|
||||
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
|
||||
{{"Gemm", {}}},
|
||||
std::move(selector),
|
||||
|
|
@ -215,7 +232,9 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
|
|||
std::unique_ptr<Action> action = std::make_unique<QDQ::WhereReplaceWithQLinear>();
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::WhereSelector>();
|
||||
|
||||
std::vector<const char*> providers = {kCpuExecutionProvider};
|
||||
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::WhereSelector>(providers);
|
||||
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
|
||||
{{"Where", {}}},
|
||||
std::move(selector),
|
||||
|
|
@ -250,8 +269,8 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer(
|
|||
"QDQSelectorActionTransformer",
|
||||
CreateSelectorActionRegistry(is_int8_allowed),
|
||||
apply_context,
|
||||
// this transformer is only compatible with the CPU EP
|
||||
{kCpuExecutionProvider}} {
|
||||
// this transformer is only compatible with the CPU and DML EP
|
||||
{kCpuExecutionProvider, kDmlExecutionProvider}} {
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -85,6 +85,14 @@ std::optional<NodeGroup> NodeGroupSelector::GetQDQSelection(const GraphViewer& g
|
|||
}
|
||||
|
||||
std::optional<NodesToOptimizeIndices> BaseSelector::Select(const GraphViewer& graph_viewer, const Node& node) const {
|
||||
|
||||
const std::string_view node_ep = node.GetExecutionProviderType();
|
||||
|
||||
if (!compatible_providers_.empty() &&
|
||||
std::find(compatible_providers_.begin(), compatible_providers_.end(), node_ep) == compatible_providers_.end()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
const auto qdq_group = node_group_selector_->GetQDQSelection(graph_viewer, node);
|
||||
if (!qdq_group.has_value()) {
|
||||
return std::nullopt;
|
||||
|
|
|
|||
|
|
@ -179,12 +179,16 @@ class BaseSelector : public NodeSelector {
|
|||
|
||||
// We std::move SelectorActionRegistry into the SelectorActionTransformer so this class needs to have a move ctor
|
||||
BaseSelector(BaseSelector&& rhs) noexcept
|
||||
: node_group_selector_{std::move(rhs.node_group_selector_)} {
|
||||
: node_group_selector_{std::move(rhs.node_group_selector_)},
|
||||
compatible_providers_{std::move(rhs.compatible_providers_)} {
|
||||
}
|
||||
|
||||
protected:
|
||||
BaseSelector(std::unique_ptr<NodeGroupSelector> node_group_selector)
|
||||
: node_group_selector_{std::move(node_group_selector)} {}
|
||||
BaseSelector(std::unique_ptr<NodeGroupSelector> node_group_selector, std::vector<const char*> compatible_providers = {})
|
||||
: node_group_selector_{std::move(node_group_selector)}
|
||||
{
|
||||
compatible_providers_.assign(compatible_providers.begin(), compatible_providers.end());
|
||||
}
|
||||
|
||||
// override if you need to adjust the values in NodesToOptimize.
|
||||
// e.g. add entries for missing optional DQ inputs or set num_inputs to handle variadic inputs
|
||||
|
|
@ -193,6 +197,7 @@ class BaseSelector : public NodeSelector {
|
|||
|
||||
private:
|
||||
std::unique_ptr<NodeGroupSelector> node_group_selector_;
|
||||
std::vector<std::string> compatible_providers_;
|
||||
};
|
||||
|
||||
class DropQDQNodesSelector : public BaseSelector {
|
||||
|
|
@ -207,12 +212,12 @@ class DropDQNodesSelector : public BaseSelector {
|
|||
|
||||
class UnarySelector : public BaseSelector {
|
||||
public:
|
||||
UnarySelector() : BaseSelector(std::make_unique<UnaryNodeGroupSelector>()) {}
|
||||
UnarySelector(std::vector<const char*> compatible_providers) : BaseSelector(std::make_unique<UnaryNodeGroupSelector>(), compatible_providers) {}
|
||||
};
|
||||
|
||||
class BinarySelector : public BaseSelector {
|
||||
public:
|
||||
BinarySelector() : BaseSelector(std::make_unique<BinaryNodeGroupSelector>()) {}
|
||||
BinarySelector(std::vector<const char*> compatible_providers = {}) : BaseSelector(std::make_unique<BinaryNodeGroupSelector>(), compatible_providers) {}
|
||||
};
|
||||
|
||||
// Variadic DQ nodes -> node -> Q
|
||||
|
|
@ -240,7 +245,7 @@ class ConvSelector : public BaseSelector {
|
|||
};
|
||||
class WhereSelector : public BaseSelector {
|
||||
public:
|
||||
WhereSelector() : BaseSelector(std::make_unique<WhereNodeGroupSelector>()) {}
|
||||
WhereSelector(std::vector<const char*> compatible_providers = {}) : BaseSelector(std::make_unique<WhereNodeGroupSelector>(), compatible_providers) {}
|
||||
};
|
||||
// 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not
|
||||
class MatMulSelector : public BaseSelector {
|
||||
|
|
@ -253,8 +258,8 @@ class MatMulSelector : public BaseSelector {
|
|||
// Output: optional Q node for Y
|
||||
class GemmSelector : public BaseSelector {
|
||||
public:
|
||||
GemmSelector()
|
||||
: BaseSelector(std::make_unique<GemmNodeGroupSelector>()) {}
|
||||
GemmSelector(std::vector<const char*> compatible_providers = {})
|
||||
: BaseSelector(std::make_unique<GemmNodeGroupSelector>(), compatible_providers) {}
|
||||
|
||||
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -555,26 +555,6 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr
|
|||
session_options_.enable_mem_pattern = false;
|
||||
}
|
||||
|
||||
// Default this option to true when the DML EP is registered.
|
||||
// This should be removed if QDQ is supported for DML through QDQSelectorActionTransformer and the DML EP does not
|
||||
// rely on the constant folding pass for DequantizeLinear.
|
||||
optional<std::string> disable_quant_qdq = session_options_.config_options.GetConfigEntry(kOrtSessionOptionsDisableQuantQDQ);
|
||||
|
||||
if (disable_quant_qdq == std::nullopt) {
|
||||
LOGS(*session_logger_, INFO)
|
||||
<< "QDQ quantization is not supported while using the DML Execution Provider. "
|
||||
<< "So disabling it for this session since it uses the DML Execution Provider.";
|
||||
|
||||
auto st = session_options_.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1");
|
||||
if (!st.IsOK()) {
|
||||
return st;
|
||||
}
|
||||
} else if (*disable_quant_qdq != "1") {
|
||||
LOGS(*session_logger_, WARNING)
|
||||
<< "QDQ quantization is not supported while using the DML Execution Provider. "
|
||||
<< "It is enabled within session options which may result in lower performance.";
|
||||
}
|
||||
|
||||
// Parallel execution mode does not support DML EP
|
||||
if (session_options_.execution_mode != ExecutionMode::ORT_SEQUENTIAL) {
|
||||
LOGS(*session_logger_, INFO)
|
||||
|
|
|
|||
Loading…
Reference in a new issue