From fe827c3891b2ce38e278fe5cbd9de5ee821eae50 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Mon, 12 Dec 2022 10:15:51 -0800 Subject: [PATCH] [DML EP] Disable DML Graph Fusion for lower graph optimization level OR setOptimizedFilePath true (#13913) ### Description DML EP won't fuse the ONNX Graph if ORT Graph optimization level is <= 1 or `SessionOption::SetOptimizedFilePath` is passed. This is the successor of https://github.com/microsoft/onnxruntime/pull/11346. ### Motivation and Context - **Why is this change required? What problem does it solve?** Requested by few a users (issues below) and also helps in debugging. - **If it fixes an open issue, please link to the issue here:** - https://github.com/microsoft/onnxruntime/issues/13535 - https://github.com/microsoft/onnxruntime/issues/8440 --- .../src/Operators/DmlOperatorPooling.cpp | 23 ++++++++++------- onnxruntime/core/session/inference_session.cc | 25 ++++++++++++------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp index 18fd53199c..4f8b5a1bc7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp @@ -20,8 +20,21 @@ public: PoolingHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), useGlobalPooling), m_function(function) { - DmlOperator::Initialize(kernelInfo); + const bool hasDilations = + std::any_of( + m_kernel.dilations, + m_kernel.dilations + m_kernel.spatialDimensionCount, + [](auto d) {return d != 1; } + ); + const bool hasOutputIndices = (kernelInfo.GetOutputCount() > 1 && kernelInfo.IsOutputValid(1)); + std::vector> kernelOutputIndices = {0}; + if (function == DML_OPERATOR_MAX_POOLING2 && (hasOutputIndices || hasDilations)) + { + kernelOutputIndices.emplace_back(1); + } + DmlOperator::Initialize(kernelInfo, std::nullopt, kernelOutputIndices); + std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); ML_CHECK_VALID_ARGUMENT(inputDescs.size() >= 1, "MaxPool input count must be >=1."); @@ -33,13 +46,6 @@ public: int storageOrder = kernelInfo.GetOptionalAttribute(AttrName::StorageOrder, 0); ORT_THROW_HR_IF(E_NOTIMPL, storageOrder != 0); - const bool hasDilations = - std::any_of( - m_kernel.dilations, - m_kernel.dilations + m_kernel.spatialDimensionCount, - [](auto d) {return d != 1; } - ); - // DML requires that DimensionCount be equal to Input.DimCount - 2 for Pooling uint32_t expectedSpatialDimCount = m_inputTensorDescs[0].GetDimensionCount() - 2; if (m_kernel.spatialDimensionCount < expectedSpatialDimCount) @@ -104,7 +110,6 @@ public: case DML_OPERATOR_MAX_POOLING1: case DML_OPERATOR_MAX_POOLING2: { - bool hasOutputIndices = (outputDescs.size() > 1 && outputDescs[1].Desc != nullptr); if (hasOutputIndices || hasDilations) { DML_MAX_POOLING2_OPERATOR_DESC desc = {}; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 81d668bdce..a8baf340f4 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1403,19 +1403,26 @@ common::Status InferenceSession::Initialize() { #ifdef USE_DML if (execution_providers_.Get(kDmlExecutionProvider)) { - std::unique_ptr dmlGraphFusionTransformer = std::make_unique("DmlGraphFusionTransformer", - execution_providers_.Get(kDmlExecutionProvider)); - if (dmlGraphFusionTransformer == nullptr) { - return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr"); + bool dml_graph_fusion_enabled = session_options_.optimized_model_filepath.empty() && + session_options_.graph_optimization_level >= TransformerLevel::Level3; + if (dml_graph_fusion_enabled) { + std::unique_ptr dmlGraphFusionTransformer = std::make_unique("DmlGraphFusionTransformer", + execution_providers_.Get(kDmlExecutionProvider)); + if (dmlGraphFusionTransformer == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr"); + } + ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformation_mgr_.Register(std::move(dmlGraphFusionTransformer), onnxruntime::TransformerLevel::Level3)); } - ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformation_mgr_.Register(std::move(dmlGraphFusionTransformer), onnxruntime::TransformerLevel::Level3)); // This transformer applies DML-specific fusions that go beyond what ORT offers by default - std::unique_ptr dmlOperatorFusionTransformer = std::make_unique("DmlOperatorFusionTransformer"); - if (dmlOperatorFusionTransformer == nullptr) { - return Status(common::ONNXRUNTIME, common::FAIL, "DmlOperatorFusionTransformer is nullptr"); + bool dml_operator_fusion_enabled = session_options_.graph_optimization_level >= TransformerLevel::Level2; + if (dml_operator_fusion_enabled) { + std::unique_ptr dmlOperatorFusionTransformer = std::make_unique("DmlOperatorFusionTransformer"); + if (dmlOperatorFusionTransformer == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "DmlOperatorFusionTransformer is nullptr"); + } + ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformation_mgr_.Register(std::move(dmlOperatorFusionTransformer), onnxruntime::TransformerLevel::Level2)); } - ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformation_mgr_.Register(std::move(dmlOperatorFusionTransformer), onnxruntime::TransformerLevel::Level2)); } #endif