[Intel GPU] qlinear_pointwise.binary[_tensor] XPU support

ghstack-source-id: d0b52577b4
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135337
This commit is contained in:
Yan Zhiwei 2025-02-10 04:25:53 -08:00 committed by guangyey
parent b91481ba82
commit 691f37d6b3
3 changed files with 180 additions and 23 deletions

View file

@ -125,6 +125,120 @@ Tensor q_linear_pointwise_tensor(
return qout;
}
Tensor q_linear_pointwise_binary(
Tensor act,
double act_scale,
int64_t act_zero_point,
Tensor weight,
Tensor weight_scales,
Tensor weight_zero_points,
std::optional<at::Tensor> other,
std::optional<Tensor> bias,
double output_scale,
int64_t output_zero_point,
std::optional<c10::ScalarType> output_dtype,
double other_scale,
int64_t other_zero_point,
c10::string_view binary_post_op,
double binary_alpha,
c10::string_view unary_post_op,
torch::List<std::optional<at::Scalar>> unary_post_op_args,
c10::string_view unary_post_op_algorithm) {
Tensor b_raw = bias.has_value() ? bias.value() : at::Tensor();
const int64_t dim = act.dim();
int64_t K = act.size(dim - 1);
int64_t M = act.numel() / K;
// [M, K] x [K, N]
int64_t N = weight.size(1);
std::vector<int64_t> src_dims = {M, K};
std::vector<int64_t> dst_dims = {M, N};
auto out_dtype =
output_dtype.has_value() ? output_dtype.value() : act.scalar_type();
Tensor qout = at::empty(dst_dims, device(c10::kXPU).dtype(out_dtype));
quantized_matmul(
act.contiguous(),
act_scale,
act_zero_point,
weight.contiguous(),
weight_scales,
weight_zero_points,
b_raw,
qout,
output_scale,
output_zero_point,
output_dtype,
/*other*/ other,
/*other scale*/ other_scale,
/*other zp*/ other_zero_point,
/*binary post op*/ binary_post_op,
/*binary alpha*/ binary_alpha,
unary_post_op,
unary_post_op_args,
unary_post_op_algorithm);
return qout;
}
Tensor q_linear_pointwise_binary_tensor(
Tensor act,
Tensor act_scale,
Tensor act_zero_point,
Tensor weight,
Tensor weight_scales,
Tensor weight_zero_points,
std::optional<at::Tensor> other,
std::optional<Tensor> bias,
double output_scale,
int64_t output_zero_point,
std::optional<c10::ScalarType> output_dtype,
double other_scale,
int64_t other_zero_point,
c10::string_view binary_post_op,
double binary_alpha,
c10::string_view unary_post_op,
torch::List<std::optional<at::Scalar>> unary_post_op_args,
c10::string_view unary_post_op_algorithm) {
Tensor b_raw = bias.has_value() ? bias.value() : at::Tensor();
const int64_t dim = act.dim();
int64_t K = act.size(dim - 1);
int64_t M = act.numel() / K;
// [M, K] x [K, N]
int64_t N = weight.size(1);
std::vector<int64_t> src_dims = {M, K};
std::vector<int64_t> dst_dims = {M, N};
auto out_dtype =
output_dtype.has_value() ? output_dtype.value() : act.scalar_type();
Tensor qout = at::empty(dst_dims, device(c10::kXPU).dtype(out_dtype));
quantized_matmul(
act.contiguous(),
act_scale.item().toDouble(),
act_zero_point.item().toLong(),
weight.contiguous(),
weight_scales,
weight_zero_points,
b_raw,
qout,
output_scale,
output_zero_point,
output_dtype,
/*other*/ other,
/*other scale*/ other_scale,
/*other zp*/ other_zero_point,
/*binary post op*/ binary_post_op,
/*binary alpha*/ binary_alpha,
unary_post_op,
unary_post_op_args,
unary_post_op_algorithm);
return qout;
}
at::Tensor q_linear_prepack_onednn(
at::Tensor weight,
std::optional<torch::List<int64_t>> input_shape) {
@ -142,6 +256,12 @@ TORCH_LIBRARY_IMPL(onednn, XPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("onednn::qlinear_prepack"),
TORCH_FN(q_linear_prepack_onednn));
m.impl(
TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary"),
TORCH_FN(q_linear_pointwise_binary));
m.impl(
TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary_tensor"),
TORCH_FN(q_linear_pointwise_binary_tensor));
}
} // namespace at::native::xpu

View file

@ -2347,8 +2347,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
(torch.randn((2, 4)),), gelu, int8_mixed_bf16=True
)
def _qlinear_add_cpu_test_helper(
self, use_relu=False, int8_mixed_bf16=False, is_qat=True, is_dynamic=True
def _qlinear_add_test_helper(
self,
device="cpu",
use_relu=False,
int8_mixed_bf16=False,
is_qat=True,
is_dynamic=True,
):
r"""
This testcase will quantize two consecutive Linear->Add(->relu) patterns as:
@ -2423,8 +2428,10 @@ class TestPatternMatcher(TestPatternMatcherBase):
fake_quant_x2_list = [False, True] if int8_mixed_bf16 else [False]
cases = itertools.product(add_fn_list, fake_quant_x2_list)
for add_fn, fq_x2 in cases:
mod = M(add_fn, use_relu, fq_x2).eval()
v = torch.randn((4, 4), dtype=torch.float32, requires_grad=False).add(1)
mod = M(add_fn, use_relu, fq_x2).eval().to(device=device)
v = torch.randn(
(4, 4), dtype=torch.float32, requires_grad=False, device=device
).add(1)
def matcher_check_fn():
# 1. Dequant-linear pattern matched in quantization weight prepack * 4
@ -2505,10 +2512,22 @@ class TestPatternMatcher(TestPatternMatcherBase):
@parametrize("is_qat", [True, False])
@parametrize("is_dynamic", [True, False])
def test_qlinear_add_cpu(self, use_relu, is_qat, is_dynamic):
self._qlinear_add_cpu_test_helper(
self._qlinear_add_test_helper(
use_relu=use_relu, is_qat=is_qat, is_dynamic=is_dynamic
)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfNoXPU
@config.patch({"fx_graph_cache": False})
@parametrize("use_relu", [True])
@parametrize("is_qat", [False])
@parametrize("is_dynamic", [False])
def test_qlinear_add_xpu(self, use_relu, is_qat, is_dynamic):
self._qlinear_add_test_helper(
device="xpu", use_relu=use_relu, is_qat=is_qat, is_dynamic=is_dynamic
)
@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
@ -2516,7 +2535,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
@parametrize("is_qat", [True, False])
@parametrize("is_dynamic", [True, False])
def test_qlinear_add_int8_mixed_bf16(self, use_relu, is_qat, is_dynamic):
self._qlinear_add_cpu_test_helper(
self._qlinear_add_test_helper(
int8_mixed_bf16=True,
use_relu=use_relu,
is_qat=is_qat,
@ -2679,6 +2698,41 @@ class TestPatternMatcher(TestPatternMatcherBase):
is_dynamic=True,
)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfNoXPU
@config.patch({"fx_graph_cache": False})
def test_qlinear_mul_xpu(self):
r"""
This testcase will quantize a Linear->Mul pattern.
"""
class M(torch.nn.Module):
def __init__(self, use_bias):
super().__init__()
self.linear = torch.nn.Linear(4, 5, use_bias)
def forward(self, x1, x2):
return torch.mul(self.linear(x1), x2)
bias_list = [True, False]
for bias in bias_list:
mod = M(bias).eval().to(device="xpu")
x1 = torch.randn((2, 4)).to(device="xpu")
x2 = torch.randn((2, 5)).to(device="xpu")
def matcher_check_fn():
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
)
self._test_common(
mod,
(x1, x2),
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
def test_qlinear_mul_cpu(self):

View file

@ -96,23 +96,6 @@ class XPUInductorQuantizer(X86InductorQuantizer):
):
pass
def _annotate_linear_fusion_pattern(
self,
model: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[FilterFn] = None,
):
self._annotate_linear_unary(model, quantization_config, filter_fn)
self._annotate_linear(model, quantization_config, filter_fn)
def _annotate_matmul(
self,
model: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[FilterFn] = None,
):
pass
def _annotate_maxpool2d(
self,
node: Node,