mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Intel GPU] qlinear_pointwise.binary[_tensor] XPU support
ghstack-source-id: d0b52577b4
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135337
This commit is contained in:
parent
b91481ba82
commit
691f37d6b3
3 changed files with 180 additions and 23 deletions
|
|
@ -125,6 +125,120 @@ Tensor q_linear_pointwise_tensor(
|
|||
return qout;
|
||||
}
|
||||
|
||||
Tensor q_linear_pointwise_binary(
|
||||
Tensor act,
|
||||
double act_scale,
|
||||
int64_t act_zero_point,
|
||||
Tensor weight,
|
||||
Tensor weight_scales,
|
||||
Tensor weight_zero_points,
|
||||
std::optional<at::Tensor> other,
|
||||
std::optional<Tensor> bias,
|
||||
double output_scale,
|
||||
int64_t output_zero_point,
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
double other_scale,
|
||||
int64_t other_zero_point,
|
||||
c10::string_view binary_post_op,
|
||||
double binary_alpha,
|
||||
c10::string_view unary_post_op,
|
||||
torch::List<std::optional<at::Scalar>> unary_post_op_args,
|
||||
c10::string_view unary_post_op_algorithm) {
|
||||
Tensor b_raw = bias.has_value() ? bias.value() : at::Tensor();
|
||||
|
||||
const int64_t dim = act.dim();
|
||||
int64_t K = act.size(dim - 1);
|
||||
int64_t M = act.numel() / K;
|
||||
// [M, K] x [K, N]
|
||||
int64_t N = weight.size(1);
|
||||
|
||||
std::vector<int64_t> src_dims = {M, K};
|
||||
std::vector<int64_t> dst_dims = {M, N};
|
||||
auto out_dtype =
|
||||
output_dtype.has_value() ? output_dtype.value() : act.scalar_type();
|
||||
Tensor qout = at::empty(dst_dims, device(c10::kXPU).dtype(out_dtype));
|
||||
|
||||
quantized_matmul(
|
||||
act.contiguous(),
|
||||
act_scale,
|
||||
act_zero_point,
|
||||
weight.contiguous(),
|
||||
weight_scales,
|
||||
weight_zero_points,
|
||||
b_raw,
|
||||
qout,
|
||||
output_scale,
|
||||
output_zero_point,
|
||||
output_dtype,
|
||||
/*other*/ other,
|
||||
/*other scale*/ other_scale,
|
||||
/*other zp*/ other_zero_point,
|
||||
/*binary post op*/ binary_post_op,
|
||||
/*binary alpha*/ binary_alpha,
|
||||
unary_post_op,
|
||||
unary_post_op_args,
|
||||
unary_post_op_algorithm);
|
||||
|
||||
return qout;
|
||||
}
|
||||
|
||||
Tensor q_linear_pointwise_binary_tensor(
|
||||
Tensor act,
|
||||
Tensor act_scale,
|
||||
Tensor act_zero_point,
|
||||
Tensor weight,
|
||||
Tensor weight_scales,
|
||||
Tensor weight_zero_points,
|
||||
std::optional<at::Tensor> other,
|
||||
std::optional<Tensor> bias,
|
||||
double output_scale,
|
||||
int64_t output_zero_point,
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
double other_scale,
|
||||
int64_t other_zero_point,
|
||||
c10::string_view binary_post_op,
|
||||
double binary_alpha,
|
||||
c10::string_view unary_post_op,
|
||||
torch::List<std::optional<at::Scalar>> unary_post_op_args,
|
||||
c10::string_view unary_post_op_algorithm) {
|
||||
Tensor b_raw = bias.has_value() ? bias.value() : at::Tensor();
|
||||
|
||||
const int64_t dim = act.dim();
|
||||
int64_t K = act.size(dim - 1);
|
||||
int64_t M = act.numel() / K;
|
||||
// [M, K] x [K, N]
|
||||
int64_t N = weight.size(1);
|
||||
|
||||
std::vector<int64_t> src_dims = {M, K};
|
||||
std::vector<int64_t> dst_dims = {M, N};
|
||||
auto out_dtype =
|
||||
output_dtype.has_value() ? output_dtype.value() : act.scalar_type();
|
||||
Tensor qout = at::empty(dst_dims, device(c10::kXPU).dtype(out_dtype));
|
||||
|
||||
quantized_matmul(
|
||||
act.contiguous(),
|
||||
act_scale.item().toDouble(),
|
||||
act_zero_point.item().toLong(),
|
||||
weight.contiguous(),
|
||||
weight_scales,
|
||||
weight_zero_points,
|
||||
b_raw,
|
||||
qout,
|
||||
output_scale,
|
||||
output_zero_point,
|
||||
output_dtype,
|
||||
/*other*/ other,
|
||||
/*other scale*/ other_scale,
|
||||
/*other zp*/ other_zero_point,
|
||||
/*binary post op*/ binary_post_op,
|
||||
/*binary alpha*/ binary_alpha,
|
||||
unary_post_op,
|
||||
unary_post_op_args,
|
||||
unary_post_op_algorithm);
|
||||
|
||||
return qout;
|
||||
}
|
||||
|
||||
at::Tensor q_linear_prepack_onednn(
|
||||
at::Tensor weight,
|
||||
std::optional<torch::List<int64_t>> input_shape) {
|
||||
|
|
@ -142,6 +256,12 @@ TORCH_LIBRARY_IMPL(onednn, XPU, m) {
|
|||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("onednn::qlinear_prepack"),
|
||||
TORCH_FN(q_linear_prepack_onednn));
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary"),
|
||||
TORCH_FN(q_linear_pointwise_binary));
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary_tensor"),
|
||||
TORCH_FN(q_linear_pointwise_binary_tensor));
|
||||
}
|
||||
|
||||
} // namespace at::native::xpu
|
||||
|
|
|
|||
|
|
@ -2347,8 +2347,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
(torch.randn((2, 4)),), gelu, int8_mixed_bf16=True
|
||||
)
|
||||
|
||||
def _qlinear_add_cpu_test_helper(
|
||||
self, use_relu=False, int8_mixed_bf16=False, is_qat=True, is_dynamic=True
|
||||
def _qlinear_add_test_helper(
|
||||
self,
|
||||
device="cpu",
|
||||
use_relu=False,
|
||||
int8_mixed_bf16=False,
|
||||
is_qat=True,
|
||||
is_dynamic=True,
|
||||
):
|
||||
r"""
|
||||
This testcase will quantize two consecutive Linear->Add(->relu) patterns as:
|
||||
|
|
@ -2423,8 +2428,10 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
fake_quant_x2_list = [False, True] if int8_mixed_bf16 else [False]
|
||||
cases = itertools.product(add_fn_list, fake_quant_x2_list)
|
||||
for add_fn, fq_x2 in cases:
|
||||
mod = M(add_fn, use_relu, fq_x2).eval()
|
||||
v = torch.randn((4, 4), dtype=torch.float32, requires_grad=False).add(1)
|
||||
mod = M(add_fn, use_relu, fq_x2).eval().to(device=device)
|
||||
v = torch.randn(
|
||||
(4, 4), dtype=torch.float32, requires_grad=False, device=device
|
||||
).add(1)
|
||||
|
||||
def matcher_check_fn():
|
||||
# 1. Dequant-linear pattern matched in quantization weight prepack * 4
|
||||
|
|
@ -2505,10 +2512,22 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
@parametrize("is_qat", [True, False])
|
||||
@parametrize("is_dynamic", [True, False])
|
||||
def test_qlinear_add_cpu(self, use_relu, is_qat, is_dynamic):
|
||||
self._qlinear_add_cpu_test_helper(
|
||||
self._qlinear_add_test_helper(
|
||||
use_relu=use_relu, is_qat=is_qat, is_dynamic=is_dynamic
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
@skipIfNoXPU
|
||||
@config.patch({"fx_graph_cache": False})
|
||||
@parametrize("use_relu", [True])
|
||||
@parametrize("is_qat", [False])
|
||||
@parametrize("is_dynamic", [False])
|
||||
def test_qlinear_add_xpu(self, use_relu, is_qat, is_dynamic):
|
||||
self._qlinear_add_test_helper(
|
||||
device="xpu", use_relu=use_relu, is_qat=is_qat, is_dynamic=is_dynamic
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNNBF16
|
||||
@skipIfNoONEDNN
|
||||
|
|
@ -2516,7 +2535,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
@parametrize("is_qat", [True, False])
|
||||
@parametrize("is_dynamic", [True, False])
|
||||
def test_qlinear_add_int8_mixed_bf16(self, use_relu, is_qat, is_dynamic):
|
||||
self._qlinear_add_cpu_test_helper(
|
||||
self._qlinear_add_test_helper(
|
||||
int8_mixed_bf16=True,
|
||||
use_relu=use_relu,
|
||||
is_qat=is_qat,
|
||||
|
|
@ -2679,6 +2698,41 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
is_dynamic=True,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
@skipIfNoXPU
|
||||
@config.patch({"fx_graph_cache": False})
|
||||
def test_qlinear_mul_xpu(self):
|
||||
r"""
|
||||
This testcase will quantize a Linear->Mul pattern.
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, use_bias):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(4, 5, use_bias)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
return torch.mul(self.linear(x1), x2)
|
||||
|
||||
bias_list = [True, False]
|
||||
for bias in bias_list:
|
||||
mod = M(bias).eval().to(device="xpu")
|
||||
x1 = torch.randn((2, 4)).to(device="xpu")
|
||||
x2 = torch.randn((2, 5)).to(device="xpu")
|
||||
|
||||
def matcher_check_fn():
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
|
||||
)
|
||||
|
||||
self._test_common(
|
||||
mod,
|
||||
(x1, x2),
|
||||
check_quantization=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_mul_cpu(self):
|
||||
|
|
|
|||
|
|
@ -96,23 +96,6 @@ class XPUInductorQuantizer(X86InductorQuantizer):
|
|||
):
|
||||
pass
|
||||
|
||||
def _annotate_linear_fusion_pattern(
|
||||
self,
|
||||
model: torch.fx.GraphModule,
|
||||
quantization_config: Optional[QuantizationConfig],
|
||||
filter_fn: Optional[FilterFn] = None,
|
||||
):
|
||||
self._annotate_linear_unary(model, quantization_config, filter_fn)
|
||||
self._annotate_linear(model, quantization_config, filter_fn)
|
||||
|
||||
def _annotate_matmul(
|
||||
self,
|
||||
model: torch.fx.GraphModule,
|
||||
quantization_config: Optional[QuantizationConfig],
|
||||
filter_fn: Optional[FilterFn] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def _annotate_maxpool2d(
|
||||
self,
|
||||
node: Node,
|
||||
|
|
|
|||
Loading…
Reference in a new issue