mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-27 03:11:28 +00:00
DML EP Fix InstanceNormalization with 3D tensors (#12693)
Fix InstanceNormalization with 3D tensors
This commit is contained in:
parent
94f76b944e
commit
3f47119f33
1 changed files with 30 additions and 10 deletions
|
|
@ -15,18 +15,21 @@ class DmlOperatorInstanceNormalization : public DmlOperator
|
|||
IN_BIAS
|
||||
};
|
||||
|
||||
void Shift1DInputsTensorDesc(const MLOperatorKernelCreationContext& kernelCreationContext, int index, int count, uint32_t destinationAxis)
|
||||
void Shift1DInputsTensorDesc(
|
||||
const MLOperatorKernelCreationContext& kernelCreationContext,
|
||||
int index,
|
||||
int count,
|
||||
uint32_t minimumDimensionCount
|
||||
)
|
||||
{
|
||||
for (int i = index; i != index + count; ++i)
|
||||
{
|
||||
// Shift a single dimension size to the C channel.
|
||||
// e.g. [7] or [1,1,1,7] becomes [1,7,1,1]
|
||||
// e.g. Given a 4D input (X), then a 1D scale of [7] becomes [1,7,1,1].
|
||||
TensorDesc& tensorDesc = m_inputTensorDescs[i];
|
||||
gsl::span<const uint32_t> sizes = tensorDesc.GetSizes();
|
||||
gsl::span<const uint32_t> lastDimension = sizes.last(1);
|
||||
ML_CHECK_VALID_ARGUMENT(tensorDesc.GetDimensionCount() == OperatorHelper::NchwDimensionCount);
|
||||
ML_CHECK_VALID_ARGUMENT(sizes.size() >=4 && sizes[N] == 1 && sizes[C] == 1 && sizes[H] == 1);
|
||||
m_inputTensorDescs[i] = CreateTensorDescFromInput(kernelCreationContext, i, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned, lastDimension);
|
||||
m_inputTensorDescs[i] = CreateTensorDescFromInput(kernelCreationContext, i, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned, lastDimension, minimumDimensionCount);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -35,29 +38,46 @@ public:
|
|||
: DmlOperator(kernelCreationContext)
|
||||
{
|
||||
std::vector<std::optional<uint32_t>> kernelInputIndices = {0, 1, 2};
|
||||
DmlOperator::Initialize(kernelCreationContext, kernelInputIndices);
|
||||
DmlOperator::Initialize(
|
||||
kernelCreationContext,
|
||||
kernelInputIndices,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
/*minimumDimensionCount*/ 1
|
||||
);
|
||||
|
||||
// "Instance" normalization is really spatial normalization,
|
||||
// where the spatial channels are reduced and normalized, while
|
||||
// batch and channel remain independent. So pass a list of axes
|
||||
// just beyond the leading batch and channel dimensions (starting
|
||||
// at axis 2 up to the last spatial dimension).
|
||||
const uint32_t inputDimensionCount = m_inputTensorDescs.front().GetDimensionCount();
|
||||
std::vector<uint32_t> axes(inputDimensionCount - 2);
|
||||
std::iota(axes.begin(), axes.end(), 2u);
|
||||
|
||||
const float epsilon = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Epsilon, DefaultEpsilon);
|
||||
const std::optional<ActivationOperatorDesc> fusedActivation = FusionHelpers::TryGetFusedActivationDesc(kernelCreationContext);
|
||||
DML_OPERATOR_DESC fusedActivationDmlDesc = fusedActivation ? fusedActivation->GetDmlDesc() : DML_OPERATOR_DESC();
|
||||
|
||||
// Shift IN_SCALE and IN_BIAS input tensor descs {1, C, 1, 1} out of 1D tensors.
|
||||
Shift1DInputsTensorDesc(kernelCreationContext, IN_SCALE, 2, C);
|
||||
Shift1DInputsTensorDesc(kernelCreationContext, IN_SCALE, 2, inputDimensionCount);
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC operatorDesc = {};
|
||||
DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.InputTensor = &inputDescs[0];
|
||||
operatorDesc.ScaleTensor = &inputDescs[1];
|
||||
operatorDesc.BiasTensor = &inputDescs[2];
|
||||
operatorDesc.OutputTensor = outputDescs.data();
|
||||
operatorDesc.CrossChannel = false;
|
||||
operatorDesc.Axes = axes.data();
|
||||
operatorDesc.AxisCount = static_cast<uint32_t>(axes.size());
|
||||
operatorDesc.NormalizeVariance = true;
|
||||
operatorDesc.Epsilon = epsilon;
|
||||
operatorDesc.FusedActivation = fusedActivation ? &fusedActivationDmlDesc : nullptr;
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION, &operatorDesc };
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &operatorDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelCreationContext);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue