diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 22b0d0ce60..8968d499d2 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -289,6 +289,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int32_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, QuantizeLinear); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, QLinearMatMul); @@ -446,6 +447,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Ca class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Clip); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Expand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Expand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, Expand); @@ -1066,6 +1068,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { DequantizeLinear)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo // formula is Y = (X - ZeroPoint) * Scale +template Status DequantizeLinear::Compute(OpKernelContext* ctx) const { auto& x = *ctx->Input(0); auto& x_scale = *ctx->Input(1); @@ -78,11 +80,19 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { PrepareForQDQ(x.Shape(), x_scale, x_zero_point, axis_, N, broadcast_dim, block_size); - const T* zero_point = x_zero_point ? x_zero_point->template Data() : nullptr; const float* scale = x_scale.template Data(); const T* input = x.template Data(); float* output = y.template MutableData(); + const T* zero_point = x_zero_point ? x_zero_point->template Data() : nullptr; + if (std::is_same::value) { + ORT_ENFORCE(zero_point == nullptr || + std::all_of(zero_point, + zero_point + x_zero_point->Shape().Size(), + [](int32_t zp) { return zp == 0; }), + "DequantizeLinear with type int32 should have no zero point or all zero points should be 0"); + } + for (size_t n = 0; n < static_cast(N); n++) { for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { auto zp = zero_point ? static_cast(zero_point[bd]) : 0; @@ -123,8 +133,8 @@ REGISTER_QUANTIZELINEAR(uint8_t) REGISTER_QUANTIZELINEAR_VERSIONED(int8_t) REGISTER_QUANTIZELINEAR_VERSIONED(uint8_t) -template // formula is Y = X / Scale + ZeroPoint +template Status QuantizeLinear::Compute(OpKernelContext* ctx) const { auto& x = *ctx->Input(0); auto& y_scale = *ctx->Input(1); diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 8a85330ce9..716b7240de 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -28,6 +28,16 @@ TEST(DequantizeLinearOpTest, Int8) { test.Run(); } +// scalar zero & scale with int8 +TEST(DequantizeLinearOpTest, Int32) { + OpTester test("DequantizeLinear", 10); + std::vector dims{4}; + test.AddInput("x", dims, {-30, -3, 100, 127}); + test.AddInput("x_scale", {}, {2.0f}); + test.AddOutput("y", dims, {-60.f, -6.f, 200.f, 254.f}); + test.Run(); +} + // 2d inputs TEST(DequantizeLinearOpTest, 2D) { OpTester test("DequantizeLinear", 10); @@ -134,7 +144,7 @@ TEST(DequantizeLinearOpTest, Per_Channel_Axis_0) { } // 1d zero & scale with int8 broadcast axis 1 -TEST(DequantizeLinearOpTest, Per_Channel_Axis_1) { +TEST(DequantizeLinearOpTest, Per_Channel_Axis_1_int8) { OpTester test("DequantizeLinear", 13); std::vector dims{3, 4}; test.AddInput("X", dims, @@ -151,6 +161,24 @@ TEST(DequantizeLinearOpTest, Per_Channel_Axis_1) { test.Run(); } +// 1d zero & scale with int32 broadcast axis 1 +TEST(DequantizeLinearOpTest, Per_Channel_Axis_1_int32) { + OpTester test("DequantizeLinear", 13); + std::vector dims{3, 4}; + test.AddInput("X", dims, + {0, 1, 2, 3, + 0, 2, 4, 6, + 0, 10, 20, 30}); + test.AddAttribute("axis", 1); + test.AddInput("scale", {4}, {1, 2, 4, 8}); + test.AddInput("zero_point", {4}, {0, 0, 0, 0}); + test.AddOutput("Y", dims, + {0, 2, 8, 24, + 0, 4, 16, 48, + 0, 20, 80, 240}); + test.Run(); +} + // 1d zero & scale with uint8 broadcast axis -2 (-2 resolves to axis 0) TEST(DequantizeLinearOpTest, Per_Channel_Neg_2) { OpTester test("DequantizeLinear", 13);