From 6949cfaf94f97e93463d77b304a3a5cdb19aa1de Mon Sep 17 00:00:00 2001
From: Jeff Bloomfield <38966965+jeffbloo@users.noreply.github.com>
Date: Thu, 15 Jun 2023 18:21:56 -0700
Subject: [PATCH] =?UTF-8?q?Fix=20MS=20domain=20QuantizeLinear=20and=20Dequ?=
=?UTF-8?q?antizeLinear=20type=20registrations=20=E2=80=A6=20(#16298)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This fixes the type lists used to register DML kernels for Microsoft
domain QuantizeLinear and DequantizeLinear. These previously did not
include FP16 and incorrectly used the same type list for both operators.
The new type lists are the same as opset 19 ONNX which aren't
implemented yet in the DML EP.
---
docs/OperatorKernels.md | 4 ++--
.../src/Operators/OperatorRegistration.cpp | 7 ++++---
2 files changed, 6 insertions(+), 5 deletions(-)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 98d21d85f8..23de8f1504 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -1223,7 +1223,7 @@ Do not modify directly.*
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasSplitGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
-|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)|
+|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)|
|EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
@@ -1233,7 +1233,7 @@ Do not modify directly.*
|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
-|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)|
+|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)|
|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
| |
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index e2156a2ba7..b3533ea60c 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -510,8 +510,9 @@ constexpr static std::array supportedTypeListScatte
constexpr static std::array supportedTypeListScatterGatherND = { SupportedTensorDataTypes::AllScalars };
constexpr static std::array supportedTypeListSlice10 = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
constexpr static std::array supportedTypeListQuantizeLinear = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
+constexpr static std::array supportedTypeListQuantizeLinear19 = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
constexpr static std::array supportedTypeListDequantizeLinear = { SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 };
-constexpr static std::array supportedTypeListQuantize = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 };
+constexpr static std::array supportedTypeListDequantizeLinear19 = { SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::Float16to32 };
constexpr static std::array supportedTypeListIsNan = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
constexpr static std::array supportedTypeListIsInf = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool };
constexpr static std::array supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::AllScalars };
@@ -762,8 +763,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 13, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 10, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 13, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
- {REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)},
- {REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)},
+ {REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)},
+ {REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)},
{REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)},
{REG_INFO( 13, Sign, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)},
{REG_INFO( 9, IsNaN, typeNameListTwo, supportedTypeListIsNan, DmlGraphSupport::Supported)},