DML EP Fix InstanceNormalization with 3D tensors (#12693)

Fix InstanceNormalization with 3D tensors
This commit is contained in:
Dwayne Robinson 2022-08-24 14:58:38 -07:00 committed by GitHub
parent 94f76b944e
commit 3f47119f33
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -15,18 +15,21 @@ class DmlOperatorInstanceNormalization : public DmlOperator
IN_BIAS
};
void Shift1DInputsTensorDesc(const MLOperatorKernelCreationContext& kernelCreationContext, int index, int count, uint32_t destinationAxis)
void Shift1DInputsTensorDesc(
const MLOperatorKernelCreationContext& kernelCreationContext,
int index,
int count,
uint32_t minimumDimensionCount
)
{
for (int i = index; i != index + count; ++i)
{
// Shift a single dimension size to the C channel.
// e.g. [7] or [1,1,1,7] becomes [1,7,1,1]
// e.g. Given a 4D input (X), then a 1D scale of [7] becomes [1,7,1,1].
TensorDesc& tensorDesc = m_inputTensorDescs[i];
gsl::span<const uint32_t> sizes = tensorDesc.GetSizes();
gsl::span<const uint32_t> lastDimension = sizes.last(1);
ML_CHECK_VALID_ARGUMENT(tensorDesc.GetDimensionCount() == OperatorHelper::NchwDimensionCount);
ML_CHECK_VALID_ARGUMENT(sizes.size() >=4 && sizes[N] == 1 && sizes[C] == 1 && sizes[H] == 1);
m_inputTensorDescs[i] = CreateTensorDescFromInput(kernelCreationContext, i, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned, lastDimension);
m_inputTensorDescs[i] = CreateTensorDescFromInput(kernelCreationContext, i, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned, lastDimension, minimumDimensionCount);
}
}
@ -35,29 +38,46 @@ public:
: DmlOperator(kernelCreationContext)
{
std::vector<std::optional<uint32_t>> kernelInputIndices = {0, 1, 2};
DmlOperator::Initialize(kernelCreationContext, kernelInputIndices);
DmlOperator::Initialize(
kernelCreationContext,
kernelInputIndices,
std::nullopt,
std::nullopt,
std::nullopt,
/*minimumDimensionCount*/ 1
);
// "Instance" normalization is really spatial normalization,
// where the spatial channels are reduced and normalized, while
// batch and channel remain independent. So pass a list of axes
// just beyond the leading batch and channel dimensions (starting
// at axis 2 up to the last spatial dimension).
const uint32_t inputDimensionCount = m_inputTensorDescs.front().GetDimensionCount();
std::vector<uint32_t> axes(inputDimensionCount - 2);
std::iota(axes.begin(), axes.end(), 2u);
const float epsilon = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Epsilon, DefaultEpsilon);
const std::optional<ActivationOperatorDesc> fusedActivation = FusionHelpers::TryGetFusedActivationDesc(kernelCreationContext);
DML_OPERATOR_DESC fusedActivationDmlDesc = fusedActivation ? fusedActivation->GetDmlDesc() : DML_OPERATOR_DESC();
// Shift IN_SCALE and IN_BIAS input tensor descs {1, C, 1, 1} out of 1D tensors.
Shift1DInputsTensorDesc(kernelCreationContext, IN_SCALE, 2, C);
Shift1DInputsTensorDesc(kernelCreationContext, IN_SCALE, 2, inputDimensionCount);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC operatorDesc = {};
DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.ScaleTensor = &inputDescs[1];
operatorDesc.BiasTensor = &inputDescs[2];
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.CrossChannel = false;
operatorDesc.Axes = axes.data();
operatorDesc.AxisCount = static_cast<uint32_t>(axes.size());
operatorDesc.NormalizeVariance = true;
operatorDesc.Epsilon = epsilon;
operatorDesc.FusedActivation = fusedActivation ? &fusedActivationDmlDesc : nullptr;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION, &operatorDesc };
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};