Enable QDQ quantization for DML EP

This commit is contained in:
Jeff Bloomfield 2023-08-30 11:27:53 -07:00
parent 7520974970
commit 33c1d9bc3e
4 changed files with 53 additions and 41 deletions

View file

@ -91,7 +91,8 @@ void UnaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
std::unique_ptr<Action> action = std::make_unique<QDQ::UnaryReplaceWithQLinear>(kMSDomain);
#if !defined(ORT_MINIMAL_BUILD)
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::UnarySelector>();
std::vector<const char*> providers = {kCpuExecutionProvider};
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::UnarySelector>(providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"AveragePool", {}},
{"LeakyRelu", {}},
@ -108,16 +109,31 @@ void UnaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
void BinaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
// 4 nodes. 2 x DQ for inputs, target, Q
// Replace with internal QLinear version of operator. Delete all original nodes.
const std::string action_name{"2DQ"};
std::unique_ptr<Action> action = std::make_unique<QDQ::BinaryReplaceWithQLinear>(kMSDomain);
#if !defined(ORT_MINIMAL_BUILD)
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::BinarySelector>();
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Add", {}},
{"Mul", {}}},
std::move(selector),
std::move(action));
{
const std::string action_name{"2DQ_Mul"};
std::unique_ptr<Action> action = std::make_unique<QDQ::BinaryReplaceWithQLinear>(kMSDomain);
std::vector<const char*> providers = {kCpuExecutionProvider};
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::BinarySelector>(providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Mul", {}}},
std::move(selector),
std::move(action));
}
{
const std::string action_name{"2DQ_Add"};
std::unique_ptr<Action> action = std::make_unique<QDQ::BinaryReplaceWithQLinear>(kMSDomain);
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::BinarySelector>();
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Add", {}}},
std::move(selector),
std::move(action));
}
#else
qdq_selector_action_registry.RegisterAction(action_name, std::move(action));
@ -195,7 +211,8 @@ void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
std::unique_ptr<Action> action = std::make_unique<QDQ::GemmReplaceWithQuant>();
#if !defined(ORT_MINIMAL_BUILD)
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::GemmSelector>();
std::vector<const char*> providers = {kCpuExecutionProvider};
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::GemmSelector>(providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Gemm", {}}},
std::move(selector),
@ -215,7 +232,9 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
std::unique_ptr<Action> action = std::make_unique<QDQ::WhereReplaceWithQLinear>();
#if !defined(ORT_MINIMAL_BUILD)
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::WhereSelector>();
std::vector<const char*> providers = {kCpuExecutionProvider};
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::WhereSelector>(providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Where", {}}},
std::move(selector),
@ -250,8 +269,8 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer(
"QDQSelectorActionTransformer",
CreateSelectorActionRegistry(is_int8_allowed),
apply_context,
// this transformer is only compatible with the CPU EP
{kCpuExecutionProvider}} {
// this transformer is only compatible with the CPU and DML EP
{kCpuExecutionProvider, kDmlExecutionProvider}} {
}
} // namespace onnxruntime

View file

@ -85,6 +85,14 @@ std::optional<NodeGroup> NodeGroupSelector::GetQDQSelection(const GraphViewer& g
}
std::optional<NodesToOptimizeIndices> BaseSelector::Select(const GraphViewer& graph_viewer, const Node& node) const {
const std::string_view node_ep = node.GetExecutionProviderType();
if (!compatible_providers_.empty() &&
std::find(compatible_providers_.begin(), compatible_providers_.end(), node_ep) == compatible_providers_.end()) {
return std::nullopt;
}
const auto qdq_group = node_group_selector_->GetQDQSelection(graph_viewer, node);
if (!qdq_group.has_value()) {
return std::nullopt;

View file

@ -179,12 +179,16 @@ class BaseSelector : public NodeSelector {
// We std::move SelectorActionRegistry into the SelectorActionTransformer so this class needs to have a move ctor
BaseSelector(BaseSelector&& rhs) noexcept
: node_group_selector_{std::move(rhs.node_group_selector_)} {
: node_group_selector_{std::move(rhs.node_group_selector_)},
compatible_providers_{std::move(rhs.compatible_providers_)} {
}
protected:
BaseSelector(std::unique_ptr<NodeGroupSelector> node_group_selector)
: node_group_selector_{std::move(node_group_selector)} {}
BaseSelector(std::unique_ptr<NodeGroupSelector> node_group_selector, std::vector<const char*> compatible_providers = {})
: node_group_selector_{std::move(node_group_selector)}
{
compatible_providers_.assign(compatible_providers.begin(), compatible_providers.end());
}
// override if you need to adjust the values in NodesToOptimize.
// e.g. add entries for missing optional DQ inputs or set num_inputs to handle variadic inputs
@ -193,6 +197,7 @@ class BaseSelector : public NodeSelector {
private:
std::unique_ptr<NodeGroupSelector> node_group_selector_;
std::vector<std::string> compatible_providers_;
};
class DropQDQNodesSelector : public BaseSelector {
@ -207,12 +212,12 @@ class DropDQNodesSelector : public BaseSelector {
class UnarySelector : public BaseSelector {
public:
UnarySelector() : BaseSelector(std::make_unique<UnaryNodeGroupSelector>()) {}
UnarySelector(std::vector<const char*> compatible_providers) : BaseSelector(std::make_unique<UnaryNodeGroupSelector>(), compatible_providers) {}
};
class BinarySelector : public BaseSelector {
public:
BinarySelector() : BaseSelector(std::make_unique<BinaryNodeGroupSelector>()) {}
BinarySelector(std::vector<const char*> compatible_providers = {}) : BaseSelector(std::make_unique<BinaryNodeGroupSelector>(), compatible_providers) {}
};
// Variadic DQ nodes -> node -> Q
@ -240,7 +245,7 @@ class ConvSelector : public BaseSelector {
};
class WhereSelector : public BaseSelector {
public:
WhereSelector() : BaseSelector(std::make_unique<WhereNodeGroupSelector>()) {}
WhereSelector(std::vector<const char*> compatible_providers = {}) : BaseSelector(std::make_unique<WhereNodeGroupSelector>(), compatible_providers) {}
};
// 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not
class MatMulSelector : public BaseSelector {
@ -253,8 +258,8 @@ class MatMulSelector : public BaseSelector {
// Output: optional Q node for Y
class GemmSelector : public BaseSelector {
public:
GemmSelector()
: BaseSelector(std::make_unique<GemmNodeGroupSelector>()) {}
GemmSelector(std::vector<const char*> compatible_providers = {})
: BaseSelector(std::make_unique<GemmNodeGroupSelector>(), compatible_providers) {}
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
};

View file

@ -555,26 +555,6 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr
session_options_.enable_mem_pattern = false;
}
// Default this option to true when the DML EP is registered.
// This should be removed if QDQ is supported for DML through QDQSelectorActionTransformer and the DML EP does not
// rely on the constant folding pass for DequantizeLinear.
optional<std::string> disable_quant_qdq = session_options_.config_options.GetConfigEntry(kOrtSessionOptionsDisableQuantQDQ);
if (disable_quant_qdq == std::nullopt) {
LOGS(*session_logger_, INFO)
<< "QDQ quantization is not supported while using the DML Execution Provider. "
<< "So disabling it for this session since it uses the DML Execution Provider.";
auto st = session_options_.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1");
if (!st.IsOK()) {
return st;
}
} else if (*disable_quant_qdq != "1") {
LOGS(*session_logger_, WARNING)
<< "QDQ quantization is not supported while using the DML Execution Provider. "
<< "It is enabled within session options which may result in lower performance.";
}
// Parallel execution mode does not support DML EP
if (session_options_.execution_mode != ExecutionMode::ORT_SEQUENTIAL) {
LOGS(*session_logger_, INFO)