diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 99c265729d..00289c145c 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -1181,6 +1181,7 @@ Do not modify directly.*
|FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)|
+|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)|
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConvInteger.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConvInteger.cpp
index 4a689b3065..6a2f867ed8 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConvInteger.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConvInteger.cpp
@@ -9,14 +9,14 @@ namespace Dml
class DmlOperatorConvInteger : public DmlOperator, public ConvolutionHelperBase
{
private:
- enum InputTensors
- {
- IN_X,
- IN_X_ZERO_POINT,
- IN_F,
- IN_F_ZERO_POINT,
+ enum InputTensors
+ {
+ IN_X,
+ IN_X_ZERO_POINT,
+ IN_F,
+ IN_F_ZERO_POINT,
};
-
+
public:
using Self = DmlOperatorConvInteger;
@@ -24,15 +24,15 @@ public:
const MLOperatorKernelCreationContext& kernelInfo
)
: DmlOperator(kernelInfo),
- ConvolutionHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), false, false, 0, 1)
+ ConvolutionHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), false, false, false, 0, 1)
{
std::vector> kernelInputIndices = {0, 2, 1, 3};
std::vector> kernelOutputIndices = {0};
DmlOperator::Initialize(kernelInfo, kernelInputIndices);
- // DirectML is limited to handle only 2D. So for 1D tensors, massage the tensor descriptions. By default, the
- // TensorDesc simply right aligns all the values up to 4D (padding the leading dimensions with 1's),
+ // DirectML is limited to handle only 2D. So for 1D tensors, massage the tensor descriptions. By default, the
+ // TensorDesc simply right aligns all the values up to 4D (padding the leading dimensions with 1's),
// but 1D tensors actually need to insert the 1 between C and W. e.g. [2,3,4] becomes [2,3,1,4]
m_inputTensorDescs[IN_X] = CreateTensorDescFromInput(kernelInfo, 0/*Onnx Index*/, TensorAxis::DoNotCoerce, TensorAxis::NoPlacementAdjustment, NonspatialDimensionCount, std::nullopt);
m_inputTensorDescs[IN_F] = CreateTensorDescFromInput(kernelInfo, 1/*Onnx Index*/, TensorAxis::DoNotCoerce, TensorAxis::NoPlacementAdjustment, NonspatialDimensionCount, std::nullopt);
@@ -42,9 +42,9 @@ public:
// Resize the Filter ZeroPoint to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the C channel.
m_inputTensorDescs[IN_F_ZERO_POINT] = CreateTensorDescFromInput(
- kernelInfo,
- 3/*Onnx Index*/,
- TensorAxis::DoNotCoerce,
+ kernelInfo,
+ 3/*Onnx Index*/,
+ TensorAxis::DoNotCoerce,
TensorAxis::C,
TensorAxis::LeftAligned,
std::nullopt,
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConvolution.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConvolution.cpp
index 548ab65cc0..9c343e9a76 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConvolution.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConvolution.cpp
@@ -15,10 +15,11 @@ public:
const MLOperatorKernelCreationContext& kernelInfo,
DML_CONVOLUTION_MODE mode,
DML_CONVOLUTION_DIRECTION direction,
- bool hasDynamicPads
+ bool hasDynamicPads,
+ bool isNhwc
)
: DmlOperator(kernelInfo),
- ConvolutionHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), direction == DML_CONVOLUTION_DIRECTION_BACKWARD, hasDynamicPads, 0, 1)
+ ConvolutionHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), direction == DML_CONVOLUTION_DIRECTION_BACKWARD, hasDynamicPads, isNhwc, 0, 1)
{
uint32_t biasIndex = hasDynamicPads ? 3 : 2;
bool hasBiasInput = kernelInfo.GetInputCount() > biasIndex;
@@ -33,6 +34,43 @@ public:
// e.g. [2,3,4] becomes [2,3,1,4]
m_inputTensorDescs[0] = CreateTensorDescFromInput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::NoPlacementAdjustment, NonspatialDimensionCount, std::nullopt);
m_inputTensorDescs[1] = CreateTensorDescFromInput(kernelInfo, 1, TensorAxis::DoNotCoerce, TensorAxis::NoPlacementAdjustment, NonspatialDimensionCount, std::nullopt);
+ m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::NoPlacementAdjustment, NonspatialDimensionCount, std::nullopt);
+
+ if (isNhwc)
+ {
+ // Restrict to 4D like other implementations
+ ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[0].GetDimensionCount() == 4);
+ const auto inputSizes = m_inputTensorDescs[0].GetSizes();
+ const uint32_t inputBatch = inputSizes[0];
+ const uint32_t inputHeight = inputSizes[1];
+ const uint32_t inputWidth = inputSizes[2];
+ const uint32_t inputChannels = inputSizes[3];
+ const std::array nchwInputSizes = {inputBatch, inputChannels, inputHeight, inputWidth};
+ const std::array nchwInputStrides = {inputHeight * inputWidth * inputChannels, 1, inputWidth * inputChannels, inputChannels};
+ m_inputTensorDescs[0] = TensorDesc(m_inputTensorDescs[0].GetDmlDataType(), nchwInputSizes, nchwInputStrides);
+
+ // Restrict to 4D like other implementations
+ ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[1].GetDimensionCount() == 4);
+ const auto weightSizes = m_inputTensorDescs[1].GetSizes();
+ const uint32_t featureMaps = weightSizes[0];
+ const uint32_t kernelHeight = weightSizes[1];
+ const uint32_t kernelWidth = weightSizes[2];
+ const uint32_t channelsPerGroup = weightSizes[3];
+ const std::array nchwKernelSizes = {featureMaps, channelsPerGroup, kernelHeight, kernelWidth};
+ const std::array nchwKernelStrides = {kernelHeight * kernelWidth * channelsPerGroup, 1, kernelWidth * channelsPerGroup, channelsPerGroup};
+ m_inputTensorDescs[1] = TensorDesc(m_inputTensorDescs[1].GetDmlDataType(), nchwKernelSizes, nchwKernelStrides);
+
+ // Restrict to 4D like other implementations
+ ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs[0].GetDimensionCount() == 4);
+ const auto outputSizes = m_outputTensorDescs[0].GetSizes();
+ const uint32_t outputBatch = outputSizes[0];
+ const uint32_t outputHeight = outputSizes[1];
+ const uint32_t outputWidth = outputSizes[2];
+ const uint32_t outputChannels = outputSizes[3];
+ const std::array nchwOutputSizes = {outputBatch, outputChannels, outputHeight, outputWidth};
+ const std::array nchwOutputStrides = {outputHeight * outputWidth * outputChannels, 1, outputWidth * outputChannels, outputChannels};
+ m_outputTensorDescs[0] = TensorDesc(m_outputTensorDescs[0].GetDmlDataType(), nchwOutputSizes, nchwOutputStrides);
+ }
// Bias is optional so only adjust it if it exists.
if (hasBiasInput)
@@ -47,9 +85,9 @@ public:
// Resize the bias to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the C channel.
m_inputTensorDescs[biasIndex] = CreateTensorDescFromInput(
- kernelInfo,
- biasIndex,
- TensorAxis::DoNotCoerce,
+ kernelInfo,
+ biasIndex,
+ TensorAxis::DoNotCoerce,
TensorAxis::C,
TensorAxis::LeftAligned,
std::nullopt,
@@ -57,8 +95,6 @@ public:
);
}
- m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::NoPlacementAdjustment, NonspatialDimensionCount, std::nullopt);
-
std::optional fusedActivation = FusionHelpers::TryGetFusedActivationDesc(kernelInfo);
DML_OPERATOR_DESC fusedActivationDmlDesc = fusedActivation ? fusedActivation->GetDmlDesc() : DML_OPERATOR_DESC();
std::vector inputDescs = GetDmlInputDescs();
@@ -95,20 +131,21 @@ public:
};
// A specific type of operation for registration.
-template
+template
class DmlOperatorConvolutionTemplate : public DmlOperatorConvolution
{
public:
DmlOperatorConvolutionTemplate(const MLOperatorKernelCreationContext& kernelInfo)
- : DmlOperatorConvolution(kernelInfo, Mode, Direction, hasDynamicPads)
+ : DmlOperatorConvolution(kernelInfo, Mode, Direction, hasDynamicPads, isNhwc)
{
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Conv, DmlOperatorConvolutionTemplate);
+DML_OP_DEFINE_CREATION_FUNCTION(NhwcConv, DmlOperatorConvolutionTemplate);
DML_OP_DEFINE_CREATION_FUNCTION(ConvTranspose, DmlOperatorConvolutionTemplate);
-DML_OP_DEFINE_CREATION_FUNCTION(DmlFusedConv, DmlOperatorConvolutionTemplate);
-DML_OP_DEFINE_CREATION_FUNCTION(DmlFusedConvTranspose, DmlOperatorConvolutionTemplate);
+DML_OP_DEFINE_CREATION_FUNCTION(DmlFusedConv, DmlOperatorConvolutionTemplate);
+DML_OP_DEFINE_CREATION_FUNCTION(DmlFusedConvTranspose, DmlOperatorConvolutionTemplate);
DML_OP_DEFINE_CREATION_FUNCTION(ConvTransposeWithDynamicPads, DmlOperatorConvolutionTemplate);
} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp
index 91f730da7a..d45fdef3c8 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp
@@ -9,19 +9,19 @@ namespace Dml
class DmlOperatorQLinearConv : public DmlOperator, public ConvolutionHelperBase
{
private:
- enum InputTensors
- {
- IN_X,
+ enum InputTensors
+ {
+ IN_X,
IN_X_SCALE,
- IN_X_ZERO_POINT,
- IN_F,
+ IN_X_ZERO_POINT,
+ IN_F,
IN_F_SCALE,
- IN_F_ZERO_POINT,
+ IN_F_ZERO_POINT,
IN_BIAS,
IN_Y_SCALE,
IN_Y_ZERO_POINT
};
-
+
public:
using Self = DmlOperatorQLinearConv;
@@ -29,15 +29,15 @@ public:
const MLOperatorKernelCreationContext& kernelInfo
)
: DmlOperator(kernelInfo),
- ConvolutionHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), false, false, 0, 3)
+ ConvolutionHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), false, false, false, 0, 3)
{
std::vector> kernelInputIndices = {0, 1, 2, 3, 4, 5, 8, 6, 7};
std::vector> kernelOutputIndices = {0};
DmlOperator::Initialize(kernelInfo, kernelInputIndices);
- // DirectML is limited to handle only 2D. So for 1D tensors, massage the tensor descriptions. By default, the
- // TensorDesc simply right aligns all the values up to 4D (padding the leading dimensions with 1's),
+ // DirectML is limited to handle only 2D. So for 1D tensors, massage the tensor descriptions. By default, the
+ // TensorDesc simply right aligns all the values up to 4D (padding the leading dimensions with 1's),
// but 1D tensors actually need to insert the 1 between C and W. e.g. [2,3,4] becomes [2,3,1,4]
m_inputTensorDescs[IN_X] = CreateTensorDescFromInput(kernelInfo, 0/*Onnx Index*/, TensorAxis::DoNotCoerce, TensorAxis::NoPlacementAdjustment, NonspatialDimensionCount, std::nullopt);
m_inputTensorDescs[IN_F] = CreateTensorDescFromInput(kernelInfo, 3/*Onnx Index*/, TensorAxis::DoNotCoerce, TensorAxis::NoPlacementAdjustment, NonspatialDimensionCount, std::nullopt);
@@ -56,9 +56,9 @@ public:
// Resize the bias to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the C channel.
m_inputTensorDescs[IN_BIAS] = CreateTensorDescFromInput(
- kernelInfo,
- 8/*Onnx Index*/,
- TensorAxis::DoNotCoerce,
+ kernelInfo,
+ 8/*Onnx Index*/,
+ TensorAxis::DoNotCoerce,
TensorAxis::C,
TensorAxis::LeftAligned,
std::nullopt,
@@ -69,9 +69,9 @@ public:
// Resize the Filter ZeroPoint to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the C channel.
m_inputTensorDescs[IN_F_ZERO_POINT] = CreateTensorDescFromInput(
- kernelInfo,
- 5/*Onnx Index*/,
- TensorAxis::DoNotCoerce,
+ kernelInfo,
+ 5/*Onnx Index*/,
+ TensorAxis::DoNotCoerce,
TensorAxis::C,
TensorAxis::LeftAligned,
std::nullopt,
@@ -80,9 +80,9 @@ public:
// Resize the Filter Scale to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the C channel.
m_inputTensorDescs[IN_F_SCALE] = CreateTensorDescFromInput(
- kernelInfo,
- 4/*Onnx Index*/,
- TensorAxis::DoNotCoerce,
+ kernelInfo,
+ 4/*Onnx Index*/,
+ TensorAxis::DoNotCoerce,
TensorAxis::C,
TensorAxis::LeftAligned,
std::nullopt,
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index e3067a4a22..f8e9e1a262 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -194,6 +194,7 @@ struct OperatorRegistrationInformation
DML_OP_EXTERN_CREATION_FUNCTION(Copy);
DML_OP_EXTERN_CREATION_FUNCTION(FC);
DML_OP_EXTERN_CREATION_FUNCTION(Conv);
+DML_OP_EXTERN_CREATION_FUNCTION(NhwcConv);
DML_OP_EXTERN_CREATION_FUNCTION(ConvTranspose);
DML_OP_EXTERN_CREATION_FUNCTION(ConvTransposeWithDynamicPads);
DML_OP_EXTERN_CREATION_FUNCTION(AveragePool);
@@ -528,6 +529,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
// Deep Learning Standard Layers
{REG_INFO( 7, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
+ {REG_INFO_MS( 1, NhwcConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
index 00e3e1fcd5..d91f7be447 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
@@ -292,11 +292,12 @@ namespace OperatorHelper
// are ordered such that they are at the end (e.g. NCHW or NCDHW).
std::vector InitializeKernelOutputDimensions(
gsl::span inputDimensions,
- const KernelArgs& args
+ const KernelArgs& args,
+ bool isNhwc
)
{
ML_CHECK_VALID_ARGUMENT(gsl::narrow_cast(inputDimensions.size()) >= args.spatialDimensionCount);
- int dimOffset = gsl::narrow_cast(inputDimensions.size()) - args.spatialDimensionCount;
+ int dimOffset = isNhwc ? 1 : gsl::narrow_cast(inputDimensions.size()) - args.spatialDimensionCount;
std::vector outputDimensions(inputDimensions.begin(), inputDimensions.end());
@@ -478,7 +479,8 @@ namespace OperatorHelper
void ResolveAutoPadding(
KernelArgs& args,
- gsl::span inputDimensions
+ gsl::span inputDimensions,
+ bool isNhwc
)
{
if (!args.autoPad)
@@ -490,7 +492,9 @@ namespace OperatorHelper
uint32_t spatialDimensionCount = gsl::narrow_cast(inputDimensions.size()) - NonspatialDimensionCount;
ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); // Support up to 3D convolution (in 5D tensor).
- const int dimOffset = gsl::narrow_cast(inputDimensions.size()) - spatialDimensionCount;
+ ML_CHECK_VALID_ARGUMENT(!isNhwc || inputDimensions.size() == 4);
+
+ const int dimOffset = isNhwc ? 1 : gsl::narrow_cast(inputDimensions.size()) - spatialDimensionCount;
for (size_t dim = 0; dim < spatialDimensionCount; ++dim)
{
@@ -763,8 +767,16 @@ namespace OperatorHelper
ResolvingPadding(inputDimensions);
m_outputShapes.resize(1);
- m_outputShapes[0] = InitializeKernelOutputDimensions(inputDimensions, m_kernel);
- m_outputShapes[0].GetShape()[C] = filterDims[K];
+ m_outputShapes[0] = InitializeKernelOutputDimensions(inputDimensions, m_kernel, m_isNhwc);
+
+ if (m_isNhwc)
+ {
+ m_outputShapes[0].GetShape()[static_cast(NhwcInputDims::C)] = filterDims[K];
+ }
+ else
+ {
+ m_outputShapes[0].GetShape()[C] = filterDims[K];
+ }
}
void ConvolutionHelperBase::InitializeKernelAndShapesTransposed(
@@ -868,7 +880,7 @@ namespace OperatorHelper
void ConvolutionHelperBase::ResolvingPadding(gsl::span inputDimensions)
{
- ResolveAutoPadding(m_kernel, inputDimensions);
+ ResolveAutoPadding(m_kernel, inputDimensions, m_isNhwc);
}
std::vector GemmHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index 7ce27bfebe..bc1ef2dac0 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -204,7 +204,8 @@ struct KernelArgs
std::vector InitializeKernelOutputDimensions(
gsl::span inputDimensions,
- const KernelArgs& args);
+ const KernelArgs& args,
+ bool isNhwc = false);
std::vector InitializeKernelOutputDimsTranspose(
gsl::span inputDimensions,
@@ -219,7 +220,8 @@ KernelArgs InitializeKernel(
void ResolveAutoPadding(
KernelArgs& args,
- gsl::span inputDimensions);
+ gsl::span inputDimensions,
+ bool isNhwc = false);
void MatMulShapeMapping(
std::vector& inputShape0,
@@ -450,13 +452,15 @@ class ConvolutionHelperBase
public:
enum FilterDims { K };
enum InputDims { N, C, H, W };
+ enum class NhwcInputDims { N, H, W, C };
public:
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
template
- ConvolutionHelperBase(const Info_t& info, const Shape_t& shape, bool transpose, bool hasDynamicPads, uint32_t inputTensorIndex, uint32_t filterTensorIndex) :
+ ConvolutionHelperBase(const Info_t& info, const Shape_t& shape, bool transpose, bool hasDynamicPads, bool isNhwc, uint32_t inputTensorIndex, uint32_t filterTensorIndex) :
m_inputTensorIndex(inputTensorIndex),
m_filterTensorIndex(filterTensorIndex),
+ m_isNhwc(isNhwc),
m_kernel(InitializeKernel(info, shape.GetInputTensorDimensionCount(inputTensorIndex), shape.GetInputTensorShape(filterTensorIndex)))
{
m_groupCount = info.template GetOptionalAttribute(AttrName::Group, 1);
@@ -487,6 +491,7 @@ protected:
uint32_t m_groupCount;
uint32_t m_inputTensorIndex;
uint32_t m_filterTensorIndex;
+ bool m_isNhwc;
KernelArgs m_kernel;
std::vector m_outputShapes;
};
@@ -495,28 +500,35 @@ class ConvHelper : public ConvolutionHelperBase
{
public:
template
- ConvHelper(const Info_t& info, const Shape_t& shape) : ConvolutionHelperBase(info, shape, false, false, 0, 1) {}
+ ConvHelper(const Info_t& info, const Shape_t& shape) : ConvolutionHelperBase(info, shape, false, false, false, 0, 1) {}
+};
+
+class NhwcConvHelper : public ConvolutionHelperBase
+{
+public:
+ template
+ NhwcConvHelper(const Info_t& info, const Shape_t& shape) : ConvolutionHelperBase(info, shape, false, false, true, 0, 1) {}
};
class ConvTransposeHelper : public ConvolutionHelperBase
{
public:
template
- ConvTransposeHelper(const Info_t& info, const Shape_t& shape) : ConvolutionHelperBase(info, shape, true, false, 0, 1) {}
+ ConvTransposeHelper(const Info_t& info, const Shape_t& shape) : ConvolutionHelperBase(info, shape, true, false, false, 0, 1) {}
};
class ConvTransposeWithDynamicPadsHelper : public ConvolutionHelperBase
{
public:
template
- ConvTransposeWithDynamicPadsHelper(const Info_t& info, const Shape_t& shape) : ConvolutionHelperBase(info, shape, true, true, 0, 1) {}
+ ConvTransposeWithDynamicPadsHelper(const Info_t& info, const Shape_t& shape) : ConvolutionHelperBase(info, shape, true, true, false, 0, 1) {}
};
class QLinearConvHelper : public ConvolutionHelperBase
{
public:
template
- QLinearConvHelper(const Info_t& info, const Shape_t& shape) : ConvolutionHelperBase(info, shape, false, false, 0, 3) {}
+ QLinearConvHelper(const Info_t& info, const Shape_t& shape) : ConvolutionHelperBase(info, shape, false, false, false, 0, 3) {}
};
class GemmHelper
@@ -1416,6 +1428,7 @@ public:
};
using ShapeInferenceHelper_Conv = ConvHelper;
+using ShapeInferenceHelper_NhwcConv = NhwcConvHelper;
using ShapeInferenceHelper_ConvTranspose = ConvTransposeHelper;
using ShapeInferenceHelper_ConvTransposeWithDynamicPads = ConvTransposeWithDynamicPadsHelper;
using ShapeInferenceHelper_ConvInteger = ConvHelper;
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index d5e4f8f134..b6fe40e5b7 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -410,6 +410,7 @@ namespace OperatorHelper
static const int sc_sinceVer_Attention = 1;
static const int sc_sinceVer_SkipLayerNormalization = 1;
static const int sc_sinceVer_EmbedLayerNormalization = 1;
+ static const int sc_sinceVer_NhwcConv = 1;
static const int sc_sinceVer_BiasAdd = 1;
static const int sc_sinceVer_QuickGelu = 1;
static const int sc_sinceVer_GroupNorm = 1;
diff --git a/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc b/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc
index a79677357c..3f298b0a8f 100644
--- a/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc
+++ b/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc
@@ -32,8 +32,9 @@ void TestNhwcConvOp(const NhwcConvOpAndTestAttributes& attributes,
int min_cuda_architecture = use_float16 ? 530 : 0;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get());
+ bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get());
- if (enable_cuda || enable_rocm) {
+ if (enable_cuda || enable_rocm || enable_dml) {
OpTester test("NhwcConv", 1, onnxruntime::kMSDomain);
test.AddAttribute("group", attributes.group);
test.AddAttribute("kernel_shape", attributes.kernel_shape);
@@ -82,6 +83,10 @@ void TestNhwcConvOp(const NhwcConvOpAndTestAttributes& attributes,
execution_providers.push_back(DefaultRocmExecutionProvider());
}
+ if (enable_dml) {
+ execution_providers.push_back(DefaultDmlExecutionProvider());
+ }
+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}