From 3f47119f3330a77b8d4e071dd3cf5f150119e095 Mon Sep 17 00:00:00 2001 From: Dwayne Robinson Date: Wed, 24 Aug 2022 14:58:38 -0700 Subject: [PATCH] DML EP Fix InstanceNormalization with 3D tensors (#12693) Fix InstanceNormalization with 3D tensors --- .../DmlOperatorInstanceNormalization.cpp | 40 ++++++++++++++----- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorInstanceNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorInstanceNormalization.cpp index be48659b16..c7f3672914 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorInstanceNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorInstanceNormalization.cpp @@ -15,18 +15,21 @@ class DmlOperatorInstanceNormalization : public DmlOperator IN_BIAS }; - void Shift1DInputsTensorDesc(const MLOperatorKernelCreationContext& kernelCreationContext, int index, int count, uint32_t destinationAxis) + void Shift1DInputsTensorDesc( + const MLOperatorKernelCreationContext& kernelCreationContext, + int index, + int count, + uint32_t minimumDimensionCount + ) { for (int i = index; i != index + count; ++i) { // Shift a single dimension size to the C channel. - // e.g. [7] or [1,1,1,7] becomes [1,7,1,1] + // e.g. Given a 4D input (X), then a 1D scale of [7] becomes [1,7,1,1]. TensorDesc& tensorDesc = m_inputTensorDescs[i]; gsl::span sizes = tensorDesc.GetSizes(); gsl::span lastDimension = sizes.last(1); - ML_CHECK_VALID_ARGUMENT(tensorDesc.GetDimensionCount() == OperatorHelper::NchwDimensionCount); - ML_CHECK_VALID_ARGUMENT(sizes.size() >=4 && sizes[N] == 1 && sizes[C] == 1 && sizes[H] == 1); - m_inputTensorDescs[i] = CreateTensorDescFromInput(kernelCreationContext, i, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned, lastDimension); + m_inputTensorDescs[i] = CreateTensorDescFromInput(kernelCreationContext, i, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned, lastDimension, minimumDimensionCount); } } @@ -35,29 +38,46 @@ public: : DmlOperator(kernelCreationContext) { std::vector> kernelInputIndices = {0, 1, 2}; - DmlOperator::Initialize(kernelCreationContext, kernelInputIndices); + DmlOperator::Initialize( + kernelCreationContext, + kernelInputIndices, + std::nullopt, + std::nullopt, + std::nullopt, + /*minimumDimensionCount*/ 1 + ); + + // "Instance" normalization is really spatial normalization, + // where the spatial channels are reduced and normalized, while + // batch and channel remain independent. So pass a list of axes + // just beyond the leading batch and channel dimensions (starting + // at axis 2 up to the last spatial dimension). + const uint32_t inputDimensionCount = m_inputTensorDescs.front().GetDimensionCount(); + std::vector axes(inputDimensionCount - 2); + std::iota(axes.begin(), axes.end(), 2u); const float epsilon = kernelCreationContext.GetOptionalAttribute(AttrName::Epsilon, DefaultEpsilon); const std::optional fusedActivation = FusionHelpers::TryGetFusedActivationDesc(kernelCreationContext); DML_OPERATOR_DESC fusedActivationDmlDesc = fusedActivation ? fusedActivation->GetDmlDesc() : DML_OPERATOR_DESC(); // Shift IN_SCALE and IN_BIAS input tensor descs {1, C, 1, 1} out of 1D tensors. - Shift1DInputsTensorDesc(kernelCreationContext, IN_SCALE, 2, C); + Shift1DInputsTensorDesc(kernelCreationContext, IN_SCALE, 2, inputDimensionCount); std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); - DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC operatorDesc = {}; + DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC operatorDesc = {}; operatorDesc.InputTensor = &inputDescs[0]; operatorDesc.ScaleTensor = &inputDescs[1]; operatorDesc.BiasTensor = &inputDescs[2]; operatorDesc.OutputTensor = outputDescs.data(); - operatorDesc.CrossChannel = false; + operatorDesc.Axes = axes.data(); + operatorDesc.AxisCount = static_cast(axes.size()); operatorDesc.NormalizeVariance = true; operatorDesc.Epsilon = epsilon; operatorDesc.FusedActivation = fusedActivation ? &fusedActivationDmlDesc : nullptr; - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION, &operatorDesc }; + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &operatorDesc }; SetDmlOperatorDesc(opDesc, kernelCreationContext); } };