diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index 1c180173aab..b1bf5f4f80d 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -1254,11 +1254,11 @@ class QLinearOnednn final { Tensor onednn_weight, // int8 tensor from MkldnnCPU Tensor weight_scales, Tensor weight_zero_points, + std::optional other, // extra input for binary post-op std::optional bias, double output_scale, int64_t output_zero_point, std::optional output_dtype, - std::optional other, // extra input for binary post-op double other_scale, int64_t other_zero_point, c10::string_view binary_post_op, // e.g. "none", "sum", "add" @@ -1286,11 +1286,11 @@ class QLinearOnednn final { Tensor onednn_weight, // int8 tensor from MkldnnCPU Tensor weight_scales, Tensor weight_zero_points, + std::optional other, // extra input for binary post-op std::optional bias, double output_scale, int64_t output_zero_point, std::optional output_dtype, - std::optional other, // extra input for binary post-op double other_scale, int64_t other_zero_point, c10::string_view binary_post_op, // e.g. "none", "sum", "add" diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 21ff400a563..15afd66db0a 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -272,6 +272,6 @@ TORCH_LIBRARY(onednn, m) { m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor")); // Linear with binary postop - m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, Tensor? other, float other_scale, int other_zp, str binary_post_op, float binary_alpha, str unary_post_op, Scalar?[] unary_post_op_args, str unary_post_op_algorithm) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.binary_tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, Tensor? other, float other_scale, int other_zp, str binary_post_op, float binary_alpha, str unary_post_op, Scalar?[] unary_post_op_args, str unary_post_op_algorithm) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? other, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, float other_scale, int other_zp, str binary_post_op, float binary_alpha, str unary_post_op, Scalar?[] unary_post_op_args, str unary_post_op_algorithm) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.binary_tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? other, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, float other_scale, int other_zp, str binary_post_op, float binary_alpha, str unary_post_op, Scalar?[] unary_post_op_args, str unary_post_op_algorithm) -> Tensor")); } diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 189155f69f8..002fd7691b6 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -141,6 +141,8 @@ ALLOW_LIST = [ ("onednn::qconv2d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv3d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv2d_pointwise.binary", datetime.date(2024, 12, 31)), + ("onednn::qlinear_pointwise.binary", datetime.date(2024, 12, 31)), + ("onednn::qlinear_pointwise.binary_tensor", datetime.date(2024, 12, 31)), ("aten::_scaled_mm.out", datetime.date(2024, 12, 31)), ("aten::_scaled_mm", datetime.date(2024, 12, 31)), # BC-breaking change in can_cast signature: 'from' -> 'from_' diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 25b062a7ab1..92f9ee0988b 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -4332,8 +4332,8 @@ class TestQuantizedLinear(TestCase): accum = accum.bfloat16() qy_cpu = qlinear_op( qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps, - b, used_y_scale, used_y_zp, output_dtype, - accum, x2_scale, x2_zp, "sum", binary_alpha, + accum, b, used_y_scale, used_y_zp, output_dtype, + x2_scale, x2_zp, "sum", binary_alpha, unary_post_op, unary_post_op_args, post_op_algo ) y_ref = y_ref + x2 * binary_alpha @@ -4350,8 +4350,8 @@ class TestQuantizedLinear(TestCase): binary_alpha = 1.0 # we only support alpha=1.0 now qy_cpu = qlinear_op( qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps, - b, used_y_scale, used_y_zp, output_dtype, - x2, 1.0, 0, "add", binary_alpha, + x2, b, used_y_scale, used_y_zp, output_dtype, + 1.0, 0, "add", binary_alpha, unary_post_op, unary_post_op_args, post_op_algo ) y_ref = y_ref + x2 * binary_alpha diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index a31c15a85ee..570fdabb24a 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -505,11 +505,11 @@ def _register_quantized_linear_binary_lowering( packed_weight, w_scale, w_zp, + x2, b, o_inv_scale, o_zero_point, output_dtype, - x2, x2_scale, x2_zp, binary_op_name, diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index a9259b85c00..49fd7063061 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -1374,11 +1374,11 @@ class QLinearPointwiseBinaryPT2E(ExternKernelAlloc): at::Tensor weight, at::Tensor weight_scales, at::Tensor weight_zero_points, + c10::optional other, c10::optional bias, double inv_output_scale, int64_t output_zero_point, c10::optional output_dtype, - c10::optional other, double other_scale, int64_t other_zero_point, c10::string_view binary_post_op, @@ -1436,11 +1436,11 @@ class QLinearPointwiseBinaryPT2E(ExternKernelAlloc): packed_weight, w_scale, w_zp, + other, bias, o_scale, o_zp, output_dtype, - other, other_scale, other_zp, binary_attr, @@ -1470,11 +1470,11 @@ class QLinearPointwiseBinaryPT2E(ExternKernelAlloc): qw: "TensorBox", # packed_weight w_scale: "TensorBox", w_zero_point: "TensorBox", + other: "TensorBox", bias: "TensorBox", output_scale: float, output_zero_point: int, output_dtype, - other: "TensorBox", other_scale, other_zp, binary_post_op, diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 60e1e67c525..105a3eac90c 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -860,11 +860,11 @@ def register_onednn_fusion_ops(): packed_weight: TensorBox, w_scale: TensorBox, w_zp: TensorBox, + x2: TensorBox, bias: TensorBox, o_inv_scale, o_zero_point, output_dtype, - x2: TensorBox, x2_scale, x2_zp, binary_attr, @@ -896,11 +896,11 @@ def register_onednn_fusion_ops(): packed_weight, w_scale, w_zp, + x2, bias, o_inv_scale, o_zero_point, output_dtype, - x2, x2_scale, x2_zp, binary_attr,