mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
[DML EP] Disable DML Graph Fusion for lower graph optimization level OR setOptimizedFilePath true (#13913)
### Description DML EP won't fuse the ONNX Graph if ORT Graph optimization level is <= 1 or `SessionOption::SetOptimizedFilePath` is passed. This is the successor of https://github.com/microsoft/onnxruntime/pull/11346. ### Motivation and Context - **Why is this change required? What problem does it solve?** Requested by few a users (issues below) and also helps in debugging. - **If it fixes an open issue, please link to the issue here:** - https://github.com/microsoft/onnxruntime/issues/13535 - https://github.com/microsoft/onnxruntime/issues/8440
This commit is contained in:
parent
8cfbc4fe91
commit
fe827c3891
2 changed files with 30 additions and 18 deletions
|
|
@ -20,8 +20,21 @@ public:
|
|||
PoolingHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), useGlobalPooling),
|
||||
m_function(function)
|
||||
{
|
||||
DmlOperator::Initialize(kernelInfo);
|
||||
const bool hasDilations =
|
||||
std::any_of(
|
||||
m_kernel.dilations,
|
||||
m_kernel.dilations + m_kernel.spatialDimensionCount,
|
||||
[](auto d) {return d != 1; }
|
||||
);
|
||||
const bool hasOutputIndices = (kernelInfo.GetOutputCount() > 1 && kernelInfo.IsOutputValid(1));
|
||||
std::vector<std::optional<uint32_t>> kernelOutputIndices = {0};
|
||||
|
||||
if (function == DML_OPERATOR_MAX_POOLING2 && (hasOutputIndices || hasDilations))
|
||||
{
|
||||
kernelOutputIndices.emplace_back(1);
|
||||
}
|
||||
DmlOperator::Initialize(kernelInfo, std::nullopt, kernelOutputIndices);
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
ML_CHECK_VALID_ARGUMENT(inputDescs.size() >= 1, "MaxPool input count must be >=1.");
|
||||
|
|
@ -33,13 +46,6 @@ public:
|
|||
int storageOrder = kernelInfo.GetOptionalAttribute<int>(AttrName::StorageOrder, 0);
|
||||
ORT_THROW_HR_IF(E_NOTIMPL, storageOrder != 0);
|
||||
|
||||
const bool hasDilations =
|
||||
std::any_of(
|
||||
m_kernel.dilations,
|
||||
m_kernel.dilations + m_kernel.spatialDimensionCount,
|
||||
[](auto d) {return d != 1; }
|
||||
);
|
||||
|
||||
// DML requires that DimensionCount be equal to Input.DimCount - 2 for Pooling
|
||||
uint32_t expectedSpatialDimCount = m_inputTensorDescs[0].GetDimensionCount() - 2;
|
||||
if (m_kernel.spatialDimensionCount < expectedSpatialDimCount)
|
||||
|
|
@ -104,7 +110,6 @@ public:
|
|||
case DML_OPERATOR_MAX_POOLING1:
|
||||
case DML_OPERATOR_MAX_POOLING2:
|
||||
{
|
||||
bool hasOutputIndices = (outputDescs.size() > 1 && outputDescs[1].Desc != nullptr);
|
||||
if (hasOutputIndices || hasDilations)
|
||||
{
|
||||
DML_MAX_POOLING2_OPERATOR_DESC desc = {};
|
||||
|
|
|
|||
|
|
@ -1403,19 +1403,26 @@ common::Status InferenceSession::Initialize() {
|
|||
|
||||
#ifdef USE_DML
|
||||
if (execution_providers_.Get(kDmlExecutionProvider)) {
|
||||
std::unique_ptr<onnxruntime::GraphTransformer> dmlGraphFusionTransformer = std::make_unique<Dml::DmlGraphFusionTransformer>("DmlGraphFusionTransformer",
|
||||
execution_providers_.Get(kDmlExecutionProvider));
|
||||
if (dmlGraphFusionTransformer == nullptr) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr");
|
||||
bool dml_graph_fusion_enabled = session_options_.optimized_model_filepath.empty() &&
|
||||
session_options_.graph_optimization_level >= TransformerLevel::Level3;
|
||||
if (dml_graph_fusion_enabled) {
|
||||
std::unique_ptr<onnxruntime::GraphTransformer> dmlGraphFusionTransformer = std::make_unique<Dml::DmlGraphFusionTransformer>("DmlGraphFusionTransformer",
|
||||
execution_providers_.Get(kDmlExecutionProvider));
|
||||
if (dmlGraphFusionTransformer == nullptr) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr");
|
||||
}
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformation_mgr_.Register(std::move(dmlGraphFusionTransformer), onnxruntime::TransformerLevel::Level3));
|
||||
}
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformation_mgr_.Register(std::move(dmlGraphFusionTransformer), onnxruntime::TransformerLevel::Level3));
|
||||
|
||||
// This transformer applies DML-specific fusions that go beyond what ORT offers by default
|
||||
std::unique_ptr<onnxruntime::GraphTransformer> dmlOperatorFusionTransformer = std::make_unique<Dml::GraphTransformer>("DmlOperatorFusionTransformer");
|
||||
if (dmlOperatorFusionTransformer == nullptr) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "DmlOperatorFusionTransformer is nullptr");
|
||||
bool dml_operator_fusion_enabled = session_options_.graph_optimization_level >= TransformerLevel::Level2;
|
||||
if (dml_operator_fusion_enabled) {
|
||||
std::unique_ptr<onnxruntime::GraphTransformer> dmlOperatorFusionTransformer = std::make_unique<Dml::GraphTransformer>("DmlOperatorFusionTransformer");
|
||||
if (dmlOperatorFusionTransformer == nullptr) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "DmlOperatorFusionTransformer is nullptr");
|
||||
}
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformation_mgr_.Register(std::move(dmlOperatorFusionTransformer), onnxruntime::TransformerLevel::Level2));
|
||||
}
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformation_mgr_.Register(std::move(dmlOperatorFusionTransformer), onnxruntime::TransformerLevel::Level2));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue