mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
[DML EP] Force instance norm inputs to be 4D to better target metacommands (#14020)
### Description Force instance norm inputs to be 4D to better target metacommands ### Motivation and Context This may improve performance on some hardware by allowing the driver to return valid layouts to DML when querying for metacommand support.
This commit is contained in:
parent
589612106a
commit
e7f9d40dde
2 changed files with 17 additions and 5 deletions
|
|
@ -44,8 +44,21 @@ public:
|
|||
std::nullopt,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
/*minimumDimensionCount*/ 1
|
||||
);
|
||||
/*minimumDimensionCount*/ 1);
|
||||
|
||||
const uint32_t dmlDimensionCount = std::max<uint32_t>(4u, m_inputTensorDescs[0].GetDimensionCount());
|
||||
|
||||
// Shift IN_SCALE and IN_BIAS input tensor descs {1, C, 1, 1} out of 1D tensors.
|
||||
Shift1DInputsTensorDesc(kernelCreationContext, IN_SCALE, 2, dmlDimensionCount);
|
||||
|
||||
// Pad the input and the output with trailing 1's until they are at least 4D
|
||||
auto sizes = m_inputTensorDescs[0].GetSizes();
|
||||
std::vector<uint32_t> tensorShape(sizes.begin(), sizes.end());
|
||||
tensorShape.resize(static_cast<size_t>(dmlDimensionCount), 1);
|
||||
m_inputTensorDescs[0] = TensorDesc(
|
||||
m_inputTensorDescs[0].GetDmlDataType(),
|
||||
tensorShape);
|
||||
m_outputTensorDescs[0] = m_inputTensorDescs[0];
|
||||
|
||||
// "Instance" normalization is really spatial normalization,
|
||||
// where the spatial channels are reduced and normalized, while
|
||||
|
|
@ -60,9 +73,6 @@ public:
|
|||
const std::optional<ActivationOperatorDesc> fusedActivation = FusionHelpers::TryGetFusedActivationDesc(kernelCreationContext);
|
||||
DML_OPERATOR_DESC fusedActivationDmlDesc = fusedActivation ? fusedActivation->GetDmlDesc() : DML_OPERATOR_DESC();
|
||||
|
||||
// Shift IN_SCALE and IN_BIAS input tensor descs {1, C, 1, 1} out of 1D tensors.
|
||||
Shift1DInputsTensorDesc(kernelCreationContext, IN_SCALE, 2, inputDimensionCount);
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
|
|
|
|||
|
|
@ -184,8 +184,10 @@ namespace Dml
|
|||
OperatorInfo{ "Relu", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Relu },
|
||||
OperatorInfo{ "Relu", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_Relu },
|
||||
OperatorInfo{ "LeakyRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_LeakyRelu },
|
||||
OperatorInfo{ "LeakyRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet16::sc_sinceVer_LeakyRelu },
|
||||
OperatorInfo{ "PRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_PRelu },
|
||||
OperatorInfo{ "PRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_PRelu },
|
||||
OperatorInfo{ "PRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet16::sc_sinceVer_PRelu },
|
||||
OperatorInfo{ "ThresholdedRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_ThresholdedRelu },
|
||||
OperatorInfo{ "ThresholdedRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet10::sc_sinceVer_ThresholdedRelu },
|
||||
OperatorInfo{ "Elu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Elu },
|
||||
|
|
|
|||
Loading…
Reference in a new issue