From e7f9d40dde37960e0a4bfa9e98e29d6a37d8435f Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Tue, 3 Jan 2023 12:47:10 -0800 Subject: [PATCH] [DML EP] Force instance norm inputs to be 4D to better target metacommands (#14020) ### Description Force instance norm inputs to be 4D to better target metacommands ### Motivation and Context This may improve performance on some hardware by allowing the driver to return valid layouts to DML when querying for metacommand support. --- .../DmlOperatorInstanceNormalization.cpp | 20 ++++++++++++++----- .../src/Operators/OperatorUtility.cpp | 2 ++ 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorInstanceNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorInstanceNormalization.cpp index a90d5b9e38..d59db28816 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorInstanceNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorInstanceNormalization.cpp @@ -44,8 +44,21 @@ public: std::nullopt, std::nullopt, std::nullopt, - /*minimumDimensionCount*/ 1 - ); + /*minimumDimensionCount*/ 1); + + const uint32_t dmlDimensionCount = std::max(4u, m_inputTensorDescs[0].GetDimensionCount()); + + // Shift IN_SCALE and IN_BIAS input tensor descs {1, C, 1, 1} out of 1D tensors. + Shift1DInputsTensorDesc(kernelCreationContext, IN_SCALE, 2, dmlDimensionCount); + + // Pad the input and the output with trailing 1's until they are at least 4D + auto sizes = m_inputTensorDescs[0].GetSizes(); + std::vector tensorShape(sizes.begin(), sizes.end()); + tensorShape.resize(static_cast(dmlDimensionCount), 1); + m_inputTensorDescs[0] = TensorDesc( + m_inputTensorDescs[0].GetDmlDataType(), + tensorShape); + m_outputTensorDescs[0] = m_inputTensorDescs[0]; // "Instance" normalization is really spatial normalization, // where the spatial channels are reduced and normalized, while @@ -60,9 +73,6 @@ public: const std::optional fusedActivation = FusionHelpers::TryGetFusedActivationDesc(kernelCreationContext); DML_OPERATOR_DESC fusedActivationDmlDesc = fusedActivation ? fusedActivation->GetDmlDesc() : DML_OPERATOR_DESC(); - // Shift IN_SCALE and IN_BIAS input tensor descs {1, C, 1, 1} out of 1D tensors. - Shift1DInputsTensorDesc(kernelCreationContext, IN_SCALE, 2, inputDimensionCount); - std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp index 1d8596b643..e805cb1e7a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp @@ -184,8 +184,10 @@ namespace Dml OperatorInfo{ "Relu", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Relu }, OperatorInfo{ "Relu", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_Relu }, OperatorInfo{ "LeakyRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_LeakyRelu }, + OperatorInfo{ "LeakyRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet16::sc_sinceVer_LeakyRelu }, OperatorInfo{ "PRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_PRelu }, OperatorInfo{ "PRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_PRelu }, + OperatorInfo{ "PRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet16::sc_sinceVer_PRelu }, OperatorInfo{ "ThresholdedRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_ThresholdedRelu }, OperatorInfo{ "ThresholdedRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet10::sc_sinceVer_ThresholdedRelu }, OperatorInfo{ "Elu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Elu },