[Quant][Inductor][X86] Separate unary post op fusion and lowering for qconv (#144312)

**Summary**
The current implementation fuses quantized ops and their post ops and lowers the fused the op to cpp backend in the same pass. It is better to separate post op fusion and lowering because
- it looks better in terms of design
- we need the post op fusion pass for PT2E quantization eager mode

As one of a series of PRs which do the separation, this PR moves unary post op fusion of qconv out of the lowering pass to after the weight-prepack pass. The workflow is
1. Weight prepack for qlinear so that `dq - conv` patterns are replaced by `onednn.qconv2d_pointwise`
2. Fuse `onednn.qconv2d_pointwise` and post ops
3. Lower to cpp backend

This PR adds additional `PatternMatcherPass`'s to handle the post op fusion. Pattern matchers used for fusion are reused.

**Test plan**
It is covered by existing UTs in `test_mkldnn_pattern_matcher.py` for post op fusion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144312
Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168
ghstack dependencies: #144224
This commit is contained in:
Xia, Weiwen 2025-01-13 18:17:35 -08:00 committed by PyTorch MergeBot
parent 825fe15024
commit 9199c79a9c
3 changed files with 282 additions and 151 deletions

View file

@ -1004,6 +1004,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"],
12 if int8_mixed_bf16 else 8,
)
self.assertEqual(
counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 2
)
self._test_common(
mod,
@ -1082,6 +1085,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
counters["inductor"]["qconv2d_unary_matcher_count"],
0 if TEST_ACL else 2,
)
self.assertEqual(
counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 2
)
if qconv2d_unary_matcher_nodes:
self.assertEqual(
counters["inductor"]["qconv2d_unary_matcher_nodes"],
@ -1508,6 +1514,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
counters["inductor"]["qconv2d_unary_matcher_count"],
0 if TEST_ACL else 3,
)
self.assertEqual(
counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 4
)
self._test_common(
mod,
@ -1683,6 +1692,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
counters["inductor"]["qconv2d_unary_matcher_nodes"],
0 if TEST_ACL else 2,
)
self.assertEqual(
counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 1
)
self._test_common(
mod,
@ -1728,6 +1740,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
counters["inductor"]["qconv2d_unary_matcher_count"],
0 if TEST_ACL else 2,
)
self.assertEqual(
counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 2
)
self._test_common(
mod,
@ -2654,6 +2669,10 @@ class TestPatternMatcher(TestPatternMatcherBase):
counters["inductor"]["qconv2d_unary_matcher_count"],
0 if TEST_ACL else 1,
)
self.assertEqual(
counters["inductor"]["qconv2d_unary_lower_count"],
0 if TEST_ACL else 1,
)
self._test_common(
mod,
@ -2749,6 +2768,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
counters["inductor"]["qconv2d_unary_matcher_count"],
0 if TEST_ACL else 2,
)
self.assertEqual(
counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 2
)
self._test_common(
mod,

View file

@ -163,26 +163,26 @@ dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction(
)
def get_dequantize_qconv_pt2e_pattern(users=1):
def get_qconv2d_pt2e_pattern(users=1):
return CallFunction(
torch.ops.onednn.qconv2d_pointwise.default,
KeywordArg("x"),
KeywordArg("x_scale"), # x_scale
KeywordArg("x_zp"), # x_zp
KeywordArg("packed_weight"), # packed_weight
KeywordArg("w_scale"), # w_scale
KeywordArg("w_zp"), # w_zp
KeywordArg("b"), # bias
KeywordArg("x_scale"),
KeywordArg("x_zp"),
KeywordArg("packed_weight"),
KeywordArg("w_scale"),
KeywordArg("w_zp"),
KeywordArg("b"),
KeywordArg("stride"),
KeywordArg("padding"),
KeywordArg("dilation"),
KeywordArg("groups"),
KeywordArg("output_scale"), # output_scale = 1.0
KeywordArg("output_zero_point"), # output_zero_point = 0
KeywordArg("output_dtype"), # output_dtype = None
KeywordArg("attr"), # attr = "none"
Arg(), # scalars
Arg(), # algorithm
KeywordArg("output_scale"),
KeywordArg("output_zero_point"),
KeywordArg("output_dtype"),
KeywordArg("postop_name"),
KeywordArg("postop_args"),
KeywordArg("postop_algorithm"),
_users=users,
)
@ -332,15 +332,28 @@ def _is_valid_quantized_conv2d_optimization_pattern():
return fn
def _is_valid_qconv_lowering_pattern():
def fn(match):
if len(match.nodes) != 1:
return False
return match.nodes[0].target in (
torch.ops.onednn.qconv2d_pointwise.default,
torch.ops.onednn.qconv2d_pointwise.tensor,
torch.ops.onednn.qconv2d_pointwise.binary,
torch.ops.onednn.qconv2d_pointwise.binary_tensor,
)
return fn
def _register_quantized_conv_lowering(
pattern,
pass_number,
computation_op,
unary_attr,
):
@register_lowering_pattern(
pattern,
extra_check=_is_valid_quantized_conv2d_optimization_pattern(),
extra_check=_is_valid_qconv_lowering_pattern(),
pass_number=pass_number,
)
def qconv(match: Match, *args, **kwargs):
@ -367,23 +380,13 @@ def _register_quantized_conv_lowering(
output_dtype = _get_pattern_output_dtype(match)
assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16]
# Output QParams
o_inv_scale = (
kwargs["o_inv_scale"]
if (output_dtype == torch.uint8 or output_dtype == torch.int8)
else 1.0
)
o_zero_point = (
kwargs["o_zp"]
if (output_dtype == torch.uint8 or output_dtype == torch.int8)
else 0
)
assert (
kwargs["attr"] == "none"
) # Expected no post op fused in weight prepack phase
if unary_attr.op_name == "hardtanh":
min_value = kwargs.get("min_value")
max_value = kwargs.get("max_value")
unary_attr.scalars_attr = [min_value, max_value]
o_inv_scale = kwargs["output_scale"]
o_zero_point = kwargs["output_zero_point"]
output_dtype = kwargs["output_dtype"]
# post op
postop_name = kwargs["postop_name"]
postop_args = kwargs["postop_args"]
postop_algorithm = kwargs["postop_algorithm"]
computation_args = (
x,
@ -400,12 +403,12 @@ def _register_quantized_conv_lowering(
o_inv_scale,
o_zero_point,
output_dtype,
unary_attr.op_name,
unary_attr.scalars_attr,
unary_attr.algorithm_attr,
postop_name,
postop_args,
postop_algorithm,
)
counters["inductor"]["qconv2d_unary_matcher_count"] += 1
counters["inductor"]["qconv2d_unary_matcher_nodes"] += len(match.nodes)
counters["inductor"]["qconv2d_unary_lower_count"] += 1
counters["inductor"]["qconv2d_unary_lower_nodes"] += len(match.nodes)
return L[computation_op](*computation_args)
return qconv
@ -755,115 +758,15 @@ def _register_quantized_conv_binary_lowering(
return qconv_binary
def _register_quantization_unary_fusion():
from .mkldnn_fusion import _hardswish_fusion, _hardtanh_fusion, _silu_fusion
class UnaryAttr:
def __init__(
self, op_name: str, scalars_attr=None, algorithm_attr=None
) -> None:
self.op_name = op_name
self.scalars_attr = scalars_attr if scalars_attr else []
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
def _register_quantization_unary_lowering():
# QConv2d
for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
# Priority 1 to match: QConv2d Unary pattern with int8 output
# If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly.
# For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant
is_bf16 = original_pattern_output_dtype == torch.bfloat16
conv_unary_replace_patterns = {
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
get_dequantize_qconv_pt2e_pattern(1),
),
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
generate_pattern_with_unary(
get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
),
),
UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant(
_unary_fusion_pattern(
_hardtanh_fusion,
get_dequantize_qconv_pt2e_pattern(1),
1,
is_bf16,
),
with_dtype_convert=is_bf16,
),
UnaryAttr("hardswish", [], ""): generate_pattern_with_output_quant(
_unary_fusion_pattern(
_hardswish_fusion,
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
2,
is_bf16,
),
with_dtype_convert=is_bf16,
),
UnaryAttr("swish", [], ""): generate_pattern_with_output_quant(
_unary_fusion_pattern(
_silu_fusion,
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
2,
is_bf16,
),
with_dtype_convert=is_bf16,
),
}
for unary_attr, patterns in conv_unary_replace_patterns.items():
# Register qconv2d pattern for ExternKernel Lowering
_register_quantized_conv_lowering(
patterns,
1, # pass_number
torch.ops.onednn.qconv2d_pointwise, # computation_op
unary_attr, # unary_attr
)
# Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
conv_unary_replace_float_out_patterns = {
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
),
UnaryAttr("hardtanh", [], ""): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_hardtanh_fusion,
get_dequantize_qconv_pt2e_pattern(1),
1,
is_bf16,
),
Arg(),
is_bf16,
),
UnaryAttr("hardswish", [], ""): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_hardswish_fusion,
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
2,
is_bf16,
),
Arg(),
is_bf16,
),
UnaryAttr("swish", [], ""): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_silu_fusion,
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
2,
is_bf16,
),
Arg(),
is_bf16,
),
}
for unary_attr, patterns in conv_unary_replace_float_out_patterns.items():
# Register qconv2d pattern for ExternKernel Lowering
_register_quantized_conv_lowering(
patterns,
2, # pass_number
torch.ops.onednn.qconv2d_pointwise, # computation_op
unary_attr, # unary_attr
)
for users in [1, 2]:
qconv_pattern = get_qconv2d_pt2e_pattern(users)
_register_quantized_conv_lowering(
qconv_pattern,
2, # pass_number
torch.ops.onednn.qconv2d_pointwise.default, # computation_op
)
# QLinear
for x_scale_zp_are_tensors in (False, True):
@ -908,7 +811,7 @@ def _register_quantization_binary_fusion():
): generate_pattern_with_output_quant(
generate_pattern_with_binary(
aten.add.Tensor,
get_dequantize_qconv_pt2e_pattern(1),
get_qconv2d_pt2e_pattern(1),
dequantize_accum_pattern,
int8_mixed_bf16_with_inplace_add,
swap_inputs=swap_inputs,
@ -920,7 +823,7 @@ def _register_quantization_binary_fusion():
generate_pattern_with_unary(
generate_pattern_with_binary(
aten.add.Tensor,
get_dequantize_qconv_pt2e_pattern(1),
get_qconv2d_pt2e_pattern(1),
dequantize_accum_pattern,
int8_mixed_bf16_with_inplace_add,
swap_inputs=swap_inputs,
@ -949,7 +852,7 @@ def _register_quantization_binary_fusion():
): generate_pattern_with_unary(
generate_pattern_with_binary(
aten.add.Tensor,
get_dequantize_qconv_pt2e_pattern(1),
get_qconv2d_pt2e_pattern(1),
KeywordArg("accum_after_dequant"),
int8_mixed_bf16_with_inplace_add,
swap_inputs=swap_inputs,
@ -987,7 +890,7 @@ def _register_quantization_binary_fusion():
"sum", 1.0, "none", [], ""
): generate_pattern_with_binary(
aten.add.Tensor,
get_dequantize_qconv_pt2e_pattern(1),
get_qconv2d_pt2e_pattern(1),
KeywordArg("accum_after_dequant"),
int8_mixed_bf16_with_inplace_add,
swap_inputs=swap_inputs,
@ -1428,7 +1331,7 @@ def _register_woq_mm_int8_pattern4():
def _register_quantization_lowerings():
_register_quantization_unary_fusion()
_register_quantization_unary_lowering()
_register_quantization_binary_fusion()
_register_quantization_maxpool2d()
_register_quantization_cat()
@ -2891,6 +2794,211 @@ class PostOpAttr:
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
def _register_qconv_post_op_fusion_pass(
pattern,
pass_number,
computation_op,
post_op_attr,
):
@register_freezing_graph_pattern(
pattern,
extra_check=_is_valid_quantized_conv2d_optimization_pattern(),
pass_number=pass_number,
)
def qconv(match: Match, *args, **kwargs):
# Activation QParams
x, x_scale, x_zp = (
kwargs["x"],
kwargs["x_scale"],
kwargs["x_zp"],
)
# Weight QParams
packed_weight, w_scale, w_zp = (
kwargs["packed_weight"],
kwargs["w_scale"],
kwargs["w_zp"],
)
# Conv Params
b, stride, padding, dilation, groups = (
kwargs["b"],
kwargs["stride"],
kwargs["padding"],
kwargs["dilation"],
kwargs["groups"],
)
output_dtype = _get_pattern_output_dtype(match)
assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16]
# Output QParams
o_inv_scale = (
kwargs["o_inv_scale"]
if (output_dtype == torch.uint8 or output_dtype == torch.int8)
else 1.0
)
o_zero_point = (
kwargs["o_zp"]
if (output_dtype == torch.uint8 or output_dtype == torch.int8)
else 0
)
assert (
kwargs["postop_name"] == "none"
) # Expected no post op fused in weight prepack phase
if post_op_attr.unary_op_name == "hardtanh":
min_value = kwargs.get("min_value")
max_value = kwargs.get("max_value")
post_op_attr.scalars_attr = [min_value, max_value]
out_node = match.output_node()
with match.graph.inserting_before(out_node):
computation_args = (
x,
x_scale,
x_zp,
packed_weight,
w_scale,
w_zp,
b,
stride,
padding,
dilation,
groups,
o_inv_scale,
o_zero_point,
output_dtype,
post_op_attr.unary_op_name,
post_op_attr.scalars_attr,
post_op_attr.algorithm_attr,
)
new_conv_node = match.graph.call_function(
computation_op, args=computation_args
)
out_node.replace_all_uses_with(new_conv_node)
new_conv_node.meta.update(out_node.meta)
for node in reversed(match.nodes):
match.graph.erase_node(node)
counters["inductor"]["qconv2d_unary_matcher_count"] += 1
counters["inductor"]["qconv2d_unary_matcher_nodes"] += len(match.nodes)
return qconv
def _register_qconv_unary_fusion():
from .mkldnn_fusion import _hardswish_fusion, _hardtanh_fusion, _silu_fusion
for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
# Priority 1 to match: QConv2d Unary pattern with int8 output
# If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly.
# For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant
is_bf16 = original_pattern_output_dtype == torch.bfloat16
conv_unary_replace_patterns = {
PostOpAttr(
"none", None, "none", [], ""
): generate_pattern_with_output_quant(
get_qconv2d_pt2e_pattern(1),
),
PostOpAttr(
"none", None, "relu", [], ""
): generate_pattern_with_output_quant(
generate_pattern_with_unary(
get_qconv2d_pt2e_pattern(1), aten.relu.default
),
),
PostOpAttr(
"none", None, "hardtanh", [], ""
): generate_pattern_with_output_quant(
_unary_fusion_pattern(
_hardtanh_fusion,
get_qconv2d_pt2e_pattern(1),
1,
is_bf16,
),
with_dtype_convert=is_bf16,
),
PostOpAttr(
"none", None, "hardswish", [], ""
): generate_pattern_with_output_quant(
_unary_fusion_pattern(
_hardswish_fusion,
get_qconv2d_pt2e_pattern(1 if is_bf16 else 2),
2,
is_bf16,
),
with_dtype_convert=is_bf16,
),
PostOpAttr(
"none", None, "swish", [], ""
): generate_pattern_with_output_quant(
_unary_fusion_pattern(
_silu_fusion,
get_qconv2d_pt2e_pattern(1 if is_bf16 else 2),
2,
is_bf16,
),
with_dtype_convert=is_bf16,
),
}
for unary_attr, patterns in conv_unary_replace_patterns.items():
# Register qconv2d pattern for ExternKernel Lowering
_register_qconv_post_op_fusion_pass(
patterns,
3, # pass_number
torch.ops.onednn.qconv2d_pointwise.default, # computation_op
unary_attr, # unary_attr
)
# Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
conv_unary_replace_float_out_patterns = {
PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary(
get_qconv2d_pt2e_pattern(1), aten.relu.default
),
PostOpAttr(
"none", None, "hardtanh", [], ""
): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_hardtanh_fusion,
get_qconv2d_pt2e_pattern(1),
1,
is_bf16,
),
Arg(),
is_bf16,
),
PostOpAttr(
"none", None, "hardswish", [], ""
): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_hardswish_fusion,
get_qconv2d_pt2e_pattern(1 if is_bf16 else 2),
2,
is_bf16,
),
Arg(),
is_bf16,
),
PostOpAttr(
"none", None, "swish", [], ""
): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_silu_fusion,
get_qconv2d_pt2e_pattern(1 if is_bf16 else 2),
2,
is_bf16,
),
Arg(),
is_bf16,
),
}
for unary_attr, patterns in conv_unary_replace_float_out_patterns.items():
# Register qconv2d pattern for ExternKernel Lowering
_register_qconv_post_op_fusion_pass(
patterns,
4, # pass_number
torch.ops.onednn.qconv2d_pointwise.default, # computation_op
unary_attr, # unary_attr
)
def _register_qlinear_post_op_fusion_pass(
pattern,
pass_number,
@ -3338,6 +3446,7 @@ def _register_quantization_weight_pack_pass():
# Step 5: QLinear post op Fusion
if not torch.ops.mkldnn._is_mkldnn_acl_supported():
# skip fusion on ARM
_register_qconv_unary_fusion()
_register_qlinear_unary_fusion()
_register_qlinear_binary_fusion()

View file

@ -2428,7 +2428,7 @@ if torch._C._has_mkldnn:
groups,
None,
)
assert output_dtype in [torch.float32, torch.bfloat16]
assert output_dtype in [torch.float32, torch.bfloat16, torch.uint8]
out = x.new_empty(shape_out, dtype=output_dtype)
out = out.to(memory_format=torch.channels_last)
return out