mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Quant][Inductor][X86] Separate unary post op fusion and lowering for qconv (#144312)
**Summary** The current implementation fuses quantized ops and their post ops and lowers the fused the op to cpp backend in the same pass. It is better to separate post op fusion and lowering because - it looks better in terms of design - we need the post op fusion pass for PT2E quantization eager mode As one of a series of PRs which do the separation, this PR moves unary post op fusion of qconv out of the lowering pass to after the weight-prepack pass. The workflow is 1. Weight prepack for qlinear so that `dq - conv` patterns are replaced by `onednn.qconv2d_pointwise` 2. Fuse `onednn.qconv2d_pointwise` and post ops 3. Lower to cpp backend This PR adds additional `PatternMatcherPass`'s to handle the post op fusion. Pattern matchers used for fusion are reused. **Test plan** It is covered by existing UTs in `test_mkldnn_pattern_matcher.py` for post op fusion. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144312 Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168 ghstack dependencies: #144224
This commit is contained in:
parent
825fe15024
commit
9199c79a9c
3 changed files with 282 additions and 151 deletions
|
|
@ -1004,6 +1004,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"],
|
||||
12 if int8_mixed_bf16 else 8,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 2
|
||||
)
|
||||
|
||||
self._test_common(
|
||||
mod,
|
||||
|
|
@ -1082,6 +1085,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
counters["inductor"]["qconv2d_unary_matcher_count"],
|
||||
0 if TEST_ACL else 2,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 2
|
||||
)
|
||||
if qconv2d_unary_matcher_nodes:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_unary_matcher_nodes"],
|
||||
|
|
@ -1508,6 +1514,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
counters["inductor"]["qconv2d_unary_matcher_count"],
|
||||
0 if TEST_ACL else 3,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 4
|
||||
)
|
||||
|
||||
self._test_common(
|
||||
mod,
|
||||
|
|
@ -1683,6 +1692,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
counters["inductor"]["qconv2d_unary_matcher_nodes"],
|
||||
0 if TEST_ACL else 2,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 1
|
||||
)
|
||||
|
||||
self._test_common(
|
||||
mod,
|
||||
|
|
@ -1728,6 +1740,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
counters["inductor"]["qconv2d_unary_matcher_count"],
|
||||
0 if TEST_ACL else 2,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 2
|
||||
)
|
||||
|
||||
self._test_common(
|
||||
mod,
|
||||
|
|
@ -2654,6 +2669,10 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
counters["inductor"]["qconv2d_unary_matcher_count"],
|
||||
0 if TEST_ACL else 1,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_unary_lower_count"],
|
||||
0 if TEST_ACL else 1,
|
||||
)
|
||||
|
||||
self._test_common(
|
||||
mod,
|
||||
|
|
@ -2749,6 +2768,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
counters["inductor"]["qconv2d_unary_matcher_count"],
|
||||
0 if TEST_ACL else 2,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 2
|
||||
)
|
||||
|
||||
self._test_common(
|
||||
mod,
|
||||
|
|
|
|||
|
|
@ -163,26 +163,26 @@ dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction(
|
|||
)
|
||||
|
||||
|
||||
def get_dequantize_qconv_pt2e_pattern(users=1):
|
||||
def get_qconv2d_pt2e_pattern(users=1):
|
||||
return CallFunction(
|
||||
torch.ops.onednn.qconv2d_pointwise.default,
|
||||
KeywordArg("x"),
|
||||
KeywordArg("x_scale"), # x_scale
|
||||
KeywordArg("x_zp"), # x_zp
|
||||
KeywordArg("packed_weight"), # packed_weight
|
||||
KeywordArg("w_scale"), # w_scale
|
||||
KeywordArg("w_zp"), # w_zp
|
||||
KeywordArg("b"), # bias
|
||||
KeywordArg("x_scale"),
|
||||
KeywordArg("x_zp"),
|
||||
KeywordArg("packed_weight"),
|
||||
KeywordArg("w_scale"),
|
||||
KeywordArg("w_zp"),
|
||||
KeywordArg("b"),
|
||||
KeywordArg("stride"),
|
||||
KeywordArg("padding"),
|
||||
KeywordArg("dilation"),
|
||||
KeywordArg("groups"),
|
||||
KeywordArg("output_scale"), # output_scale = 1.0
|
||||
KeywordArg("output_zero_point"), # output_zero_point = 0
|
||||
KeywordArg("output_dtype"), # output_dtype = None
|
||||
KeywordArg("attr"), # attr = "none"
|
||||
Arg(), # scalars
|
||||
Arg(), # algorithm
|
||||
KeywordArg("output_scale"),
|
||||
KeywordArg("output_zero_point"),
|
||||
KeywordArg("output_dtype"),
|
||||
KeywordArg("postop_name"),
|
||||
KeywordArg("postop_args"),
|
||||
KeywordArg("postop_algorithm"),
|
||||
_users=users,
|
||||
)
|
||||
|
||||
|
|
@ -332,15 +332,28 @@ def _is_valid_quantized_conv2d_optimization_pattern():
|
|||
return fn
|
||||
|
||||
|
||||
def _is_valid_qconv_lowering_pattern():
|
||||
def fn(match):
|
||||
if len(match.nodes) != 1:
|
||||
return False
|
||||
return match.nodes[0].target in (
|
||||
torch.ops.onednn.qconv2d_pointwise.default,
|
||||
torch.ops.onednn.qconv2d_pointwise.tensor,
|
||||
torch.ops.onednn.qconv2d_pointwise.binary,
|
||||
torch.ops.onednn.qconv2d_pointwise.binary_tensor,
|
||||
)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def _register_quantized_conv_lowering(
|
||||
pattern,
|
||||
pass_number,
|
||||
computation_op,
|
||||
unary_attr,
|
||||
):
|
||||
@register_lowering_pattern(
|
||||
pattern,
|
||||
extra_check=_is_valid_quantized_conv2d_optimization_pattern(),
|
||||
extra_check=_is_valid_qconv_lowering_pattern(),
|
||||
pass_number=pass_number,
|
||||
)
|
||||
def qconv(match: Match, *args, **kwargs):
|
||||
|
|
@ -367,23 +380,13 @@ def _register_quantized_conv_lowering(
|
|||
output_dtype = _get_pattern_output_dtype(match)
|
||||
assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16]
|
||||
# Output QParams
|
||||
o_inv_scale = (
|
||||
kwargs["o_inv_scale"]
|
||||
if (output_dtype == torch.uint8 or output_dtype == torch.int8)
|
||||
else 1.0
|
||||
)
|
||||
o_zero_point = (
|
||||
kwargs["o_zp"]
|
||||
if (output_dtype == torch.uint8 or output_dtype == torch.int8)
|
||||
else 0
|
||||
)
|
||||
assert (
|
||||
kwargs["attr"] == "none"
|
||||
) # Expected no post op fused in weight prepack phase
|
||||
if unary_attr.op_name == "hardtanh":
|
||||
min_value = kwargs.get("min_value")
|
||||
max_value = kwargs.get("max_value")
|
||||
unary_attr.scalars_attr = [min_value, max_value]
|
||||
o_inv_scale = kwargs["output_scale"]
|
||||
o_zero_point = kwargs["output_zero_point"]
|
||||
output_dtype = kwargs["output_dtype"]
|
||||
# post op
|
||||
postop_name = kwargs["postop_name"]
|
||||
postop_args = kwargs["postop_args"]
|
||||
postop_algorithm = kwargs["postop_algorithm"]
|
||||
|
||||
computation_args = (
|
||||
x,
|
||||
|
|
@ -400,12 +403,12 @@ def _register_quantized_conv_lowering(
|
|||
o_inv_scale,
|
||||
o_zero_point,
|
||||
output_dtype,
|
||||
unary_attr.op_name,
|
||||
unary_attr.scalars_attr,
|
||||
unary_attr.algorithm_attr,
|
||||
postop_name,
|
||||
postop_args,
|
||||
postop_algorithm,
|
||||
)
|
||||
counters["inductor"]["qconv2d_unary_matcher_count"] += 1
|
||||
counters["inductor"]["qconv2d_unary_matcher_nodes"] += len(match.nodes)
|
||||
counters["inductor"]["qconv2d_unary_lower_count"] += 1
|
||||
counters["inductor"]["qconv2d_unary_lower_nodes"] += len(match.nodes)
|
||||
return L[computation_op](*computation_args)
|
||||
|
||||
return qconv
|
||||
|
|
@ -755,115 +758,15 @@ def _register_quantized_conv_binary_lowering(
|
|||
return qconv_binary
|
||||
|
||||
|
||||
def _register_quantization_unary_fusion():
|
||||
from .mkldnn_fusion import _hardswish_fusion, _hardtanh_fusion, _silu_fusion
|
||||
|
||||
class UnaryAttr:
|
||||
def __init__(
|
||||
self, op_name: str, scalars_attr=None, algorithm_attr=None
|
||||
) -> None:
|
||||
self.op_name = op_name
|
||||
self.scalars_attr = scalars_attr if scalars_attr else []
|
||||
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
|
||||
|
||||
def _register_quantization_unary_lowering():
|
||||
# QConv2d
|
||||
for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
|
||||
# Priority 1 to match: QConv2d Unary pattern with int8 output
|
||||
# If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly.
|
||||
# For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant
|
||||
is_bf16 = original_pattern_output_dtype == torch.bfloat16
|
||||
conv_unary_replace_patterns = {
|
||||
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
|
||||
get_dequantize_qconv_pt2e_pattern(1),
|
||||
),
|
||||
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
|
||||
generate_pattern_with_unary(
|
||||
get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
|
||||
),
|
||||
),
|
||||
UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant(
|
||||
_unary_fusion_pattern(
|
||||
_hardtanh_fusion,
|
||||
get_dequantize_qconv_pt2e_pattern(1),
|
||||
1,
|
||||
is_bf16,
|
||||
),
|
||||
with_dtype_convert=is_bf16,
|
||||
),
|
||||
UnaryAttr("hardswish", [], ""): generate_pattern_with_output_quant(
|
||||
_unary_fusion_pattern(
|
||||
_hardswish_fusion,
|
||||
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
|
||||
2,
|
||||
is_bf16,
|
||||
),
|
||||
with_dtype_convert=is_bf16,
|
||||
),
|
||||
UnaryAttr("swish", [], ""): generate_pattern_with_output_quant(
|
||||
_unary_fusion_pattern(
|
||||
_silu_fusion,
|
||||
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
|
||||
2,
|
||||
is_bf16,
|
||||
),
|
||||
with_dtype_convert=is_bf16,
|
||||
),
|
||||
}
|
||||
|
||||
for unary_attr, patterns in conv_unary_replace_patterns.items():
|
||||
# Register qconv2d pattern for ExternKernel Lowering
|
||||
_register_quantized_conv_lowering(
|
||||
patterns,
|
||||
1, # pass_number
|
||||
torch.ops.onednn.qconv2d_pointwise, # computation_op
|
||||
unary_attr, # unary_attr
|
||||
)
|
||||
|
||||
# Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
|
||||
conv_unary_replace_float_out_patterns = {
|
||||
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
|
||||
get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
|
||||
),
|
||||
UnaryAttr("hardtanh", [], ""): _may_generate_pattern_with_dtype_convert(
|
||||
_unary_fusion_pattern(
|
||||
_hardtanh_fusion,
|
||||
get_dequantize_qconv_pt2e_pattern(1),
|
||||
1,
|
||||
is_bf16,
|
||||
),
|
||||
Arg(),
|
||||
is_bf16,
|
||||
),
|
||||
UnaryAttr("hardswish", [], ""): _may_generate_pattern_with_dtype_convert(
|
||||
_unary_fusion_pattern(
|
||||
_hardswish_fusion,
|
||||
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
|
||||
2,
|
||||
is_bf16,
|
||||
),
|
||||
Arg(),
|
||||
is_bf16,
|
||||
),
|
||||
UnaryAttr("swish", [], ""): _may_generate_pattern_with_dtype_convert(
|
||||
_unary_fusion_pattern(
|
||||
_silu_fusion,
|
||||
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
|
||||
2,
|
||||
is_bf16,
|
||||
),
|
||||
Arg(),
|
||||
is_bf16,
|
||||
),
|
||||
}
|
||||
|
||||
for unary_attr, patterns in conv_unary_replace_float_out_patterns.items():
|
||||
# Register qconv2d pattern for ExternKernel Lowering
|
||||
_register_quantized_conv_lowering(
|
||||
patterns,
|
||||
2, # pass_number
|
||||
torch.ops.onednn.qconv2d_pointwise, # computation_op
|
||||
unary_attr, # unary_attr
|
||||
)
|
||||
for users in [1, 2]:
|
||||
qconv_pattern = get_qconv2d_pt2e_pattern(users)
|
||||
_register_quantized_conv_lowering(
|
||||
qconv_pattern,
|
||||
2, # pass_number
|
||||
torch.ops.onednn.qconv2d_pointwise.default, # computation_op
|
||||
)
|
||||
|
||||
# QLinear
|
||||
for x_scale_zp_are_tensors in (False, True):
|
||||
|
|
@ -908,7 +811,7 @@ def _register_quantization_binary_fusion():
|
|||
): generate_pattern_with_output_quant(
|
||||
generate_pattern_with_binary(
|
||||
aten.add.Tensor,
|
||||
get_dequantize_qconv_pt2e_pattern(1),
|
||||
get_qconv2d_pt2e_pattern(1),
|
||||
dequantize_accum_pattern,
|
||||
int8_mixed_bf16_with_inplace_add,
|
||||
swap_inputs=swap_inputs,
|
||||
|
|
@ -920,7 +823,7 @@ def _register_quantization_binary_fusion():
|
|||
generate_pattern_with_unary(
|
||||
generate_pattern_with_binary(
|
||||
aten.add.Tensor,
|
||||
get_dequantize_qconv_pt2e_pattern(1),
|
||||
get_qconv2d_pt2e_pattern(1),
|
||||
dequantize_accum_pattern,
|
||||
int8_mixed_bf16_with_inplace_add,
|
||||
swap_inputs=swap_inputs,
|
||||
|
|
@ -949,7 +852,7 @@ def _register_quantization_binary_fusion():
|
|||
): generate_pattern_with_unary(
|
||||
generate_pattern_with_binary(
|
||||
aten.add.Tensor,
|
||||
get_dequantize_qconv_pt2e_pattern(1),
|
||||
get_qconv2d_pt2e_pattern(1),
|
||||
KeywordArg("accum_after_dequant"),
|
||||
int8_mixed_bf16_with_inplace_add,
|
||||
swap_inputs=swap_inputs,
|
||||
|
|
@ -987,7 +890,7 @@ def _register_quantization_binary_fusion():
|
|||
"sum", 1.0, "none", [], ""
|
||||
): generate_pattern_with_binary(
|
||||
aten.add.Tensor,
|
||||
get_dequantize_qconv_pt2e_pattern(1),
|
||||
get_qconv2d_pt2e_pattern(1),
|
||||
KeywordArg("accum_after_dequant"),
|
||||
int8_mixed_bf16_with_inplace_add,
|
||||
swap_inputs=swap_inputs,
|
||||
|
|
@ -1428,7 +1331,7 @@ def _register_woq_mm_int8_pattern4():
|
|||
|
||||
|
||||
def _register_quantization_lowerings():
|
||||
_register_quantization_unary_fusion()
|
||||
_register_quantization_unary_lowering()
|
||||
_register_quantization_binary_fusion()
|
||||
_register_quantization_maxpool2d()
|
||||
_register_quantization_cat()
|
||||
|
|
@ -2891,6 +2794,211 @@ class PostOpAttr:
|
|||
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
|
||||
|
||||
|
||||
def _register_qconv_post_op_fusion_pass(
|
||||
pattern,
|
||||
pass_number,
|
||||
computation_op,
|
||||
post_op_attr,
|
||||
):
|
||||
@register_freezing_graph_pattern(
|
||||
pattern,
|
||||
extra_check=_is_valid_quantized_conv2d_optimization_pattern(),
|
||||
pass_number=pass_number,
|
||||
)
|
||||
def qconv(match: Match, *args, **kwargs):
|
||||
# Activation QParams
|
||||
x, x_scale, x_zp = (
|
||||
kwargs["x"],
|
||||
kwargs["x_scale"],
|
||||
kwargs["x_zp"],
|
||||
)
|
||||
# Weight QParams
|
||||
packed_weight, w_scale, w_zp = (
|
||||
kwargs["packed_weight"],
|
||||
kwargs["w_scale"],
|
||||
kwargs["w_zp"],
|
||||
)
|
||||
# Conv Params
|
||||
b, stride, padding, dilation, groups = (
|
||||
kwargs["b"],
|
||||
kwargs["stride"],
|
||||
kwargs["padding"],
|
||||
kwargs["dilation"],
|
||||
kwargs["groups"],
|
||||
)
|
||||
output_dtype = _get_pattern_output_dtype(match)
|
||||
assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16]
|
||||
# Output QParams
|
||||
o_inv_scale = (
|
||||
kwargs["o_inv_scale"]
|
||||
if (output_dtype == torch.uint8 or output_dtype == torch.int8)
|
||||
else 1.0
|
||||
)
|
||||
o_zero_point = (
|
||||
kwargs["o_zp"]
|
||||
if (output_dtype == torch.uint8 or output_dtype == torch.int8)
|
||||
else 0
|
||||
)
|
||||
assert (
|
||||
kwargs["postop_name"] == "none"
|
||||
) # Expected no post op fused in weight prepack phase
|
||||
if post_op_attr.unary_op_name == "hardtanh":
|
||||
min_value = kwargs.get("min_value")
|
||||
max_value = kwargs.get("max_value")
|
||||
post_op_attr.scalars_attr = [min_value, max_value]
|
||||
|
||||
out_node = match.output_node()
|
||||
with match.graph.inserting_before(out_node):
|
||||
computation_args = (
|
||||
x,
|
||||
x_scale,
|
||||
x_zp,
|
||||
packed_weight,
|
||||
w_scale,
|
||||
w_zp,
|
||||
b,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
o_inv_scale,
|
||||
o_zero_point,
|
||||
output_dtype,
|
||||
post_op_attr.unary_op_name,
|
||||
post_op_attr.scalars_attr,
|
||||
post_op_attr.algorithm_attr,
|
||||
)
|
||||
new_conv_node = match.graph.call_function(
|
||||
computation_op, args=computation_args
|
||||
)
|
||||
out_node.replace_all_uses_with(new_conv_node)
|
||||
new_conv_node.meta.update(out_node.meta)
|
||||
for node in reversed(match.nodes):
|
||||
match.graph.erase_node(node)
|
||||
counters["inductor"]["qconv2d_unary_matcher_count"] += 1
|
||||
counters["inductor"]["qconv2d_unary_matcher_nodes"] += len(match.nodes)
|
||||
|
||||
return qconv
|
||||
|
||||
|
||||
def _register_qconv_unary_fusion():
|
||||
from .mkldnn_fusion import _hardswish_fusion, _hardtanh_fusion, _silu_fusion
|
||||
|
||||
for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
|
||||
# Priority 1 to match: QConv2d Unary pattern with int8 output
|
||||
# If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly.
|
||||
# For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant
|
||||
is_bf16 = original_pattern_output_dtype == torch.bfloat16
|
||||
conv_unary_replace_patterns = {
|
||||
PostOpAttr(
|
||||
"none", None, "none", [], ""
|
||||
): generate_pattern_with_output_quant(
|
||||
get_qconv2d_pt2e_pattern(1),
|
||||
),
|
||||
PostOpAttr(
|
||||
"none", None, "relu", [], ""
|
||||
): generate_pattern_with_output_quant(
|
||||
generate_pattern_with_unary(
|
||||
get_qconv2d_pt2e_pattern(1), aten.relu.default
|
||||
),
|
||||
),
|
||||
PostOpAttr(
|
||||
"none", None, "hardtanh", [], ""
|
||||
): generate_pattern_with_output_quant(
|
||||
_unary_fusion_pattern(
|
||||
_hardtanh_fusion,
|
||||
get_qconv2d_pt2e_pattern(1),
|
||||
1,
|
||||
is_bf16,
|
||||
),
|
||||
with_dtype_convert=is_bf16,
|
||||
),
|
||||
PostOpAttr(
|
||||
"none", None, "hardswish", [], ""
|
||||
): generate_pattern_with_output_quant(
|
||||
_unary_fusion_pattern(
|
||||
_hardswish_fusion,
|
||||
get_qconv2d_pt2e_pattern(1 if is_bf16 else 2),
|
||||
2,
|
||||
is_bf16,
|
||||
),
|
||||
with_dtype_convert=is_bf16,
|
||||
),
|
||||
PostOpAttr(
|
||||
"none", None, "swish", [], ""
|
||||
): generate_pattern_with_output_quant(
|
||||
_unary_fusion_pattern(
|
||||
_silu_fusion,
|
||||
get_qconv2d_pt2e_pattern(1 if is_bf16 else 2),
|
||||
2,
|
||||
is_bf16,
|
||||
),
|
||||
with_dtype_convert=is_bf16,
|
||||
),
|
||||
}
|
||||
|
||||
for unary_attr, patterns in conv_unary_replace_patterns.items():
|
||||
# Register qconv2d pattern for ExternKernel Lowering
|
||||
_register_qconv_post_op_fusion_pass(
|
||||
patterns,
|
||||
3, # pass_number
|
||||
torch.ops.onednn.qconv2d_pointwise.default, # computation_op
|
||||
unary_attr, # unary_attr
|
||||
)
|
||||
|
||||
# Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
|
||||
conv_unary_replace_float_out_patterns = {
|
||||
PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary(
|
||||
get_qconv2d_pt2e_pattern(1), aten.relu.default
|
||||
),
|
||||
PostOpAttr(
|
||||
"none", None, "hardtanh", [], ""
|
||||
): _may_generate_pattern_with_dtype_convert(
|
||||
_unary_fusion_pattern(
|
||||
_hardtanh_fusion,
|
||||
get_qconv2d_pt2e_pattern(1),
|
||||
1,
|
||||
is_bf16,
|
||||
),
|
||||
Arg(),
|
||||
is_bf16,
|
||||
),
|
||||
PostOpAttr(
|
||||
"none", None, "hardswish", [], ""
|
||||
): _may_generate_pattern_with_dtype_convert(
|
||||
_unary_fusion_pattern(
|
||||
_hardswish_fusion,
|
||||
get_qconv2d_pt2e_pattern(1 if is_bf16 else 2),
|
||||
2,
|
||||
is_bf16,
|
||||
),
|
||||
Arg(),
|
||||
is_bf16,
|
||||
),
|
||||
PostOpAttr(
|
||||
"none", None, "swish", [], ""
|
||||
): _may_generate_pattern_with_dtype_convert(
|
||||
_unary_fusion_pattern(
|
||||
_silu_fusion,
|
||||
get_qconv2d_pt2e_pattern(1 if is_bf16 else 2),
|
||||
2,
|
||||
is_bf16,
|
||||
),
|
||||
Arg(),
|
||||
is_bf16,
|
||||
),
|
||||
}
|
||||
|
||||
for unary_attr, patterns in conv_unary_replace_float_out_patterns.items():
|
||||
# Register qconv2d pattern for ExternKernel Lowering
|
||||
_register_qconv_post_op_fusion_pass(
|
||||
patterns,
|
||||
4, # pass_number
|
||||
torch.ops.onednn.qconv2d_pointwise.default, # computation_op
|
||||
unary_attr, # unary_attr
|
||||
)
|
||||
|
||||
|
||||
def _register_qlinear_post_op_fusion_pass(
|
||||
pattern,
|
||||
pass_number,
|
||||
|
|
@ -3338,6 +3446,7 @@ def _register_quantization_weight_pack_pass():
|
|||
# Step 5: QLinear post op Fusion
|
||||
if not torch.ops.mkldnn._is_mkldnn_acl_supported():
|
||||
# skip fusion on ARM
|
||||
_register_qconv_unary_fusion()
|
||||
_register_qlinear_unary_fusion()
|
||||
_register_qlinear_binary_fusion()
|
||||
|
||||
|
|
|
|||
|
|
@ -2428,7 +2428,7 @@ if torch._C._has_mkldnn:
|
|||
groups,
|
||||
None,
|
||||
)
|
||||
assert output_dtype in [torch.float32, torch.bfloat16]
|
||||
assert output_dtype in [torch.float32, torch.bfloat16, torch.uint8]
|
||||
out = x.new_empty(shape_out, dtype=output_dtype)
|
||||
out = out.to(memory_format=torch.channels_last)
|
||||
return out
|
||||
|
|
|
|||
Loading…
Reference in a new issue