mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-12 00:59:23 +00:00
[Cherry Pick Reviewed] DML EP Implementation for [QLinearAveragePool](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAveragePool) ``` Note: Google Test filter = *QLinear*Pool* [==========] Running 72 tests from 2 test suites. [----------] Global test environment set-up. [----------] 36 tests from QLinearGlobalAveragePool [ RUN ] QLinearGlobalAveragePool.Nhwc_1x1x32x32 [ OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32 (410 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x32x32x1 [ OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1 (641 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8 (156 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x256 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256 (134 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7 (160 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x255 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255 (145 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8 (148 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x255 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255 (129 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7 (134 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x256 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256 (131 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8 (159 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x256 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256 (168 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7 (139 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x255 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255 (170 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8 (155 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x255 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255 (156 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7 (133 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x256 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256 (149 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8 (131 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8 (127 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8 (153 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8 (129 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8 (133 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8 (135 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8 (129 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8 (152 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8 (140 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8 (133 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8 (135 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8 (147 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8 (156 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8 (155 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8 (138 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8 (155 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8 (144 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8 (139 ms) [----------] 36 tests from QLinearGlobalAveragePool (5968 ms total) [----------] 36 tests from QLinearPoolTest [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel (480 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel (481 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel (512 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel (455 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel (463 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel (448 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel (458 ms) [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc (171 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc (169 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc (152 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc (660 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc (150 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc (145 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc (146 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage [ OK ] QLinearPoolTest.AveragePool2D_BigImage (505 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage_nhwc [ OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc (161 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global [ OK ] QLinearPoolTest.AveragePool2D_Global (481 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global_nhwc [ OK ] QLinearPoolTest.AveragePool2D_Global_nhwc (152 ms) [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8 (461 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8 (448 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8 (471 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8 (473 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_S8 [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_S8 (1507 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8 (477 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8 (493 ms) [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8 (158 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8 (146 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8 (146 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8 (158 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8 (157 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8 (145 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8 (147 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage_S8 [ OK ] QLinearPoolTest.AveragePool2D_BigImage_S8 (537 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8 (173 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global_S8 [ OK ] QLinearPoolTest.AveragePool2D_Global_S8 (457 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8 (150 ms) [----------] 36 tests from QLinearPoolTest (12914 ms total) [----------] Global test environment tear-down [==========] 72 tests from 2 test suites ran. (18885 ms total) [ PASSED ] 72 tests. memleakdbg: ----- No memory leaks detected ----- ``` ### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
cb7f28a16a
commit
dcfff10f57
12 changed files with 339 additions and 6 deletions
|
|
@ -24,7 +24,7 @@ struct EnumTraits<DML_TENSOR_TYPE>
|
|||
template <>
|
||||
struct EnumTraits<DML_OPERATOR_TYPE>
|
||||
{
|
||||
static constexpr auto ValueCount = 160;
|
||||
static constexpr auto ValueCount = 161;
|
||||
static constexpr size_t ActivationFunctionCount = 24;
|
||||
};
|
||||
|
||||
|
|
@ -495,6 +495,12 @@ struct OperatorDescTraits<DML_ROI_POOLING_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_SLICE_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -1496,6 +1502,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING>
|
|||
using DescType = DML_ROI_POOLING_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING>
|
||||
{
|
||||
using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE>
|
||||
{
|
||||
|
|
@ -2522,6 +2534,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
case DML_OPERATOR_ACTIVATION_GELU:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable: 4063)
|
||||
case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
#pragma warning(pop)
|
||||
|
||||
default:
|
||||
ORT_THROW_HR(E_INVALIDARG);
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
|
|
|
|||
|
|
@ -869,6 +869,31 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA {
|
|||
DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING",
|
||||
static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING),
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
13,
|
||||
DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
|
|
|
|||
|
|
@ -502,6 +502,24 @@ inline std::vector<OperatorField> GetFields(const DML_ROI_POOLING_OPERATOR_DESC&
|
|||
OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<DML_SIZE_2D>(desc.PooledSize))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputScaleTensor))),
|
||||
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputZeroPointTensor))),
|
||||
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleTensor))),
|
||||
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputZeroPointTensor))),
|
||||
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
|
||||
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const UINT*>(desc.Strides), desc.DimensionCount)),
|
||||
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const UINT*>(desc.WindowSize), desc.DimensionCount)),
|
||||
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast<const UINT*>(desc.StartPadding), desc.DimensionCount)),
|
||||
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
|
||||
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast<const UINT*>(desc.Dilations), desc.DimensionCount)),
|
||||
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast<UINT>(desc.IncludePadding))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_SLICE_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -2509,6 +2527,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_ACTIVATION_GELU_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ACTIVATION_GELU_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable: 4063)
|
||||
case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
#pragma warning(pop)
|
||||
|
||||
default:
|
||||
ORT_THROW_HR(E_INVALIDARG);
|
||||
return AbstractOperatorDesc(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,150 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "precomp.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
|
||||
class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelperBase
|
||||
{
|
||||
// For QLinear Avg Pool ORT and DML have same indexing order
|
||||
enum OrtInputTensors : uint32_t
|
||||
{
|
||||
ortInput,
|
||||
ortInputScale,
|
||||
ortInputZeroPoint,
|
||||
ortOutputScale,
|
||||
ortOutputZeroPoint,
|
||||
ortInputCount
|
||||
};
|
||||
|
||||
public:
|
||||
using Self = DmlOperatorQLinearAveragePooling;
|
||||
|
||||
DmlOperatorQLinearAveragePooling(
|
||||
const MLOperatorKernelCreationContext& kernelInfo,
|
||||
bool useGlobalPooling
|
||||
)
|
||||
: DmlOperator(kernelInfo),
|
||||
PoolingHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), useGlobalPooling)
|
||||
{
|
||||
DmlOperator::Initialize(kernelInfo);
|
||||
|
||||
bool isNhwc = m_kernel.channelsLast;
|
||||
std::vector<DimensionType> inputShape = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortInput);
|
||||
std::vector<DimensionType> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
|
||||
|
||||
uint32_t dmlDimSize = m_inputTensorDescs[OrtInputTensors::ortInput].GetDimensionCount();
|
||||
ML_CHECK_VALID_ARGUMENT(dmlDimSize >= 2);
|
||||
|
||||
// DML requires that DimensionCount be equal to Input.dmlDimSize - 2 for Pooling
|
||||
uint32_t expectedSpatialDimCount = m_inputTensorDescs[0].GetDimensionCount() - 2;
|
||||
if (m_kernel.spatialDimensionCount < expectedSpatialDimCount)
|
||||
{
|
||||
size_t shift = expectedSpatialDimCount - m_kernel.spatialDimensionCount;
|
||||
|
||||
for (int i = gsl::narrow_cast<int>(m_kernel.spatialDimensionCount) - 1; i >= 0; i--)
|
||||
{
|
||||
m_kernel.windowSize[i + shift] = m_kernel.windowSize[i];
|
||||
m_kernel.windowSize[i] = 1;
|
||||
|
||||
m_kernel.strides[i + shift] = m_kernel.strides[i];
|
||||
m_kernel.strides[i] = 1;
|
||||
|
||||
m_kernel.startPadding[i + shift] = m_kernel.startPadding[i];
|
||||
m_kernel.startPadding[i] = 0;
|
||||
|
||||
m_kernel.endPadding[i + shift] = m_kernel.endPadding[i];
|
||||
m_kernel.endPadding[i] = 0;
|
||||
|
||||
m_kernel.dilations[i + shift] = m_kernel.dilations[i];
|
||||
m_kernel.dilations[i] = 1;
|
||||
}
|
||||
|
||||
m_kernel.spatialDimensionCount = expectedSpatialDimCount;
|
||||
}
|
||||
|
||||
// Initialize dimensionMapping for NCHW or NHWC layout
|
||||
std::vector<uint32_t> dimensionMapping = {0u, dmlDimSize - 1u};
|
||||
dimensionMapping.resize(dmlDimSize);
|
||||
if (isNhwc)
|
||||
{
|
||||
// Form a remapping for dimensions so C is moved before the spatial dimensions.
|
||||
// e.g. NWC -> {0,2,1} -> NCW
|
||||
// NHWC -> {0,3,1,2} -> NCHW
|
||||
// NDHWC -> {0,4,1,2,3} -> NCDHW
|
||||
std::iota(dimensionMapping.begin() + 2, dimensionMapping.end(), 1u);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Use NCHW {0,1,2,3} format with increasing order of indexs
|
||||
std::iota(dimensionMapping.begin() + 1, dimensionMapping.end(), 1u);
|
||||
}
|
||||
m_inputTensorDescs[OrtInputTensors::ortInput].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
|
||||
|
||||
// Reshape the Input Scale to be the same dimension as the input tensor.
|
||||
// The 1D tensor needs to be moved to the H channel.
|
||||
m_inputTensorDescs[OrtInputTensors::ortInputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
|
||||
|
||||
// Reshape the Input ZeroPoint to be the same dimension as the input tensor.
|
||||
// The 1D tensor needs to be moved to the H channel.
|
||||
if (kernelInfo.IsInputValid(OrtInputTensors::ortInputZeroPoint))
|
||||
{
|
||||
m_inputTensorDescs[OrtInputTensors::ortInputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
|
||||
}
|
||||
|
||||
// Reshape the Output Scale to be the same dimension as the input tensor.
|
||||
// The 1D tensor needs to be moved to the H channel.
|
||||
m_inputTensorDescs[OrtInputTensors::ortOutputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
|
||||
|
||||
// Reshape the Input ZeroPoint to be the same dimension as the input tensor.
|
||||
// The 1D tensor needs to be moved to the H channel.
|
||||
if (kernelInfo.IsInputValid(OrtInputTensors::ortOutputZeroPoint))
|
||||
{
|
||||
m_inputTensorDescs[OrtInputTensors::ortOutputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
|
||||
}
|
||||
|
||||
// Initialize the output description while overriding the shape
|
||||
m_outputTensorDescs[0].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
|
||||
|
||||
assert(m_kernel.spatialDimensionCount <= ARRAYSIZE(m_kernel.windowSize));
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC qLinearAvgPooldesc = {};
|
||||
|
||||
qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput];
|
||||
qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale];
|
||||
qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint];
|
||||
qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];;
|
||||
qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];;
|
||||
qLinearAvgPooldesc.OutputTensor = &outputDescs[0];
|
||||
qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount;
|
||||
qLinearAvgPooldesc.WindowSize = m_kernel.windowSize;
|
||||
qLinearAvgPooldesc.Strides = m_kernel.strides;
|
||||
qLinearAvgPooldesc.StartPadding = m_kernel.startPadding;
|
||||
qLinearAvgPooldesc.EndPadding = m_kernel.endPadding;
|
||||
qLinearAvgPooldesc.Dilations = m_kernel.dilations;
|
||||
qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute<bool>(AttrName::CountIncludePad, false);
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelInfo);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool UseGlobalPooling>
|
||||
class DmlOperatorQuantizedPoolingTemplate : public DmlOperatorQLinearAveragePooling
|
||||
{
|
||||
public:
|
||||
DmlOperatorQuantizedPoolingTemplate(const MLOperatorKernelCreationContext& kernelInfo)
|
||||
: DmlOperatorQLinearAveragePooling(kernelInfo, UseGlobalPooling)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(QLinearAveragePool, DmlOperatorQuantizedPoolingTemplate<false>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(QLinearGlobalAveragePool, DmlOperatorQuantizedPoolingTemplate<true>);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
@ -320,6 +320,8 @@ DML_OP_EXTERN_CREATION_FUNCTION(GlobalMaxPool);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(LpPool);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(QLinearAveragePool);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(QLinearGlobalAveragePool);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization);
|
||||
|
|
@ -634,6 +636,10 @@ constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinea
|
|||
SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32,
|
||||
};
|
||||
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListQLinearAveragePool = {
|
||||
SupportedTensorDataTypes::Ints8Bit
|
||||
};
|
||||
|
||||
template<typename... Args>
|
||||
constexpr auto requiredConstantCpuInputs(Args... args)
|
||||
{
|
||||
|
|
@ -1040,6 +1046,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 9, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))},
|
||||
{REG_INFO( 11, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, // 11 is identical to 9.
|
||||
|
||||
{REG_INFO_MS( 1, QLinearAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, QLinearGlobalAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)},
|
||||
|
|
|
|||
|
|
@ -315,3 +315,39 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
|
|||
}
|
||||
m_bufferTensorDesc.DimensionCount = newDimensionCount;
|
||||
}
|
||||
|
||||
// Uses dimensionMapping to reorder m_sizes and m_strides to match specific Tensor layout
|
||||
void TensorDesc::PermuteDimensions(gsl::span<const uint32_t> dimensionMapping, const TensorAxis alignment)
|
||||
{
|
||||
EnsureStridesExist();
|
||||
SetDimensionCount(static_cast<uint32_t>(dimensionMapping.size()), alignment);
|
||||
|
||||
// Shuffle m_sizes and m_strides according to the indexes pointed by dimensionMapping
|
||||
std::vector<uint32_t> tempSizes{m_sizes, m_sizes + MaximumDimensionCount};
|
||||
std::vector<uint32_t> tempStrides{m_strides, m_strides + MaximumDimensionCount};
|
||||
|
||||
for (size_t i = 0; i < dimensionMapping.size(); i++)
|
||||
{
|
||||
m_sizes[i] = tempSizes[dimensionMapping[i]];
|
||||
m_strides[i] = tempStrides[dimensionMapping[i]];
|
||||
}
|
||||
|
||||
m_bufferTensorDesc.Sizes = m_sizes;
|
||||
m_bufferTensorDesc.Strides = m_strides;
|
||||
}
|
||||
|
||||
void TensorDesc::EnsureStridesExist()
|
||||
{
|
||||
if (m_bufferTensorDesc.Strides != nullptr)
|
||||
{
|
||||
// Strides are populated
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t stride = 1;
|
||||
for (uint32_t i = m_bufferTensorDesc.DimensionCount; i-- > 0;)
|
||||
{
|
||||
m_strides[i] = stride;
|
||||
stride *= m_sizes[i];
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ namespace Dml
|
|||
gsl::span<const uint32_t> GetSizes() const { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; }
|
||||
gsl::span<const uint32_t> GetStrides() const;
|
||||
void SetStrides(gsl::span<const uint32_t> strides);
|
||||
void PermuteDimensions(gsl::span<const uint32_t> dimensionMapping, const TensorAxis alignment);
|
||||
|
||||
inline uint64_t GetBufferSizeInBytes() const
|
||||
{
|
||||
|
|
@ -90,6 +91,8 @@ namespace Dml
|
|||
uint32_t m_sizes[MaximumDimensionCount] = {};
|
||||
uint32_t m_strides[MaximumDimensionCount] = {};
|
||||
DML_BUFFER_TENSOR_DESC m_bufferTensorDesc = {};
|
||||
|
||||
void EnsureStridesExist();
|
||||
};
|
||||
|
||||
class TensorDescBuilder
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@ namespace AttrName
|
|||
static constexpr const char* BlockSize = "blocksize";
|
||||
static constexpr const char* Border = "border";
|
||||
static constexpr const char* Broadcast = "broadcast";
|
||||
static constexpr const char* ChannelsLast = "channels_last";
|
||||
static constexpr const char* CeilMode = "ceil_mode";
|
||||
static constexpr const char* ChannelsLast = "channels_last";
|
||||
static constexpr const char* Clip = "clip";
|
||||
static constexpr const char* CoordinateTransformationMode = "coordinate_transformation_mode";
|
||||
static constexpr const char* CountIncludePad = "count_include_pad";
|
||||
|
|
|
|||
|
|
@ -365,13 +365,20 @@ namespace OperatorHelper
|
|||
}
|
||||
|
||||
// Creates a kernel that spans the entire spatial dimensions of the input.
|
||||
KernelArgs InitializeGlobalKernel(gsl::span<const DimensionType> inputDimensions)
|
||||
KernelArgs InitializeGlobalKernel(
|
||||
const MLOperatorAttributes& kernelInfo,
|
||||
gsl::span<const DimensionType> inputDimensions)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() > NonspatialDimensionCount); // Must be at least 1D convolution (in 3D tensor)
|
||||
uint32_t spatialDimensionCount = gsl::narrow_cast<uint32_t>(inputDimensions.size()) - NonspatialDimensionCount;
|
||||
ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); // Support up to 3D convolution (in 5D tensor).
|
||||
|
||||
KernelArgs args(spatialDimensionCount);
|
||||
args.useCeilingOutputShape = kernelInfo.GetOptionalAttribute<bool>(AttrName::CeilMode, 0);
|
||||
args.channelsLast = kernelInfo.GetOptionalAttribute<bool>(AttrName::ChannelsLast, 0);
|
||||
// For Global Pooling, kernel size equal to the spatial dimension of input tensor
|
||||
// NHWC layout need to offset by one dim to acount for channel placed at the end
|
||||
int dimOffset = args.channelsLast ? 1 : 0;
|
||||
|
||||
for (size_t dim = 0; dim < spatialDimensionCount; ++dim)
|
||||
{
|
||||
|
|
@ -379,7 +386,7 @@ namespace OperatorHelper
|
|||
args.dilations[dim] = 1;
|
||||
args.startPadding[dim] = 0;
|
||||
args.endPadding[dim] = 0;
|
||||
args.windowSize[dim] = gsl::narrow_cast<uint32_t>(inputDimensions[inputDimensions.size() - spatialDimensionCount + dim]);
|
||||
args.windowSize[dim] = gsl::narrow_cast<uint32_t>(inputDimensions[inputDimensions.size() - spatialDimensionCount + dim - dimOffset]);
|
||||
}
|
||||
|
||||
return args;
|
||||
|
|
@ -495,6 +502,7 @@ namespace OperatorHelper
|
|||
}
|
||||
|
||||
args.useCeilingOutputShape = kernelInfo.GetOptionalAttribute<bool>(AttrName::CeilMode, 0);
|
||||
args.channelsLast = kernelInfo.GetOptionalAttribute<bool>(AttrName::ChannelsLast, 0);
|
||||
|
||||
return args;
|
||||
}
|
||||
|
|
@ -2012,7 +2020,37 @@ namespace OperatorHelper
|
|||
}
|
||||
return outputShapes;
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> QLinearAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
auto inputShape = shapeInfo.GetInputTensorShape(0);
|
||||
std::vector<DimensionType> outputDimensions = InitializeKernelOutputDimensions(inputShape, m_kernel, m_kernel.channelsLast);
|
||||
|
||||
const uint32_t outputCount = shapeInfo.GetOutputCount();
|
||||
|
||||
std::vector<EdgeShapes> outputShapes;
|
||||
for (uint32_t i = 0; i < outputCount; ++i)
|
||||
{
|
||||
outputShapes.push_back(outputDimensions);
|
||||
}
|
||||
return outputShapes;
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> QLinearGlobalAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
auto inputShape = shapeInfo.GetInputTensorShape(0);
|
||||
std::vector<DimensionType> outputDimensions = InitializeKernelOutputDimensions(inputShape, m_kernel, m_kernel.channelsLast);
|
||||
|
||||
const uint32_t outputCount = shapeInfo.GetOutputCount();
|
||||
|
||||
std::vector<EdgeShapes> outputShapes;
|
||||
for (uint32_t i = 0; i < outputCount; ++i)
|
||||
{
|
||||
outputShapes.push_back(outputDimensions);
|
||||
}
|
||||
return outputShapes;
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> RoiPoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
auto roiShape = shapeInfo.GetInputTensorShape(InputTensors::ROIS);
|
||||
|
|
|
|||
|
|
@ -160,6 +160,7 @@ struct KernelArgs
|
|||
bool autoPad = false;
|
||||
bool autoPadSameUpper = false;
|
||||
bool useCeilingOutputShape = false;
|
||||
bool channelsLast = false;
|
||||
uint32_t spatialDimensionCount = 0;
|
||||
|
||||
KernelArgs(uint32_t spatialDimensionCount) : spatialDimensionCount(spatialDimensionCount)
|
||||
|
|
@ -188,6 +189,7 @@ struct KernelArgs
|
|||
KernelArgs(KernelArgs const& kernelArgs, uint32_t minimumDimensionCount)
|
||||
: autoPad(kernelArgs.autoPad),
|
||||
autoPadSameUpper(kernelArgs.autoPadSameUpper),
|
||||
channelsLast(kernelArgs.channelsLast),
|
||||
spatialDimensionCount(std::max(kernelArgs.spatialDimensionCount, minimumDimensionCount))
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount);
|
||||
|
|
@ -211,7 +213,9 @@ std::vector<DimensionType> InitializeKernelOutputDimsTranspose(
|
|||
gsl::span<const DimensionType> inputDimensions,
|
||||
const KernelArgs& args);
|
||||
|
||||
KernelArgs InitializeGlobalKernel(gsl::span<const DimensionType> inputDimensions);
|
||||
KernelArgs InitializeGlobalKernel(
|
||||
const MLOperatorAttributes& kernelInfo,
|
||||
gsl::span<const DimensionType> inputDimensions);
|
||||
|
||||
KernelArgs InitializeKernel(
|
||||
const MLOperatorAttributes& kernelInfo,
|
||||
|
|
@ -1059,7 +1063,7 @@ public:
|
|||
bool useGlobalPooling
|
||||
)
|
||||
: m_kernel(useGlobalPooling
|
||||
? InitializeGlobalKernel(shape.GetInputTensorShape(0))
|
||||
? InitializeGlobalKernel(info, shape.GetInputTensorShape(0))
|
||||
: InitializeKernel(info, static_cast<uint32_t>(shape.GetInputTensorShape(0).size()), gsl::span<uint32_t>()))
|
||||
{
|
||||
if (!useGlobalPooling)
|
||||
|
|
@ -1161,6 +1165,24 @@ public:
|
|||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
};
|
||||
|
||||
class QLinearAveragePoolingHelper : public PoolingHelperBase
|
||||
{
|
||||
public:
|
||||
template <typename Info_t, typename Shape_t>
|
||||
QLinearAveragePoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, false) {}
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
|
||||
};
|
||||
|
||||
class QLinearGlobalAveragePoolingHelper : public PoolingHelperBase
|
||||
{
|
||||
public:
|
||||
template <typename Info_t, typename Shape_t>
|
||||
QLinearGlobalAveragePoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, true) {}
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
|
||||
};
|
||||
|
||||
class SqueezeHelper
|
||||
{
|
||||
public:
|
||||
|
|
@ -1490,6 +1512,8 @@ using ShapeInferenceHelper_MaxUnpool = UnpoolingHelper;
|
|||
using ShapeInferenceHelper_LpPool = PoolingHelper;
|
||||
using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper;
|
||||
using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper;
|
||||
using ShapeInferenceHelper_QLinearAveragePool = QLinearAveragePoolingHelper;
|
||||
using ShapeInferenceHelper_QLinearGlobalAveragePool = QLinearGlobalAveragePoolingHelper;
|
||||
using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper<RoiAlignHelper, 10>;
|
||||
using ShapeInferenceHelper_RoiAlign16 = VersionedOpsetHelper<RoiAlignHelper, 16>;
|
||||
using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper;
|
||||
|
|
|
|||
|
|
@ -445,6 +445,8 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_GroupNorm = 1;
|
||||
static const int sc_sinceVer_QLinearConcat = 1;
|
||||
static const int sc_sinceVer_RotaryEmbedding = 1;
|
||||
static const int sc_sinceVer_QLinearAveragePool = 1;
|
||||
static const int sc_sinceVer_QLinearGlobalAveragePool = 1;
|
||||
} // namespace MsftOperatorSet1
|
||||
|
||||
} // namespace OperatorHelper
|
||||
|
|
|
|||
|
|
@ -66,6 +66,9 @@ void RunQLinearGlobalAveragePool(
|
|||
test.AddInput<float>("y_scale", {}, {y_scale});
|
||||
test.AddInput<T8Bits>("y_zero_point", {}, {y_zero_point});
|
||||
test.AddOutput<T8Bits>("Y", y_dims, y_data);
|
||||
if (channels_last) {
|
||||
test.AddAttribute("channels_last", (int64_t)1LL);
|
||||
}
|
||||
|
||||
auto q8checker = [&](const std::vector<OrtValue>& fetches, const std::string& provider_type) {
|
||||
const OrtValue& ort_value = fetches[0];
|
||||
|
|
|
|||
Loading…
Reference in a new issue