diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 46d9e217bf..d57394b3e7 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -587,7 +587,8 @@ Do not modify directly.*
|DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
-|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**
or
*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|19+|**T1** = tensor(float8e4m3fn), tensor(float8e5m2), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)|
+|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**
or
*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|21+|**T1** = tensor(float8e4m3fn), tensor(float8e5m2), tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)|
+|||[19, 20]|**T1** = tensor(float8e4m3fn), tensor(float8e5m2), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)|
|||[13, 18]|**T** = tensor(int8), tensor(uint8)|
|||[10, 12]|**T** = tensor(int8), tensor(uint8)|
|Div|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
@@ -718,7 +719,8 @@ Do not modify directly.*
|||[13, 14]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||[7, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
-|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**
or
*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|19+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e5m2), tensor(int8), tensor(uint8)|
+|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**
or
*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e5m2), tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)|
+|||[19, 20]|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e5m2), tensor(int8), tensor(uint8)|
|||[13, 18]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
|||[10, 12]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
|RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)|
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc
index f1b30da01f..adfa680878 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc
@@ -23,7 +23,10 @@ void SplitQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
const std::string action_name{"dropSplitQDQ"};
std::unique_ptr action = std::make_unique();
#if !defined(ORT_MINIMAL_BUILD)
- std::unique_ptr selector = std::make_unique(true /*req_equal_quant_params*/);
+ std::vector providers = {kCpuExecutionProvider, kDmlExecutionProvider};
+ std::unique_ptr selector = std::make_unique(true /*req_equal_quant_params*/,
+ false,
+ providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Split", {}}},
std::move(selector),
@@ -63,14 +66,18 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) {
//
// And cannot eliminate the QDQ for MaxPool if the scale is not positive, as a negative
// scale will change the ordering of the elements between quantized & de-quantized values.
- std::unique_ptr selector_no_16bit = std::make_unique(false);
+ std::vector providers = {kCpuExecutionProvider, kDmlExecutionProvider};
+ std::unique_ptr selector_no_16bit = std::make_unique(false,
+ false,
+ true,
+ providers);
qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_name,
{{"Resize", {}}},
std::move(selector_no_16bit),
std::move(drop_action_no_int16));
std::unique_ptr selector_no_16bit_and_positive_scale =
- std::make_unique(false, true, false);
+ std::make_unique(false, true, false, providers);
qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_and_positive_scale_name,
{{"MaxPool", {12}},
{"ReduceMax", {}},
@@ -78,7 +85,7 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) {
std::move(selector_no_16bit_and_positive_scale),
std::move(drop_action_no_int16_and_positive_scale));
- std::unique_ptr selector = std::make_unique(true);
+ std::unique_ptr selector = std::make_unique(true, false, true, providers);
// DepthToSpace and SpaceToDepth not included because there are no integer implementations.
// https://github.com/microsoft/onnxruntime/issues/21287
qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_name,
@@ -117,7 +124,8 @@ void DropDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) {
#if !defined(ORT_MINIMAL_BUILD)
// TODO: Enable 16-bit types in selector when ArgMax supports 16-bit integer input tensors.
- std::unique_ptr selector = std::make_unique();
+ std::vector providers = {kCpuExecutionProvider, kDmlExecutionProvider};
+ std::unique_ptr selector = std::make_unique(false, false, providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"ArgMax", {}}},
std::move(selector),
@@ -200,7 +208,8 @@ void VariadicOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
#if !defined(ORT_MINIMAL_BUILD)
// TODO: Enable 16-bit types in selector when QLinearConcat supports 16-bit.
- std::unique_ptr selector = std::make_unique();
+ std::vector providers = {kCpuExecutionProvider, kDmlExecutionProvider};
+ std::unique_ptr selector = std::make_unique(false, false, providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Concat", {}}},
@@ -222,7 +231,11 @@ void ConvQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool is_
#if !defined(ORT_MINIMAL_BUILD)
// TODO: Enable 16-bit types in selector when QLinearConv supports 16-bit.
- std::unique_ptr selector = std::make_unique(is_int8_allowed);
+ std::vector providers = {kCpuExecutionProvider, kDmlExecutionProvider};
+ std::unique_ptr selector = std::make_unique(is_int8_allowed,
+ false,
+ false,
+ providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Conv", {}}},
@@ -245,7 +258,11 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i
#if !defined(ORT_MINIMAL_BUILD)
// TODO: Enable 16-bit types in selector when QLinearMatMul and MatMulInteger support 16-bit.
- std::unique_ptr selector = std::make_unique(is_int8_allowed);
+ std::vector providers = {kCpuExecutionProvider, kDmlExecutionProvider};
+ std::unique_ptr selector = std::make_unique(is_int8_allowed,
+ false,
+ false,
+ providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"MatMul", {}}},
std::move(selector),
@@ -272,7 +289,8 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi
p_buffered_tensors);
#if !defined(ORT_MINIMAL_BUILD)
- std::unique_ptr selector = std::make_unique();
+ std::vector providers = {kCpuExecutionProvider, kCudaExecutionProvider};
+ std::unique_ptr selector = std::make_unique(providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"MatMul", {}}},
std::move(selector),
@@ -363,8 +381,9 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer(
CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level,
intra_op_thread_pool, p_buffered_tensors),
apply_context,
- // this transformer is only compatible with the CPU and DML EP
- {kCpuExecutionProvider, kDmlExecutionProvider}} {
+ // this transformer is compatible with CPU, DML and CUDA EP.
+ // There is further EP control on the rule level.
+ {kCpuExecutionProvider, kDmlExecutionProvider, kCudaExecutionProvider}} {
}
} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
index 7e009da394..0ba5436e69 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
@@ -302,14 +302,20 @@ class BaseSelector : public NodeSelector {
class DropQDQNodesSelector : public BaseSelector {
public:
- explicit DropQDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false, bool allow_nonpositive_scale = true)
- : BaseSelector(std::make_unique(allow_16bit, allow_4bit, allow_nonpositive_scale)) {}
+ explicit DropQDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false,
+ bool allow_nonpositive_scale = true,
+ gsl::span compatible_providers = {})
+ : BaseSelector(std::make_unique(allow_16bit, allow_4bit, allow_nonpositive_scale),
+ compatible_providers) {}
};
class DropDQNodesSelector : public BaseSelector {
public:
- explicit DropDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false)
- : BaseSelector(std::make_unique(allow_16bit, allow_4bit)) {}
+ explicit DropDQNodesSelector(bool allow_16bit = false,
+ bool allow_4bit = false,
+ gsl::span compatible_providers = {})
+ : BaseSelector(std::make_unique(allow_16bit, allow_4bit),
+ compatible_providers) {}
};
class UnarySelector : public BaseSelector {
@@ -329,8 +335,11 @@ class BinarySelector : public BaseSelector {
// Variadic DQ nodes -> node -> Q
class InputVariadicSelector : public BaseSelector {
public:
- explicit InputVariadicSelector(bool allow_16bit = false, bool allow_4bit = false)
- : BaseSelector(std::make_unique(allow_16bit, allow_4bit)) {}
+ explicit InputVariadicSelector(bool allow_16bit = false,
+ bool allow_4bit = false,
+ gsl::span compatible_providers = {})
+ : BaseSelector(std::make_unique(allow_16bit, allow_4bit),
+ compatible_providers) {}
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
};
@@ -338,8 +347,10 @@ class InputVariadicSelector : public BaseSelector {
// DQ -> Split -> variadic Q nodes
class SplitSelector : public BaseSelector {
public:
- SplitSelector(bool req_equal_quant_params = false, bool allow_4bit = false)
- : BaseSelector(std::make_unique(req_equal_quant_params, allow_4bit)) {}
+ SplitSelector(bool req_equal_quant_params = false, bool allow_4bit = false,
+ gsl::span compatible_providers = {})
+ : BaseSelector(std::make_unique(req_equal_quant_params, allow_4bit),
+ compatible_providers) {}
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
};
@@ -347,8 +358,10 @@ class SplitSelector : public BaseSelector {
// DQ nodes for X, W and optionally B -> node -> Q
class ConvSelector : public BaseSelector {
public:
- ConvSelector(bool int8_allowed = false, bool allow_16bit = false, bool allow_4bit_weight = false)
- : BaseSelector(std::make_unique(int8_allowed, allow_16bit, allow_4bit_weight)) {}
+ ConvSelector(bool int8_allowed = false, bool allow_16bit = false, bool allow_4bit_weight = false,
+ gsl::span compatible_providers = {})
+ : BaseSelector(std::make_unique(int8_allowed, allow_16bit, allow_4bit_weight),
+ compatible_providers) {}
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
};
@@ -363,9 +376,11 @@ class WhereSelector : public BaseSelector {
// 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not
class MatMulSelector : public BaseSelector {
public:
- MatMulSelector(bool int8_allowed, bool allow_16bit = false, bool allow_4bit = false)
+ MatMulSelector(bool int8_allowed, bool allow_16bit = false, bool allow_4bit = false,
+ gsl::span compatible_providers = {})
: BaseSelector(std::make_unique(int8_allowed, /*matmulintegertofloat_allowed*/ true,
- allow_16bit, allow_4bit)) {}
+ allow_16bit, allow_4bit),
+ compatible_providers) {}
};
// Convert "1 DQ node for input B -> MatMul" to "MatMulNBits"
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index f74754c3cd..b54c572556 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -4,6 +4,7 @@
#include "core/common/inlined_containers.h"
#include "core/common/parse_string.h"
+#include "core/framework/int4.h"
#include "core/providers/shared_library/provider_api.h"
#include "core/platform/env_var_utils.h"
#include "core/providers/cuda/cuda_execution_provider.h"
@@ -1348,38 +1349,37 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E5M2, Cast);
#endif
-class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint8_t, float, DequantizeLinear);
-class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int8_t, float, DequantizeLinear);
+class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, uint8_t, float, DequantizeLinear);
+class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, int8_t, float, DequantizeLinear);
#if !defined(DISABLE_FLOAT8_TYPES)
-class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E4M3FN, float, DequantizeLinear);
-class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E5M2, float, DequantizeLinear);
+class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, float, DequantizeLinear);
+class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, float, DequantizeLinear);
#endif
-class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint8_t, MLFloat16, DequantizeLinear);
-class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int8_t, MLFloat16, DequantizeLinear);
+class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, uint8_t, MLFloat16, DequantizeLinear);
+class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, int8_t, MLFloat16, DequantizeLinear);
#if !defined(DISABLE_FLOAT8_TYPES)
-class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E4M3FN, MLFloat16, DequantizeLinear);
-class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E5M2, MLFloat16, DequantizeLinear);
+class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, MLFloat16, DequantizeLinear);
+class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, MLFloat16, DequantizeLinear);
#endif
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Identity);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, If);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Loop);
-class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint8_t, float, QuantizeLinear);
-class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int8_t, float, QuantizeLinear);
+class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, uint8_t, float, QuantizeLinear);
+class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, int8_t, float, QuantizeLinear);
#if !defined(DISABLE_FLOAT8_TYPES)
-class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E4M3FN, float, QuantizeLinear);
-class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E5M2, float, QuantizeLinear);
+class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, float, QuantizeLinear);
+class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, float, QuantizeLinear);
#endif
-class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint8_t, MLFloat16, QuantizeLinear);
-class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int8_t, MLFloat16, QuantizeLinear);
+class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, uint8_t, MLFloat16, QuantizeLinear);
+class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, int8_t, MLFloat16, QuantizeLinear);
#if !defined(DISABLE_FLOAT8_TYPES)
-class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E4M3FN, MLFloat16, QuantizeLinear);
-class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E5M2, MLFloat16, QuantizeLinear);
+class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, MLFloat16, QuantizeLinear);
+class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, MLFloat16, QuantizeLinear);
#endif
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Reshape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Scan);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Shape);
-#endif
// Opset 20
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu);
@@ -1388,6 +1388,40 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN);
+// Opset 21.
+// TODO(fajin): support other quantized types
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, uint8_t, float, DequantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, int8_t, float, DequantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, uint8_t, MLFloat16, DequantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, int8_t, MLFloat16, DequantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, UInt4x2, float, DequantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Int4x2, float, DequantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, UInt4x2, MLFloat16, DequantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Int4x2, MLFloat16, DequantizeLinear);
+#if !defined(DISABLE_FLOAT8_TYPES)
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, float, DequantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, float, DequantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, MLFloat16, DequantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, MLFloat16, DequantizeLinear);
+#endif
+
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, uint8_t, float, QuantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, int8_t, float, QuantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, uint8_t, MLFloat16, QuantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, int8_t, MLFloat16, QuantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, UInt4x2, float, QuantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Int4x2, float, QuantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, UInt4x2, MLFloat16, QuantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Int4x2, MLFloat16, QuantizeLinear);
+#if !defined(DISABLE_FLOAT8_TYPES)
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, float, QuantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, float, QuantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, MLFloat16, QuantizeLinear);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, MLFloat16, QuantizeLinear);
+#endif
+
+#endif
+
template <>
KernelCreateInfo BuildKernelCreateInfo() {
return {};
@@ -2265,34 +2299,34 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
#endif
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#if !defined(DISABLE_FLOAT8_TYPES)
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#endif
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#if !defined(DISABLE_FLOAT8_TYPES)
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#endif
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#if !defined(DISABLE_FLOAT8_TYPES)
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#endif
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#if !defined(DISABLE_FLOAT8_TYPES)
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#endif
BuildKernelCreateInfo,
@@ -2305,6 +2339,37 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+
+ // Opset 21
+ // TODO(fajin): support other quantized types
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+#if !defined(DISABLE_FLOAT8_TYPES)
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+#endif
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+#if !defined(DISABLE_FLOAT8_TYPES)
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+#endif
#endif
};
diff --git a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cc b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cc
index d4b6d1bc49..6a5dbc433f 100644
--- a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cc
+++ b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cc
@@ -7,36 +7,181 @@
namespace onnxruntime {
namespace cuda {
+void ValidateBlockQuantizationShapes(const TensorShape& input_shape,
+ const TensorShape& scale_shape,
+ const Tensor* zero_point,
+ size_t axis_no_neg,
+ int64_t block_size_) {
+ ORT_ENFORCE(scale_shape.NumDimensions() == input_shape.NumDimensions(),
+ "scale and input must have the same rank for blocked quantization");
+
+ for (size_t i = 0, ndim = input_shape.NumDimensions(); i < ndim; ++i) {
+ if (i == static_cast(axis_no_neg)) {
+ ORT_ENFORCE(scale_shape[i] == (input_shape[i] + block_size_ - 1) / block_size_,
+ "scale must be ceil(Di/block_size) on the quantize axis i for blocked quantization");
+ } else {
+ ORT_ENFORCE(scale_shape[i] == input_shape[i],
+ "scale and input must have the same shape despite the quantize axis for blocked quantization");
+ }
+ }
+
+ if (zero_point) {
+ ORT_ENFORCE(zero_point->Shape() == scale_shape,
+ "zero_point and scale must have the same shape for blocked quantization");
+ }
+}
+
template
-typename std::enable_if, T>::value, Status>::type
-CudaQuantizeLinear(cudaStream_t stream, const U* input, T* output, const U* scale, const T* zero_point, size_t num_of_element, bool /*saturate*/) {
+Status CudaQuantizeLinear(cudaStream_t stream, const U* input, T* output, const U* scale, const T* zero_point,
+ size_t num_of_element, bool /*saturate*/) {
+ ORT_UNUSED_PARAMETER(stream);
+ ORT_UNUSED_PARAMETER(input);
+ ORT_UNUSED_PARAMETER(output);
+ ORT_UNUSED_PARAMETER(scale);
+ ORT_UNUSED_PARAMETER(zero_point);
+ ORT_UNUSED_PARAMETER(num_of_element);
+ ORT_NOT_IMPLEMENTED("Unsupported quantization type.");
+}
+
+template
+Status CudaQuantizeLinearAxis(cudaStream_t stream, const U* input, T* output, const U* scale, const T* zero_point,
+ size_t num_of_element, size_t batch_size, size_t n_scales, bool /*saturate*/) {
+ ORT_UNUSED_PARAMETER(stream);
+ ORT_UNUSED_PARAMETER(input);
+ ORT_UNUSED_PARAMETER(output);
+ ORT_UNUSED_PARAMETER(scale);
+ ORT_UNUSED_PARAMETER(zero_point);
+ ORT_UNUSED_PARAMETER(num_of_element);
+ ORT_UNUSED_PARAMETER(batch_size);
+ ORT_UNUSED_PARAMETER(n_scales);
+ ORT_NOT_IMPLEMENTED("Unsupported quantization type.");
+}
+
+template
+Status CudaQuantizeLinearBlock(cudaStream_t stream, const U* input, T* output, const U* scale, const T* zero_point,
+ size_t num_of_element, size_t K, size_t N, size_t block_size, bool /*saturate*/) {
+ ORT_UNUSED_PARAMETER(stream);
+ ORT_UNUSED_PARAMETER(input);
+ ORT_UNUSED_PARAMETER(output);
+ ORT_UNUSED_PARAMETER(scale);
+ ORT_UNUSED_PARAMETER(zero_point);
+ ORT_UNUSED_PARAMETER(num_of_element);
+ ORT_UNUSED_PARAMETER(K);
+ ORT_UNUSED_PARAMETER(N);
+ ORT_UNUSED_PARAMETER(block_size);
+ ORT_NOT_IMPLEMENTED("Unsupported quantization type.");
+}
+
+template
+Status CudaQuantizeLinear(cudaStream_t stream, const U* input, int8_t* output, const U* scale,
+ const int8_t* zero_point, size_t num_of_element, bool /*saturate*/) {
return CudaQuantizeLinearStd(stream, input, output, scale, zero_point, num_of_element);
}
+template
+Status CudaQuantizeLinear(cudaStream_t stream, const U* input, uint8_t* output, const U* scale,
+ const uint8_t* zero_point, size_t num_of_element, bool /*saturate*/) {
+ return CudaQuantizeLinearStd(stream, input, output, scale, zero_point, num_of_element);
+}
+
+template
+Status CudaQuantizeLinear(cudaStream_t stream, const U* input, Int4x2* output, const U* scale,
+ const Int4x2* zero_point, size_t num_of_element, bool /*saturate*/) {
+ return CudaQuantizeLinearStdInt4(stream, input, reinterpret_cast(output), scale,
+ zero_point ? reinterpret_cast(zero_point) : nullptr,
+ num_of_element);
+}
+
+template
+Status CudaQuantizeLinear(cudaStream_t stream, const U* input, UInt4x2* output, const U* scale,
+ const UInt4x2* zero_point, size_t num_of_element, bool /*saturate*/) {
+ return CudaQuantizeLinearStdInt4(stream, input, reinterpret_cast(output), scale,
+ zero_point ? reinterpret_cast(zero_point) : nullptr,
+ num_of_element);
+}
+
#if !defined(DISABLE_FLOAT8_TYPES)
-template
-typename std::enable_if, T>::value, Status>::type
-CudaQuantizeLinear(cudaStream_t stream, const U* input, T* output, const U* scale, const T* zero_point, size_t num_of_element, bool saturate) {
+template
+Status CudaQuantizeLinear(cudaStream_t stream, const U* input, Float8E4M3FN* output, const U* scale,
+ const Float8E4M3FN* zero_point, size_t num_of_element, bool saturate) {
return CudaQuantizeLinearSat(stream, input, output, scale, zero_point, num_of_element, saturate);
}
-template
-typename std::enable_if, T>::value, Status>::type
-CudaQuantizeLinearAxis(cudaStream_t stream, const U* input, T* output, const U* scale, const T* zero_point, size_t num_of_element,
- size_t batch_size, size_t n_scales, bool saturate) {
- return CudaQuantizeLinearAxisSat(stream, input, output, scale, zero_point, num_of_element, batch_size, n_scales, saturate);
+template
+Status CudaQuantizeLinear(cudaStream_t stream, const U* input, Float8E5M2* output, const U* scale,
+ const Float8E5M2* zero_point, size_t num_of_element, bool saturate) {
+ return CudaQuantizeLinearSat(stream, input, output, scale, zero_point, num_of_element, saturate);
+}
+
+template
+Status CudaQuantizeLinearAxis(cudaStream_t stream, const U* input, Float8E4M3FN* output, const U* scale,
+ const Float8E4M3FN* zero_point, size_t num_of_element,
+ size_t batch_size, size_t n_scales, bool saturate) {
+ return CudaQuantizeLinearAxisSat(stream, input, output, scale, zero_point, num_of_element, batch_size,
+ n_scales, saturate);
+}
+
+template
+Status CudaQuantizeLinearAxis(cudaStream_t stream, const U* input, Float8E5M2* output, const U* scale,
+ const Float8E5M2* zero_point, size_t num_of_element,
+ size_t batch_size, size_t n_scales, bool saturate) {
+ return CudaQuantizeLinearAxisSat(stream, input, output, scale, zero_point, num_of_element, batch_size,
+ n_scales, saturate);
}
#endif
-template
-typename std::enable_if, T>::value, Status>::type
-CudaQuantizeLinearAxis(cudaStream_t stream, const U* input, T* output, const U* scale, const T* zero_point, size_t num_of_element,
- size_t batch_size, size_t n_scales, bool /*saturate*/) {
+template
+Status CudaQuantizeLinearAxis(cudaStream_t stream, const U* input, int8_t* output, const U* scale,
+ const int8_t* zero_point, size_t num_of_element,
+ size_t batch_size, size_t n_scales, bool /*saturate*/) {
return CudaQuantizeLinearAxisStd(stream, input, output, scale, zero_point, num_of_element, batch_size, n_scales);
}
+template
+Status CudaQuantizeLinearAxis(cudaStream_t stream, const U* input, uint8_t* output, const U* scale,
+ const uint8_t* zero_point, size_t num_of_element,
+ size_t batch_size, size_t n_scales, bool /*saturate*/) {
+ return CudaQuantizeLinearAxisStd(stream, input, output, scale, zero_point, num_of_element, batch_size, n_scales);
+}
+
+template
+Status CudaQuantizeLinearAxis(cudaStream_t stream, const U* input, Int4x2* output, const U* scale,
+ const Int4x2* zero_point, size_t num_of_element,
+ size_t batch_size, size_t n_scales, bool /*saturate*/) {
+ return CudaQuantizeLinearAxisStdInt4(stream, input, reinterpret_cast(output), scale,
+ zero_point ? reinterpret_cast(zero_point) : nullptr,
+ num_of_element, batch_size, n_scales);
+}
+
+template
+Status CudaQuantizeLinearAxis(cudaStream_t stream, const U* input, UInt4x2* output, const U* scale,
+ const UInt4x2* zero_point, size_t num_of_element,
+ size_t batch_size, size_t n_scales, bool /*saturate*/) {
+ return CudaQuantizeLinearAxisStdInt4(stream, input, reinterpret_cast(output), scale,
+ zero_point ? reinterpret_cast(zero_point) : nullptr,
+ num_of_element, batch_size, n_scales);
+}
+
+template
+Status CudaQuantizeLinearBlock(cudaStream_t stream,
+ const U* input, Int4x2* output, const U* scale, const Int4x2* zero_point,
+ size_t num_of_element, size_t K, size_t N, size_t block_size, bool /*saturate*/) {
+ return CudaQuantizeLinearBlockStdInt4(stream, input, reinterpret_cast(output), scale,
+ zero_point ? reinterpret_cast(zero_point) : nullptr,
+ num_of_element, K, N, block_size);
+}
+
+template
+Status CudaQuantizeLinearBlock(cudaStream_t stream,
+ const U* input, UInt4x2* output, const U* scale, const UInt4x2* zero_point,
+ size_t num_of_element, size_t K, size_t N, size_t block_size, bool /*saturate*/) {
+ return CudaQuantizeLinearBlockStdInt4(stream, input, reinterpret_cast(output), scale,
+ zero_point ? reinterpret_cast(zero_point) : nullptr,
+ num_of_element, K, N, block_size);
+}
+
template
Status QuantizeLinear::ComputeInternal(OpKernelContext* ctx) const {
typedef typename ToCudaType::MappedType CudaU;
@@ -48,21 +193,22 @@ Status QuantizeLinear::ComputeInternal(OpKernelContext* ctx) const {
auto& y = *ctx->Output(0, x.Shape());
const auto& x_shape = x.Shape();
+ const auto num_of_elements = x_shape.Size();
const CudaU* input = reinterpret_cast(x.Data());
T* output = y.MutableData();
- if (IsScalarOr1ElementVector(&y_scale)) {
+ if (IsScalarOr1ElementVector(&y_scale)) { // per-tensor quantization
ORT_ENFORCE(y_zero_point == nullptr || IsScalarOr1ElementVector(y_zero_point),
"y_zero_point must be a scalar or 1D tensor of size 1.");
+ ORT_ENFORCE(block_size_ == 0, "block_size must be 0 for per-tensor quantization.");
const T* zero_point = y_zero_point != nullptr ? y_zero_point->Data() : nullptr;
const CudaU* scale = reinterpret_cast(y_scale.Data());
- const auto num_of_elements = x_shape.Size();
ORT_RETURN_IF_ERROR(CudaQuantizeLinear(Stream(ctx), input, output, scale, zero_point, num_of_elements, saturate_));
return Status::OK();
- } else {
+ } else if (block_size_ == 0) { // per-axis quantization
ORT_ENFORCE(y_scale.Shape().NumDimensions() == 1);
ORT_ENFORCE(y_zero_point == nullptr || (y_scale.Shape().Size() == y_zero_point->Shape().Size() &&
y_zero_point->Shape().NumDimensions() == 1),
@@ -73,44 +219,184 @@ Status QuantizeLinear::ComputeInternal(OpKernelContext* ctx) const {
const T* zero_point = y_zero_point != nullptr ? y_zero_point->Data() : nullptr;
const CudaU* scale = reinterpret_cast(y_scale.Data());
- const auto num_of_elements = x_shape.Size();
ORT_RETURN_IF_ERROR(CudaQuantizeLinearAxis(Stream(ctx), input, output, scale, zero_point, num_of_elements,
x_shape.SizeToDimension(axis), y_scale.Shape().Size(), saturate_));
return Status::OK();
+ } else { // blocked quantization
+ // validate shape
+ size_t axis_no_neg = SafeInt(HandleNegativeAxis(axis_, x_shape.NumDimensions()));
+ const auto& y_scale_shape = y_scale.Shape();
+
+ ValidateBlockQuantizationShapes(x_shape,
+ y_scale_shape,
+ y_zero_point,
+ axis_no_neg,
+ block_size_);
+
+ // compute
+ const T* zero_point = y_zero_point ? y_zero_point->Data() : nullptr;
+ const CudaU* scale = reinterpret_cast(y_scale.Data());
+
+ ORT_RETURN_IF_ERROR(CudaQuantizeLinearBlock(Stream(ctx), input, output, scale, zero_point,
+ num_of_elements, x_shape[axis_no_neg],
+ x_shape.SizeFromDimension(axis_no_neg + 1),
+ block_size_, saturate_));
+ return Status::OK();
}
}
template
-typename std::enable_if, T>::value, Status>::type
-CudaDequantizeLinear(cudaStream_t stream, const T* input, U* output, const U* scale, const T* zero_point, size_t num_of_element) {
+Status CudaDequantizeLinear(cudaStream_t stream, const T* input, U* output, const U* scale,
+ const T* zero_point, size_t num_of_element) {
+ ORT_UNUSED_PARAMETER(stream);
+ ORT_UNUSED_PARAMETER(input);
+ ORT_UNUSED_PARAMETER(output);
+ ORT_UNUSED_PARAMETER(scale);
+ ORT_UNUSED_PARAMETER(zero_point);
+ ORT_UNUSED_PARAMETER(num_of_element);
+ ORT_NOT_IMPLEMENTED("Unsupported quantization type.");
+}
+
+template
+Status CudaDequantizeLinearAxis(cudaStream_t stream, const T* input, U* output, const U* scale,
+ const T* zero_point, size_t num_of_element,
+ size_t batch_size, size_t n_scales) {
+ ORT_UNUSED_PARAMETER(stream);
+ ORT_UNUSED_PARAMETER(input);
+ ORT_UNUSED_PARAMETER(output);
+ ORT_UNUSED_PARAMETER(scale);
+ ORT_UNUSED_PARAMETER(zero_point);
+ ORT_UNUSED_PARAMETER(num_of_element);
+ ORT_UNUSED_PARAMETER(batch_size);
+ ORT_UNUSED_PARAMETER(n_scales);
+ ORT_NOT_IMPLEMENTED("Unsupported quantization type.");
+}
+
+template
+Status CudaDequantizeLinearBlockInt4(cudaStream_t stream, const T* input, U* output, const U* scale,
+ const T* zero_point, size_t num_of_element, size_t K, size_t N,
+ size_t block_size) {
+ ORT_UNUSED_PARAMETER(stream);
+ ORT_UNUSED_PARAMETER(input);
+ ORT_UNUSED_PARAMETER(output);
+ ORT_UNUSED_PARAMETER(scale);
+ ORT_UNUSED_PARAMETER(zero_point);
+ ORT_UNUSED_PARAMETER(num_of_element);
+ ORT_UNUSED_PARAMETER(K);
+ ORT_UNUSED_PARAMETER(N);
+ ORT_UNUSED_PARAMETER(block_size);
+ ORT_NOT_IMPLEMENTED("Unsupported quantization type.");
+}
+
+template
+Status CudaDequantizeLinear(cudaStream_t stream, const int8_t* input, U* output, const U* scale,
+ const int8_t* zero_point, size_t num_of_element) {
return CudaDequantizeLinearStd(stream, input, output, scale, zero_point, num_of_element);
}
+template
+Status CudaDequantizeLinear(cudaStream_t stream, const uint8_t* input, U* output, const U* scale,
+ const uint8_t* zero_point, size_t num_of_element) {
+ return CudaDequantizeLinearStd(stream, input, output, scale, zero_point, num_of_element);
+}
+
+template
+Status CudaDequantizeLinear(cudaStream_t stream, const Int4x2* input, U* output, const U* scale,
+ const Int4x2* zero_point, size_t num_of_element) {
+ return CudaDequantizeLinearStdInt4(stream, reinterpret_cast(input), output, scale,
+ zero_point ? reinterpret_cast(zero_point) : nullptr,
+ num_of_element);
+}
+
+template
+Status CudaDequantizeLinear(cudaStream_t stream, const UInt4x2* input, U* output, const U* scale,
+ const UInt4x2* zero_point, size_t num_of_element) {
+ return CudaDequantizeLinearStdInt4(stream, reinterpret_cast(input), output, scale,
+ zero_point ? reinterpret_cast(zero_point) : nullptr,
+ num_of_element);
+}
+
#if !defined(DISABLE_FLOAT8_TYPES)
-template
-typename std::enable_if, T>::value, Status>::type
-CudaDequantizeLinear(cudaStream_t stream, const T* input, U* output, const U* scale, const T* zero_point, size_t num_of_element) {
+template
+Status CudaDequantizeLinear(cudaStream_t stream, const Float8E4M3FN* input, U* output, const U* scale,
+ const Float8E4M3FN* zero_point, size_t num_of_element) {
+ return CudaDequantizeLinearSat(stream, input, output, scale, zero_point, num_of_element);
+}
+
+template
+Status CudaDequantizeLinear(cudaStream_t stream, const Float8E5M2* input, U* output, const U* scale,
+ const Float8E5M2* zero_point, size_t num_of_element) {
return CudaDequantizeLinearSat(stream, input, output, scale, zero_point, num_of_element);
}
#endif
-template
-typename std::enable_if, T>::value, Status>::type
-CudaDequantizeLinearAxis(cudaStream_t stream, const T* input, U* output, const U* scale, const T* zero_point, size_t num_of_element,
- size_t batch_size, size_t n_scales) {
+template
+Status CudaDequantizeLinearAxis(cudaStream_t stream, const int8_t* input, U* output, const U* scale,
+ const int8_t* zero_point, size_t num_of_element,
+ size_t batch_size, size_t n_scales) {
return CudaDequantizeLinearAxisStd(stream, input, output, scale, zero_point, num_of_element, batch_size, n_scales);
}
+template
+Status CudaDequantizeLinearAxis(cudaStream_t stream, const uint8_t* input, U* output, const U* scale,
+ const uint8_t* zero_point, size_t num_of_element,
+ size_t batch_size, size_t n_scales) {
+ return CudaDequantizeLinearAxisStd(stream, input, output, scale, zero_point, num_of_element, batch_size, n_scales);
+}
+
+template
+Status CudaDequantizeLinearAxis(cudaStream_t stream, const Int4x2* input, U* output, const U* scale,
+ const Int4x2* zero_point, size_t num_of_element,
+ size_t batch_size, size_t n_scales) {
+ return CudaDequantizeLinearAxisStdInt4(stream, reinterpret_cast(input), output, scale,
+ zero_point ? reinterpret_cast(zero_point) : nullptr,
+ num_of_element, batch_size, n_scales);
+}
+
+template
+Status CudaDequantizeLinearAxis(cudaStream_t stream, const UInt4x2* input, U* output, const U* scale,
+ const UInt4x2* zero_point, size_t num_of_element,
+ size_t batch_size, size_t n_scales) {
+ return CudaDequantizeLinearAxisStdInt4(stream, reinterpret_cast(input), output, scale,
+ zero_point ? reinterpret_cast(zero_point) : nullptr,
+ num_of_element, batch_size, n_scales);
+}
+
#if !defined(DISABLE_FLOAT8_TYPES)
-template
-typename std::enable_if, T>::value, Status>::type
-CudaDequantizeLinearAxis(cudaStream_t stream, const T* input, U* output, const U* scale, const T* zero_point, size_t num_of_element,
- size_t batch_size, size_t n_scales) {
+template
+Status CudaDequantizeLinearAxis(cudaStream_t stream, const Float8E4M3FN* input, U* output, const U* scale,
+ const Float8E4M3FN* zero_point, size_t num_of_element,
+ size_t batch_size, size_t n_scales) {
+ return CudaDequantizeLinearAxisSat(stream, input, output, scale, zero_point, num_of_element, batch_size, n_scales);
+}
+
+template
+Status CudaDequantizeLinearAxis(cudaStream_t stream, const Float8E5M2* input, U* output, const U* scale,
+ const Float8E5M2* zero_point, size_t num_of_element,
+ size_t batch_size, size_t n_scales) {
return CudaDequantizeLinearAxisSat(stream, input, output, scale, zero_point, num_of_element, batch_size, n_scales);
}
#endif
+template
+Status CudaDequantizeLinearBlockInt4(cudaStream_t stream, const UInt4x2* input, U* output, const U* scale,
+ const UInt4x2* zero_point, size_t num_of_element, size_t K, size_t N,
+ size_t block_size) {
+ return CudaDequantizeLinearBlockStdInt4(stream, reinterpret_cast(input), output, scale,
+ zero_point ? reinterpret_cast(zero_point) : nullptr,
+ num_of_element, K, N, block_size);
+}
+
+template
+Status CudaDequantizeLinearBlockInt4(cudaStream_t stream, const Int4x2* input, U* output, const U* scale,
+ const Int4x2* zero_point, size_t num_of_element, size_t K, size_t N,
+ size_t block_size) {
+ return CudaDequantizeLinearBlockStdInt4(stream, reinterpret_cast(input), output, scale,
+ zero_point ? reinterpret_cast(zero_point) : nullptr,
+ num_of_element, K, N, block_size);
+}
+
template
Status DequantizeLinear::ComputeInternal(OpKernelContext* ctx) const {
typedef typename ToCudaType::MappedType CudaU;
@@ -120,6 +406,7 @@ Status DequantizeLinear::ComputeInternal(OpKernelContext* ctx) const {
auto* y_zero_point = ctx->Input(2);
const auto& x_shape = x.Shape();
+ const auto num_of_elements = x_shape.Size();
auto& y = *ctx->Output(0, x_shape);
@@ -131,12 +418,11 @@ Status DequantizeLinear::ComputeInternal(OpKernelContext* ctx) const {
const T* zero_point = y_zero_point != nullptr ? y_zero_point->Data() : nullptr;
const CudaU* scale = reinterpret_cast(y_scale.Data());
- const auto num_of_elements = x_shape.Size();
ORT_RETURN_IF_ERROR(CudaDequantizeLinear(Stream(ctx), input, output, scale, zero_point, num_of_elements));
return Status::OK();
- } else {
+ } else if (block_size_ == 0) { // per axis quantization
ORT_ENFORCE(y_scale.Shape().NumDimensions() == 1);
ORT_ENFORCE(y_zero_point == nullptr || (y_scale.Shape().Size() == y_zero_point->Shape().Size() && y_zero_point->Shape().NumDimensions() == 1), "scale and zero_point must have the same shape.");
ORT_ENFORCE(x_shape.NumDimensions() > 1);
@@ -145,11 +431,31 @@ Status DequantizeLinear::ComputeInternal(OpKernelContext* ctx) const {
const T* zero_point = y_zero_point != nullptr ? y_zero_point->Data() : nullptr;
const CudaU* scale = reinterpret_cast(y_scale.Data());
- const auto num_of_elements = x_shape.Size();
ORT_RETURN_IF_ERROR(CudaDequantizeLinearAxis(Stream(ctx), input, output, scale, zero_point, num_of_elements,
x_shape.SizeToDimension(axis), y_scale.Shape().Size()));
return Status::OK();
+ } else { // blocked quantization
+ // validate shape
+ auto axis_no_neg = SafeInt(HandleNegativeAxis(axis_, x_shape.NumDimensions()));
+ const auto& y_scale_shape = y_scale.Shape();
+
+ ValidateBlockQuantizationShapes(x_shape,
+ y_scale_shape,
+ y_zero_point,
+ axis_no_neg,
+ block_size_);
+
+ // compute
+ const T* zero_point = y_zero_point ? y_zero_point->Data() : nullptr;
+ const CudaU* scale = reinterpret_cast(y_scale.Data());
+
+ ORT_RETURN_IF_ERROR(CudaDequantizeLinearBlockInt4(Stream(ctx), input, output, scale, zero_point,
+ num_of_elements, x_shape[axis_no_neg],
+ x_shape.SizeFromDimension(axis_no_neg + 1),
+ block_size_));
+
+ return Status::OK();
}
}
@@ -183,33 +489,54 @@ REGISTER_Q_KERNEL_TYPED_10_12(uint8_t)
REGISTER_Q_KERNEL_TYPED_13_18(int8_t)
REGISTER_Q_KERNEL_TYPED_13_18(uint8_t)
-#define REGISTER_Q_KERNEL_TYPED_19(T) \
- ONNX_OPERATOR_TWO_TYPED_KERNEL_EX( \
- QuantizeLinear, \
- kOnnxDomain, \
- 19, \
- T, float, \
- kCudaExecutionProvider, \
- (*KernelDefBuilder::Create()) \
- .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
- QuantizeLinear); \
- ONNX_OPERATOR_TWO_TYPED_KERNEL_EX( \
- QuantizeLinear, \
- kOnnxDomain, \
- 19, \
- T, MLFloat16, \
- kCudaExecutionProvider, \
- (*KernelDefBuilder::Create()) \
- .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
- QuantizeLinear);
+#define REGISTER_Q_KERNEL_TWO_TYPED_19_20(T, U) \
+ ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX( \
+ QuantizeLinear, \
+ kOnnxDomain, \
+ 19, 20, \
+ T, U, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
+ QuantizeLinear);
-REGISTER_Q_KERNEL_TYPED_19(int8_t)
-REGISTER_Q_KERNEL_TYPED_19(uint8_t)
+REGISTER_Q_KERNEL_TWO_TYPED_19_20(int8_t, float)
+REGISTER_Q_KERNEL_TWO_TYPED_19_20(uint8_t, float)
+REGISTER_Q_KERNEL_TWO_TYPED_19_20(int8_t, MLFloat16)
+REGISTER_Q_KERNEL_TWO_TYPED_19_20(uint8_t, MLFloat16)
#if !defined(DISABLE_FLOAT8_TYPES)
-REGISTER_Q_KERNEL_TYPED_19(Float8E4M3FN)
-REGISTER_Q_KERNEL_TYPED_19(Float8E5M2)
+REGISTER_Q_KERNEL_TWO_TYPED_19_20(Float8E4M3FN, float)
+REGISTER_Q_KERNEL_TWO_TYPED_19_20(Float8E5M2, float)
+REGISTER_Q_KERNEL_TWO_TYPED_19_20(Float8E4M3FN, MLFloat16)
+REGISTER_Q_KERNEL_TWO_TYPED_19_20(Float8E5M2, MLFloat16)
+#endif
+
+#define REGISTER_Q_KERNEL_TWO_TYPED_21(T, U) \
+ ONNX_OPERATOR_TWO_TYPED_KERNEL_EX( \
+ QuantizeLinear, \
+ kOnnxDomain, \
+ 21, \
+ T, U, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
+ QuantizeLinear);
+
+REGISTER_Q_KERNEL_TWO_TYPED_21(uint8_t, float)
+REGISTER_Q_KERNEL_TWO_TYPED_21(int8_t, float)
+REGISTER_Q_KERNEL_TWO_TYPED_21(uint8_t, MLFloat16)
+REGISTER_Q_KERNEL_TWO_TYPED_21(int8_t, MLFloat16)
+REGISTER_Q_KERNEL_TWO_TYPED_21(UInt4x2, float)
+REGISTER_Q_KERNEL_TWO_TYPED_21(Int4x2, float)
+REGISTER_Q_KERNEL_TWO_TYPED_21(UInt4x2, MLFloat16)
+REGISTER_Q_KERNEL_TWO_TYPED_21(Int4x2, MLFloat16)
+#if !defined(DISABLE_FLOAT8_TYPES)
+REGISTER_Q_KERNEL_TWO_TYPED_21(Float8E4M3FN, float)
+REGISTER_Q_KERNEL_TWO_TYPED_21(Float8E5M2, float)
+REGISTER_Q_KERNEL_TWO_TYPED_21(Float8E4M3FN, MLFloat16)
+REGISTER_Q_KERNEL_TWO_TYPED_21(Float8E5M2, MLFloat16)
#endif
// register DequantizeLinear kernels
@@ -240,33 +567,54 @@ REGISTER_DQ_KERNEL_TYPED_10_12(uint8_t)
REGISTER_DQ_KERNEL_TYPED_13_18(int8_t)
REGISTER_DQ_KERNEL_TYPED_13_18(uint8_t)
-#define REGISTER_DQ_KERNEL_TYPED_19(T) \
- ONNX_OPERATOR_TWO_TYPED_KERNEL_EX( \
- DequantizeLinear, \
- kOnnxDomain, \
- 19, \
- T, float, \
- kCudaExecutionProvider, \
- (*KernelDefBuilder::Create()) \
- .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
- DequantizeLinear); \
- ONNX_OPERATOR_TWO_TYPED_KERNEL_EX( \
- DequantizeLinear, \
- kOnnxDomain, \
- 19, \
- T, MLFloat16, \
- kCudaExecutionProvider, \
- (*KernelDefBuilder::Create()) \
- .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
- DequantizeLinear);
+#define REGISTER_DQ_KERNEL_TWO_TYPED_19_20(T, U) \
+ ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX( \
+ DequantizeLinear, \
+ kOnnxDomain, \
+ 19, 20, \
+ T, U, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
+ DequantizeLinear