[DML EP] Force instance norm inputs to be 4D to better target metacommands (#14020)

### Description
Force instance norm inputs to be 4D to better target metacommands

### Motivation and Context
This may improve performance on some hardware by allowing the driver to
return valid layouts to DML when querying for metacommand support.
This commit is contained in:
Patrice Vignola 2023-01-03 12:47:10 -08:00 committed by GitHub
parent 589612106a
commit e7f9d40dde
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 5 deletions

View file

@ -44,8 +44,21 @@ public:
std::nullopt,
std::nullopt,
std::nullopt,
/*minimumDimensionCount*/ 1
);
/*minimumDimensionCount*/ 1);
const uint32_t dmlDimensionCount = std::max<uint32_t>(4u, m_inputTensorDescs[0].GetDimensionCount());
// Shift IN_SCALE and IN_BIAS input tensor descs {1, C, 1, 1} out of 1D tensors.
Shift1DInputsTensorDesc(kernelCreationContext, IN_SCALE, 2, dmlDimensionCount);
// Pad the input and the output with trailing 1's until they are at least 4D
auto sizes = m_inputTensorDescs[0].GetSizes();
std::vector<uint32_t> tensorShape(sizes.begin(), sizes.end());
tensorShape.resize(static_cast<size_t>(dmlDimensionCount), 1);
m_inputTensorDescs[0] = TensorDesc(
m_inputTensorDescs[0].GetDmlDataType(),
tensorShape);
m_outputTensorDescs[0] = m_inputTensorDescs[0];
// "Instance" normalization is really spatial normalization,
// where the spatial channels are reduced and normalized, while
@ -60,9 +73,6 @@ public:
const std::optional<ActivationOperatorDesc> fusedActivation = FusionHelpers::TryGetFusedActivationDesc(kernelCreationContext);
DML_OPERATOR_DESC fusedActivationDmlDesc = fusedActivation ? fusedActivation->GetDmlDesc() : DML_OPERATOR_DESC();
// Shift IN_SCALE and IN_BIAS input tensor descs {1, C, 1, 1} out of 1D tensors.
Shift1DInputsTensorDesc(kernelCreationContext, IN_SCALE, 2, inputDimensionCount);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

View file

@ -184,8 +184,10 @@ namespace Dml
OperatorInfo{ "Relu", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Relu },
OperatorInfo{ "Relu", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_Relu },
OperatorInfo{ "LeakyRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_LeakyRelu },
OperatorInfo{ "LeakyRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet16::sc_sinceVer_LeakyRelu },
OperatorInfo{ "PRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_PRelu },
OperatorInfo{ "PRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_PRelu },
OperatorInfo{ "PRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet16::sc_sinceVer_PRelu },
OperatorInfo{ "ThresholdedRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_ThresholdedRelu },
OperatorInfo{ "ThresholdedRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet10::sc_sinceVer_ThresholdedRelu },
OperatorInfo{ "Elu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Elu },