[DML EP] Disable DML Graph Fusion for lower graph optimization level OR setOptimizedFilePath true (#13913)

### Description
DML EP won't fuse the ONNX Graph if ORT Graph optimization level is <= 1
or `SessionOption::SetOptimizedFilePath` is passed.

This is the successor of
https://github.com/microsoft/onnxruntime/pull/11346.

### Motivation and Context
- **Why is this change required? What problem does it solve?**  
Requested by few a users (issues below) and also helps in debugging.
- **If it fixes an open issue, please link to the issue here:**
  - https://github.com/microsoft/onnxruntime/issues/13535
  - https://github.com/microsoft/onnxruntime/issues/8440
This commit is contained in:
Sumit Agarwal 2022-12-12 10:15:51 -08:00 committed by GitHub
parent 8cfbc4fe91
commit fe827c3891
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 18 deletions

View file

@ -20,8 +20,21 @@ public:
PoolingHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), useGlobalPooling),
m_function(function)
{
DmlOperator::Initialize(kernelInfo);
const bool hasDilations =
std::any_of(
m_kernel.dilations,
m_kernel.dilations + m_kernel.spatialDimensionCount,
[](auto d) {return d != 1; }
);
const bool hasOutputIndices = (kernelInfo.GetOutputCount() > 1 && kernelInfo.IsOutputValid(1));
std::vector<std::optional<uint32_t>> kernelOutputIndices = {0};
if (function == DML_OPERATOR_MAX_POOLING2 && (hasOutputIndices || hasDilations))
{
kernelOutputIndices.emplace_back(1);
}
DmlOperator::Initialize(kernelInfo, std::nullopt, kernelOutputIndices);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
ML_CHECK_VALID_ARGUMENT(inputDescs.size() >= 1, "MaxPool input count must be >=1.");
@ -33,13 +46,6 @@ public:
int storageOrder = kernelInfo.GetOptionalAttribute<int>(AttrName::StorageOrder, 0);
ORT_THROW_HR_IF(E_NOTIMPL, storageOrder != 0);
const bool hasDilations =
std::any_of(
m_kernel.dilations,
m_kernel.dilations + m_kernel.spatialDimensionCount,
[](auto d) {return d != 1; }
);
// DML requires that DimensionCount be equal to Input.DimCount - 2 for Pooling
uint32_t expectedSpatialDimCount = m_inputTensorDescs[0].GetDimensionCount() - 2;
if (m_kernel.spatialDimensionCount < expectedSpatialDimCount)
@ -104,7 +110,6 @@ public:
case DML_OPERATOR_MAX_POOLING1:
case DML_OPERATOR_MAX_POOLING2:
{
bool hasOutputIndices = (outputDescs.size() > 1 && outputDescs[1].Desc != nullptr);
if (hasOutputIndices || hasDilations)
{
DML_MAX_POOLING2_OPERATOR_DESC desc = {};

View file

@ -1403,19 +1403,26 @@ common::Status InferenceSession::Initialize() {
#ifdef USE_DML
if (execution_providers_.Get(kDmlExecutionProvider)) {
std::unique_ptr<onnxruntime::GraphTransformer> dmlGraphFusionTransformer = std::make_unique<Dml::DmlGraphFusionTransformer>("DmlGraphFusionTransformer",
execution_providers_.Get(kDmlExecutionProvider));
if (dmlGraphFusionTransformer == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr");
bool dml_graph_fusion_enabled = session_options_.optimized_model_filepath.empty() &&
session_options_.graph_optimization_level >= TransformerLevel::Level3;
if (dml_graph_fusion_enabled) {
std::unique_ptr<onnxruntime::GraphTransformer> dmlGraphFusionTransformer = std::make_unique<Dml::DmlGraphFusionTransformer>("DmlGraphFusionTransformer",
execution_providers_.Get(kDmlExecutionProvider));
if (dmlGraphFusionTransformer == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr");
}
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformation_mgr_.Register(std::move(dmlGraphFusionTransformer), onnxruntime::TransformerLevel::Level3));
}
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformation_mgr_.Register(std::move(dmlGraphFusionTransformer), onnxruntime::TransformerLevel::Level3));
// This transformer applies DML-specific fusions that go beyond what ORT offers by default
std::unique_ptr<onnxruntime::GraphTransformer> dmlOperatorFusionTransformer = std::make_unique<Dml::GraphTransformer>("DmlOperatorFusionTransformer");
if (dmlOperatorFusionTransformer == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "DmlOperatorFusionTransformer is nullptr");
bool dml_operator_fusion_enabled = session_options_.graph_optimization_level >= TransformerLevel::Level2;
if (dml_operator_fusion_enabled) {
std::unique_ptr<onnxruntime::GraphTransformer> dmlOperatorFusionTransformer = std::make_unique<Dml::GraphTransformer>("DmlOperatorFusionTransformer");
if (dmlOperatorFusionTransformer == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "DmlOperatorFusionTransformer is nullptr");
}
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformation_mgr_.Register(std::move(dmlOperatorFusionTransformer), onnxruntime::TransformerLevel::Level2));
}
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformation_mgr_.Register(std::move(dmlOperatorFusionTransformer), onnxruntime::TransformerLevel::Level2));
}
#endif