From d2033a0639284c9e5df20a58a7fa223a28aedb49 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 9 Jan 2024 12:11:53 -0800 Subject: [PATCH] [quant][pt2e][xnnpack_quantizer] add support for linear_relu (#117052) Add support for linear_relu annotation for XNNPACKQuantizer, this allows the input to linear and the output to relu to share the same quantization parameter.s Differential Revision: [D52574086](https://our.internmc.facebook.com/intern/diff/D52574086/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/117052 Approved by: https://github.com/jerryzh168, https://github.com/digantdesai --- .../pt2e/test_xnnpack_quantizer.py | 34 ++++++++++ .../quantizer/xnnpack_quantizer.py | 2 + .../quantizer/xnnpack_quantizer_utils.py | 62 +++++++++++++++++++ .../testing/_internal/common_quantization.py | 10 +++ 4 files changed, 108 insertions(+) diff --git a/test/quantization/pt2e/test_xnnpack_quantizer.py b/test/quantization/pt2e/test_xnnpack_quantizer.py index fb95742f71d..021261621c1 100644 --- a/test/quantization/pt2e/test_xnnpack_quantizer.py +++ b/test/quantization/pt2e/test_xnnpack_quantizer.py @@ -153,6 +153,40 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase): qconfig_mapping, ) + def test_linear_relu(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.LinearReluModel().eval() + + # Test with 2d inputs + example_inputs_2d = (torch.randn(1, 5),) + example_inputs_3d = (torch.randn(1, 2, 5),) + example_inputs_4d = (torch.randn(1, 2, 3, 5),) + + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + # There should not be extra quantize_per_tensor or dequantize_per_tensors for relu + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + for example_inputs in [example_inputs_2d, example_inputs_3d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], # node_list + False, # executorch_backend_config() does not fuse linear-relu + qconfig_mapping, + ) + + def test_conv_linear_no_permute(self): quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True) diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index d01164489cf..c29d4f69d58 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -260,7 +260,9 @@ class XNNPACKQuantizer(Quantizer): ] # static quantization ops (both PTQ and QAT) + # Preserve the order that fusions come before singular ops STATIC_OPS = [ + "linear_relu", "linear", "conv_relu", "conv", diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index 9763cb436b5..a6831161565 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -218,6 +218,68 @@ def _annotate_linear( return annotated_partitions +@register_annotator("linear_relu") +def _annotate_linear_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + annotated_partitions = [] + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + bias_qspec = get_bias_qspec(quantization_config) + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = node + maybe_linear_node = node.args[0] + if ( + not isinstance(maybe_linear_node, Node) + or maybe_linear_node.op != "call_function" + or maybe_linear_node.target != torch.ops.aten.linear.default + ): + continue + + linear_node = maybe_linear_node + input_qspec_map = {} + input_act = linear_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = input_act_qspec + + weight = linear_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = weight_qspec + + # adding weight node to the partition as well + partition = [relu_node, linear_node, weight] + bias = linear_node.args[2] if len(linear_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = bias_qspec + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + linear_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + @register_annotator("conv") def _annotate_conv( gm: torch.fx.GraphModule, diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 66bc4d563aa..94295216211 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -2801,3 +2801,13 @@ class TestHelperModules: def example_inputs(self): return (torch.randn(2, 4, 10, 10),) + + class LinearReluModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc(x)) + return x