Enable QLinearAveragePooling DML EP (#17384) (#18240)

[Cherry Pick Reviewed]
DML EP Implementation for

[QLinearAveragePool](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAveragePool)
```
Note: Google Test filter = *QLinear*Pool*
[==========] Running 72 tests from 2 test suites.
[----------] Global test environment set-up.
[----------] 36 tests from QLinearGlobalAveragePool
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x1x32x32
[       OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32 (410 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x32x32x1
[       OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1 (641 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x256x8x8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8 (156 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x8x8x256
[       OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256 (134 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x255x7x7
[       OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7 (160 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x7x7x255
[       OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255 (145 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x255x8x8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8 (148 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x8x8x255
[       OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255 (129 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x256x7x7
[       OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7 (134 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x7x7x256
[       OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256 (131 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x256x8x8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8 (159 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x8x8x256
[       OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256 (168 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x255x7x7
[       OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7 (139 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x7x7x255
[       OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255 (170 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x255x8x8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8 (155 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x8x8x255
[       OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255 (156 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x256x7x7
[       OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7 (133 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x7x7x256
[       OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256 (149 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8 (131 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8
[       OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8 (127 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8 (153 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8
[       OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8 (129 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8 (133 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8
[       OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8 (135 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8 (129 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8
[       OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8 (152 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8 (140 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8
[       OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8 (133 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8 (135 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8
[       OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8 (147 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8 (156 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8
[       OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8 (155 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8 (138 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8
[       OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8 (155 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8 (144 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8
[       OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8 (139 ms)
[----------] 36 tests from QLinearGlobalAveragePool (5968 ms total)

[----------] 36 tests from QLinearPoolTest
[ RUN      ] QLinearPoolTest.AveragePool1D_ExcludePadPixel
[       OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel (480 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_IncludePadPixel
[       OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel (481 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_ExcludePadPixel
[       OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel (512 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_IncludePadPixel
[       OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel (455 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_MultiChannel
[       OK ] QLinearPoolTest.AveragePool2D_MultiChannel (463 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_ExcludePadPixel
[       OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel (448 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_IncludePadPixel
[       OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel (458 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc (171 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc (169 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc (152 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc (660 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc
[       OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc (150 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc (145 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc (146 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_BigImage
[       OK ] QLinearPoolTest.AveragePool2D_BigImage (505 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_BigImage_nhwc
[       OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc (161 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_Global
[       OK ] QLinearPoolTest.AveragePool2D_Global (481 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_Global_nhwc
[       OK ] QLinearPoolTest.AveragePool2D_Global_nhwc (152 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8 (461 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8 (448 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8 (471 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8 (473 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_MultiChannel_S8
[       OK ] QLinearPoolTest.AveragePool2D_MultiChannel_S8 (1507 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8 (477 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8 (493 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8 (158 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8 (146 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8 (146 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8 (158 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8 (157 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8 (145 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8 (147 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_BigImage_S8
[       OK ] QLinearPoolTest.AveragePool2D_BigImage_S8 (537 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8 (173 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_Global_S8
[       OK ] QLinearPoolTest.AveragePool2D_Global_S8 (457 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8 (150 ms)
[----------] 36 tests from QLinearPoolTest (12914 ms total)

[----------] Global test environment tear-down
[==========] 72 tests from 2 test suites ran. (18885 ms total)
[  PASSED  ] 72 tests.
memleakdbg:
----- No memory leaks detected -----
```

### Description
<!-- Describe your changes. -->

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
raoanag 2023-11-06 09:09:11 -08:00 committed by Jeff Bloomfield
parent cb7f28a16a
commit dcfff10f57
12 changed files with 339 additions and 6 deletions

View file

@ -24,7 +24,7 @@ struct EnumTraits<DML_TENSOR_TYPE>
template <>
struct EnumTraits<DML_OPERATOR_TYPE>
{
static constexpr auto ValueCount = 160;
static constexpr auto ValueCount = 161;
static constexpr size_t ActivationFunctionCount = 24;
};
@ -495,6 +495,12 @@ struct OperatorDescTraits<DML_ROI_POOLING_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING;
};
template <>
struct OperatorDescTraits<DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING;
};
template <>
struct OperatorDescTraits<DML_SLICE_OPERATOR_DESC>
{
@ -1496,6 +1502,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING>
using DescType = DML_ROI_POOLING_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING>
{
using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE>
{
@ -2522,6 +2534,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
case DML_OPERATOR_ACTIVATION_GELU:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
#pragma warning(push)
#pragma warning(disable: 4063)
case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING:
return std::invoke(std::forward<Visitor>(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
#pragma warning(pop)
default:
ORT_THROW_HR(E_INVALIDARG);
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);

View file

@ -869,6 +869,31 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA {
DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false },
};
constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA {
"DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING",
static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING),
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
13,
DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },

View file

@ -502,6 +502,24 @@ inline std::vector<OperatorField> GetFields(const DML_ROI_POOLING_OPERATOR_DESC&
OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<DML_SIZE_2D>(desc.PooledSize))),
};
}
inline std::vector<OperatorField> GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputScaleTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputZeroPointTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputZeroPointTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const UINT*>(desc.Strides), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const UINT*>(desc.WindowSize), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast<const UINT*>(desc.StartPadding), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast<const UINT*>(desc.Dilations), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast<UINT>(desc.IncludePadding))),
};
}
inline std::vector<OperatorField> GetFields(const DML_SLICE_OPERATOR_DESC& desc)
{
return {
@ -2509,6 +2527,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_ACTIVATION_GELU_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ACTIVATION_GELU_OPERATOR_DESC*>(opDesc.Desc)));
#pragma warning(push)
#pragma warning(disable: 4063)
case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING:
return AbstractOperatorDesc(
&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC*>(opDesc.Desc)));
#pragma warning(pop)
default:
ORT_THROW_HR(E_INVALIDARG);
return AbstractOperatorDesc(

View file

@ -0,0 +1,150 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelperBase
{
// For QLinear Avg Pool ORT and DML have same indexing order
enum OrtInputTensors : uint32_t
{
ortInput,
ortInputScale,
ortInputZeroPoint,
ortOutputScale,
ortOutputZeroPoint,
ortInputCount
};
public:
using Self = DmlOperatorQLinearAveragePooling;
DmlOperatorQLinearAveragePooling(
const MLOperatorKernelCreationContext& kernelInfo,
bool useGlobalPooling
)
: DmlOperator(kernelInfo),
PoolingHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), useGlobalPooling)
{
DmlOperator::Initialize(kernelInfo);
bool isNhwc = m_kernel.channelsLast;
std::vector<DimensionType> inputShape = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortInput);
std::vector<DimensionType> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
uint32_t dmlDimSize = m_inputTensorDescs[OrtInputTensors::ortInput].GetDimensionCount();
ML_CHECK_VALID_ARGUMENT(dmlDimSize >= 2);
// DML requires that DimensionCount be equal to Input.dmlDimSize - 2 for Pooling
uint32_t expectedSpatialDimCount = m_inputTensorDescs[0].GetDimensionCount() - 2;
if (m_kernel.spatialDimensionCount < expectedSpatialDimCount)
{
size_t shift = expectedSpatialDimCount - m_kernel.spatialDimensionCount;
for (int i = gsl::narrow_cast<int>(m_kernel.spatialDimensionCount) - 1; i >= 0; i--)
{
m_kernel.windowSize[i + shift] = m_kernel.windowSize[i];
m_kernel.windowSize[i] = 1;
m_kernel.strides[i + shift] = m_kernel.strides[i];
m_kernel.strides[i] = 1;
m_kernel.startPadding[i + shift] = m_kernel.startPadding[i];
m_kernel.startPadding[i] = 0;
m_kernel.endPadding[i + shift] = m_kernel.endPadding[i];
m_kernel.endPadding[i] = 0;
m_kernel.dilations[i + shift] = m_kernel.dilations[i];
m_kernel.dilations[i] = 1;
}
m_kernel.spatialDimensionCount = expectedSpatialDimCount;
}
// Initialize dimensionMapping for NCHW or NHWC layout
std::vector<uint32_t> dimensionMapping = {0u, dmlDimSize - 1u};
dimensionMapping.resize(dmlDimSize);
if (isNhwc)
{
// Form a remapping for dimensions so C is moved before the spatial dimensions.
// e.g. NWC -> {0,2,1} -> NCW
// NHWC -> {0,3,1,2} -> NCHW
// NDHWC -> {0,4,1,2,3} -> NCDHW
std::iota(dimensionMapping.begin() + 2, dimensionMapping.end(), 1u);
}
else
{
// Use NCHW {0,1,2,3} format with increasing order of indexs
std::iota(dimensionMapping.begin() + 1, dimensionMapping.end(), 1u);
}
m_inputTensorDescs[OrtInputTensors::ortInput].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
// Reshape the Input Scale to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the H channel.
m_inputTensorDescs[OrtInputTensors::ortInputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
// Reshape the Input ZeroPoint to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the H channel.
if (kernelInfo.IsInputValid(OrtInputTensors::ortInputZeroPoint))
{
m_inputTensorDescs[OrtInputTensors::ortInputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
}
// Reshape the Output Scale to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the H channel.
m_inputTensorDescs[OrtInputTensors::ortOutputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
// Reshape the Input ZeroPoint to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the H channel.
if (kernelInfo.IsInputValid(OrtInputTensors::ortOutputZeroPoint))
{
m_inputTensorDescs[OrtInputTensors::ortOutputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
}
// Initialize the output description while overriding the shape
m_outputTensorDescs[0].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
assert(m_kernel.spatialDimensionCount <= ARRAYSIZE(m_kernel.windowSize));
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC qLinearAvgPooldesc = {};
qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput];
qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale];
qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint];
qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];;
qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];;
qLinearAvgPooldesc.OutputTensor = &outputDescs[0];
qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount;
qLinearAvgPooldesc.WindowSize = m_kernel.windowSize;
qLinearAvgPooldesc.Strides = m_kernel.strides;
qLinearAvgPooldesc.StartPadding = m_kernel.startPadding;
qLinearAvgPooldesc.EndPadding = m_kernel.endPadding;
qLinearAvgPooldesc.Dilations = m_kernel.dilations;
qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute<bool>(AttrName::CountIncludePad, false);
DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
};
template <bool UseGlobalPooling>
class DmlOperatorQuantizedPoolingTemplate : public DmlOperatorQLinearAveragePooling
{
public:
DmlOperatorQuantizedPoolingTemplate(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperatorQLinearAveragePooling(kernelInfo, UseGlobalPooling)
{
}
};
DML_OP_DEFINE_CREATION_FUNCTION(QLinearAveragePool, DmlOperatorQuantizedPoolingTemplate<false>);
DML_OP_DEFINE_CREATION_FUNCTION(QLinearGlobalAveragePool, DmlOperatorQuantizedPoolingTemplate<true>);
} // namespace Dml

View file

@ -320,6 +320,8 @@ DML_OP_EXTERN_CREATION_FUNCTION(GlobalMaxPool);
DML_OP_EXTERN_CREATION_FUNCTION(LpPool);
DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool);
DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearAveragePool);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearGlobalAveragePool);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16);
DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization);
@ -634,6 +636,10 @@ constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinea
SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32,
};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListQLinearAveragePool = {
SupportedTensorDataTypes::Ints8Bit
};
template<typename... Args>
constexpr auto requiredConstantCpuInputs(Args... args)
{
@ -1040,6 +1046,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 9, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))},
{REG_INFO( 11, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, // 11 is identical to 9.
{REG_INFO_MS( 1, QLinearAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearGlobalAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)},
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)},
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)},

View file

@ -315,3 +315,39 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
}
m_bufferTensorDesc.DimensionCount = newDimensionCount;
}
// Uses dimensionMapping to reorder m_sizes and m_strides to match specific Tensor layout
void TensorDesc::PermuteDimensions(gsl::span<const uint32_t> dimensionMapping, const TensorAxis alignment)
{
EnsureStridesExist();
SetDimensionCount(static_cast<uint32_t>(dimensionMapping.size()), alignment);
// Shuffle m_sizes and m_strides according to the indexes pointed by dimensionMapping
std::vector<uint32_t> tempSizes{m_sizes, m_sizes + MaximumDimensionCount};
std::vector<uint32_t> tempStrides{m_strides, m_strides + MaximumDimensionCount};
for (size_t i = 0; i < dimensionMapping.size(); i++)
{
m_sizes[i] = tempSizes[dimensionMapping[i]];
m_strides[i] = tempStrides[dimensionMapping[i]];
}
m_bufferTensorDesc.Sizes = m_sizes;
m_bufferTensorDesc.Strides = m_strides;
}
void TensorDesc::EnsureStridesExist()
{
if (m_bufferTensorDesc.Strides != nullptr)
{
// Strides are populated
return;
}
uint32_t stride = 1;
for (uint32_t i = m_bufferTensorDesc.DimensionCount; i-- > 0;)
{
m_strides[i] = stride;
stride *= m_sizes[i];
}
}

View file

@ -44,6 +44,7 @@ namespace Dml
gsl::span<const uint32_t> GetSizes() const { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; }
gsl::span<const uint32_t> GetStrides() const;
void SetStrides(gsl::span<const uint32_t> strides);
void PermuteDimensions(gsl::span<const uint32_t> dimensionMapping, const TensorAxis alignment);
inline uint64_t GetBufferSizeInBytes() const
{
@ -90,6 +91,8 @@ namespace Dml
uint32_t m_sizes[MaximumDimensionCount] = {};
uint32_t m_strides[MaximumDimensionCount] = {};
DML_BUFFER_TENSOR_DESC m_bufferTensorDesc = {};
void EnsureStridesExist();
};
class TensorDescBuilder

View file

@ -23,8 +23,8 @@ namespace AttrName
static constexpr const char* BlockSize = "blocksize";
static constexpr const char* Border = "border";
static constexpr const char* Broadcast = "broadcast";
static constexpr const char* ChannelsLast = "channels_last";
static constexpr const char* CeilMode = "ceil_mode";
static constexpr const char* ChannelsLast = "channels_last";
static constexpr const char* Clip = "clip";
static constexpr const char* CoordinateTransformationMode = "coordinate_transformation_mode";
static constexpr const char* CountIncludePad = "count_include_pad";

View file

@ -365,13 +365,20 @@ namespace OperatorHelper
}
// Creates a kernel that spans the entire spatial dimensions of the input.
KernelArgs InitializeGlobalKernel(gsl::span<const DimensionType> inputDimensions)
KernelArgs InitializeGlobalKernel(
const MLOperatorAttributes& kernelInfo,
gsl::span<const DimensionType> inputDimensions)
{
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() > NonspatialDimensionCount); // Must be at least 1D convolution (in 3D tensor)
uint32_t spatialDimensionCount = gsl::narrow_cast<uint32_t>(inputDimensions.size()) - NonspatialDimensionCount;
ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); // Support up to 3D convolution (in 5D tensor).
KernelArgs args(spatialDimensionCount);
args.useCeilingOutputShape = kernelInfo.GetOptionalAttribute<bool>(AttrName::CeilMode, 0);
args.channelsLast = kernelInfo.GetOptionalAttribute<bool>(AttrName::ChannelsLast, 0);
// For Global Pooling, kernel size equal to the spatial dimension of input tensor
// NHWC layout need to offset by one dim to acount for channel placed at the end
int dimOffset = args.channelsLast ? 1 : 0;
for (size_t dim = 0; dim < spatialDimensionCount; ++dim)
{
@ -379,7 +386,7 @@ namespace OperatorHelper
args.dilations[dim] = 1;
args.startPadding[dim] = 0;
args.endPadding[dim] = 0;
args.windowSize[dim] = gsl::narrow_cast<uint32_t>(inputDimensions[inputDimensions.size() - spatialDimensionCount + dim]);
args.windowSize[dim] = gsl::narrow_cast<uint32_t>(inputDimensions[inputDimensions.size() - spatialDimensionCount + dim - dimOffset]);
}
return args;
@ -495,6 +502,7 @@ namespace OperatorHelper
}
args.useCeilingOutputShape = kernelInfo.GetOptionalAttribute<bool>(AttrName::CeilMode, 0);
args.channelsLast = kernelInfo.GetOptionalAttribute<bool>(AttrName::ChannelsLast, 0);
return args;
}
@ -2012,7 +2020,37 @@ namespace OperatorHelper
}
return outputShapes;
}
std::vector<EdgeShapes> QLinearAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
auto inputShape = shapeInfo.GetInputTensorShape(0);
std::vector<DimensionType> outputDimensions = InitializeKernelOutputDimensions(inputShape, m_kernel, m_kernel.channelsLast);
const uint32_t outputCount = shapeInfo.GetOutputCount();
std::vector<EdgeShapes> outputShapes;
for (uint32_t i = 0; i < outputCount; ++i)
{
outputShapes.push_back(outputDimensions);
}
return outputShapes;
}
std::vector<EdgeShapes> QLinearGlobalAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
auto inputShape = shapeInfo.GetInputTensorShape(0);
std::vector<DimensionType> outputDimensions = InitializeKernelOutputDimensions(inputShape, m_kernel, m_kernel.channelsLast);
const uint32_t outputCount = shapeInfo.GetOutputCount();
std::vector<EdgeShapes> outputShapes;
for (uint32_t i = 0; i < outputCount; ++i)
{
outputShapes.push_back(outputDimensions);
}
return outputShapes;
}
std::vector<EdgeShapes> RoiPoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
auto roiShape = shapeInfo.GetInputTensorShape(InputTensors::ROIS);

View file

@ -160,6 +160,7 @@ struct KernelArgs
bool autoPad = false;
bool autoPadSameUpper = false;
bool useCeilingOutputShape = false;
bool channelsLast = false;
uint32_t spatialDimensionCount = 0;
KernelArgs(uint32_t spatialDimensionCount) : spatialDimensionCount(spatialDimensionCount)
@ -188,6 +189,7 @@ struct KernelArgs
KernelArgs(KernelArgs const& kernelArgs, uint32_t minimumDimensionCount)
: autoPad(kernelArgs.autoPad),
autoPadSameUpper(kernelArgs.autoPadSameUpper),
channelsLast(kernelArgs.channelsLast),
spatialDimensionCount(std::max(kernelArgs.spatialDimensionCount, minimumDimensionCount))
{
ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount);
@ -211,7 +213,9 @@ std::vector<DimensionType> InitializeKernelOutputDimsTranspose(
gsl::span<const DimensionType> inputDimensions,
const KernelArgs& args);
KernelArgs InitializeGlobalKernel(gsl::span<const DimensionType> inputDimensions);
KernelArgs InitializeGlobalKernel(
const MLOperatorAttributes& kernelInfo,
gsl::span<const DimensionType> inputDimensions);
KernelArgs InitializeKernel(
const MLOperatorAttributes& kernelInfo,
@ -1059,7 +1063,7 @@ public:
bool useGlobalPooling
)
: m_kernel(useGlobalPooling
? InitializeGlobalKernel(shape.GetInputTensorShape(0))
? InitializeGlobalKernel(info, shape.GetInputTensorShape(0))
: InitializeKernel(info, static_cast<uint32_t>(shape.GetInputTensorShape(0).size()), gsl::span<uint32_t>()))
{
if (!useGlobalPooling)
@ -1161,6 +1165,24 @@ public:
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
};
class QLinearAveragePoolingHelper : public PoolingHelperBase
{
public:
template <typename Info_t, typename Shape_t>
QLinearAveragePoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, false) {}
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
};
class QLinearGlobalAveragePoolingHelper : public PoolingHelperBase
{
public:
template <typename Info_t, typename Shape_t>
QLinearGlobalAveragePoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, true) {}
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
};
class SqueezeHelper
{
public:
@ -1490,6 +1512,8 @@ using ShapeInferenceHelper_MaxUnpool = UnpoolingHelper;
using ShapeInferenceHelper_LpPool = PoolingHelper;
using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper;
using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper;
using ShapeInferenceHelper_QLinearAveragePool = QLinearAveragePoolingHelper;
using ShapeInferenceHelper_QLinearGlobalAveragePool = QLinearGlobalAveragePoolingHelper;
using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper<RoiAlignHelper, 10>;
using ShapeInferenceHelper_RoiAlign16 = VersionedOpsetHelper<RoiAlignHelper, 16>;
using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper;

View file

@ -445,6 +445,8 @@ namespace OperatorHelper
static const int sc_sinceVer_GroupNorm = 1;
static const int sc_sinceVer_QLinearConcat = 1;
static const int sc_sinceVer_RotaryEmbedding = 1;
static const int sc_sinceVer_QLinearAveragePool = 1;
static const int sc_sinceVer_QLinearGlobalAveragePool = 1;
} // namespace MsftOperatorSet1
} // namespace OperatorHelper

View file

@ -66,6 +66,9 @@ void RunQLinearGlobalAveragePool(
test.AddInput<float>("y_scale", {}, {y_scale});
test.AddInput<T8Bits>("y_zero_point", {}, {y_zero_point});
test.AddOutput<T8Bits>("Y", y_dims, y_data);
if (channels_last) {
test.AddAttribute("channels_last", (int64_t)1LL);
}
auto q8checker = [&](const std::vector<OrtValue>& fetches, const std::string& provider_type) {
const OrtValue& ort_value = fetches[0];