From 691f37d6b354d2d16da4db1104d3c35d8439d585 Mon Sep 17 00:00:00 2001 From: Yan Zhiwei Date: Mon, 10 Feb 2025 04:25:53 -0800 Subject: [PATCH] [Intel GPU] qlinear_pointwise.binary[_tensor] XPU support ghstack-source-id: d0b52577b45b22f6c5936396de82f386cc5a4e22 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135337 --- aten/src/ATen/native/mkldnn/xpu/qlinear.cpp | 120 ++++++++++++++++++ test/inductor/test_mkldnn_pattern_matcher.py | 66 +++++++++- .../quantizer/xpu_inductor_quantizer.py | 17 --- 3 files changed, 180 insertions(+), 23 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/xpu/qlinear.cpp b/aten/src/ATen/native/mkldnn/xpu/qlinear.cpp index 457c269276e..0b6dc8bbbfc 100644 --- a/aten/src/ATen/native/mkldnn/xpu/qlinear.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/qlinear.cpp @@ -125,6 +125,120 @@ Tensor q_linear_pointwise_tensor( return qout; } +Tensor q_linear_pointwise_binary( + Tensor act, + double act_scale, + int64_t act_zero_point, + Tensor weight, + Tensor weight_scales, + Tensor weight_zero_points, + std::optional other, + std::optional bias, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + double other_scale, + int64_t other_zero_point, + c10::string_view binary_post_op, + double binary_alpha, + c10::string_view unary_post_op, + torch::List> unary_post_op_args, + c10::string_view unary_post_op_algorithm) { + Tensor b_raw = bias.has_value() ? bias.value() : at::Tensor(); + + const int64_t dim = act.dim(); + int64_t K = act.size(dim - 1); + int64_t M = act.numel() / K; + // [M, K] x [K, N] + int64_t N = weight.size(1); + + std::vector src_dims = {M, K}; + std::vector dst_dims = {M, N}; + auto out_dtype = + output_dtype.has_value() ? output_dtype.value() : act.scalar_type(); + Tensor qout = at::empty(dst_dims, device(c10::kXPU).dtype(out_dtype)); + + quantized_matmul( + act.contiguous(), + act_scale, + act_zero_point, + weight.contiguous(), + weight_scales, + weight_zero_points, + b_raw, + qout, + output_scale, + output_zero_point, + output_dtype, + /*other*/ other, + /*other scale*/ other_scale, + /*other zp*/ other_zero_point, + /*binary post op*/ binary_post_op, + /*binary alpha*/ binary_alpha, + unary_post_op, + unary_post_op_args, + unary_post_op_algorithm); + + return qout; +} + +Tensor q_linear_pointwise_binary_tensor( + Tensor act, + Tensor act_scale, + Tensor act_zero_point, + Tensor weight, + Tensor weight_scales, + Tensor weight_zero_points, + std::optional other, + std::optional bias, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + double other_scale, + int64_t other_zero_point, + c10::string_view binary_post_op, + double binary_alpha, + c10::string_view unary_post_op, + torch::List> unary_post_op_args, + c10::string_view unary_post_op_algorithm) { + Tensor b_raw = bias.has_value() ? bias.value() : at::Tensor(); + + const int64_t dim = act.dim(); + int64_t K = act.size(dim - 1); + int64_t M = act.numel() / K; + // [M, K] x [K, N] + int64_t N = weight.size(1); + + std::vector src_dims = {M, K}; + std::vector dst_dims = {M, N}; + auto out_dtype = + output_dtype.has_value() ? output_dtype.value() : act.scalar_type(); + Tensor qout = at::empty(dst_dims, device(c10::kXPU).dtype(out_dtype)); + + quantized_matmul( + act.contiguous(), + act_scale.item().toDouble(), + act_zero_point.item().toLong(), + weight.contiguous(), + weight_scales, + weight_zero_points, + b_raw, + qout, + output_scale, + output_zero_point, + output_dtype, + /*other*/ other, + /*other scale*/ other_scale, + /*other zp*/ other_zero_point, + /*binary post op*/ binary_post_op, + /*binary alpha*/ binary_alpha, + unary_post_op, + unary_post_op_args, + unary_post_op_algorithm); + + return qout; +} + at::Tensor q_linear_prepack_onednn( at::Tensor weight, std::optional> input_shape) { @@ -142,6 +256,12 @@ TORCH_LIBRARY_IMPL(onednn, XPU, m) { m.impl( TORCH_SELECTIVE_NAME("onednn::qlinear_prepack"), TORCH_FN(q_linear_prepack_onednn)); + m.impl( + TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary"), + TORCH_FN(q_linear_pointwise_binary)); + m.impl( + TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary_tensor"), + TORCH_FN(q_linear_pointwise_binary_tensor)); } } // namespace at::native::xpu diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 067d16d7894..33fe8ae28e8 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -2347,8 +2347,13 @@ class TestPatternMatcher(TestPatternMatcherBase): (torch.randn((2, 4)),), gelu, int8_mixed_bf16=True ) - def _qlinear_add_cpu_test_helper( - self, use_relu=False, int8_mixed_bf16=False, is_qat=True, is_dynamic=True + def _qlinear_add_test_helper( + self, + device="cpu", + use_relu=False, + int8_mixed_bf16=False, + is_qat=True, + is_dynamic=True, ): r""" This testcase will quantize two consecutive Linear->Add(->relu) patterns as: @@ -2423,8 +2428,10 @@ class TestPatternMatcher(TestPatternMatcherBase): fake_quant_x2_list = [False, True] if int8_mixed_bf16 else [False] cases = itertools.product(add_fn_list, fake_quant_x2_list) for add_fn, fq_x2 in cases: - mod = M(add_fn, use_relu, fq_x2).eval() - v = torch.randn((4, 4), dtype=torch.float32, requires_grad=False).add(1) + mod = M(add_fn, use_relu, fq_x2).eval().to(device=device) + v = torch.randn( + (4, 4), dtype=torch.float32, requires_grad=False, device=device + ).add(1) def matcher_check_fn(): # 1. Dequant-linear pattern matched in quantization weight prepack * 4 @@ -2505,10 +2512,22 @@ class TestPatternMatcher(TestPatternMatcherBase): @parametrize("is_qat", [True, False]) @parametrize("is_dynamic", [True, False]) def test_qlinear_add_cpu(self, use_relu, is_qat, is_dynamic): - self._qlinear_add_cpu_test_helper( + self._qlinear_add_test_helper( use_relu=use_relu, is_qat=is_qat, is_dynamic=is_dynamic ) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoXPU + @config.patch({"fx_graph_cache": False}) + @parametrize("use_relu", [True]) + @parametrize("is_qat", [False]) + @parametrize("is_dynamic", [False]) + def test_qlinear_add_xpu(self, use_relu, is_qat, is_dynamic): + self._qlinear_add_test_helper( + device="xpu", use_relu=use_relu, is_qat=is_qat, is_dynamic=is_dynamic + ) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -2516,7 +2535,7 @@ class TestPatternMatcher(TestPatternMatcherBase): @parametrize("is_qat", [True, False]) @parametrize("is_dynamic", [True, False]) def test_qlinear_add_int8_mixed_bf16(self, use_relu, is_qat, is_dynamic): - self._qlinear_add_cpu_test_helper( + self._qlinear_add_test_helper( int8_mixed_bf16=True, use_relu=use_relu, is_qat=is_qat, @@ -2679,6 +2698,41 @@ class TestPatternMatcher(TestPatternMatcherBase): is_dynamic=True, ) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoXPU + @config.patch({"fx_graph_cache": False}) + def test_qlinear_mul_xpu(self): + r""" + This testcase will quantize a Linear->Mul pattern. + """ + + class M(torch.nn.Module): + def __init__(self, use_bias): + super().__init__() + self.linear = torch.nn.Linear(4, 5, use_bias) + + def forward(self, x1, x2): + return torch.mul(self.linear(x1), x2) + + bias_list = [True, False] + for bias in bias_list: + mod = M(bias).eval().to(device="xpu") + x1 = torch.randn((2, 4)).to(device="xpu") + x2 = torch.randn((2, 5)).to(device="xpu") + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1 + ) + + self._test_common( + mod, + (x1, x2), + check_quantization=True, + matcher_check_fn=matcher_check_fn, + ) + @skipIfNoDynamoSupport @skipIfNoONEDNN def test_qlinear_mul_cpu(self): diff --git a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py index f1e3d74448e..68dd42936cf 100644 --- a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py @@ -96,23 +96,6 @@ class XPUInductorQuantizer(X86InductorQuantizer): ): pass - def _annotate_linear_fusion_pattern( - self, - model: torch.fx.GraphModule, - quantization_config: Optional[QuantizationConfig], - filter_fn: Optional[FilterFn] = None, - ): - self._annotate_linear_unary(model, quantization_config, filter_fn) - self._annotate_linear(model, quantization_config, filter_fn) - - def _annotate_matmul( - self, - model: torch.fx.GraphModule, - quantization_config: Optional[QuantizationConfig], - filter_fn: Optional[FilterFn] = None, - ): - pass - def _annotate_maxpool2d( self, node: Node,