[quant][pt2e][xnnpack_quantizer] add support for linear_relu (#117052)

Add support for linear_relu annotation for XNNPACKQuantizer, this allows the input to linear and the output to relu to share the same quantization parameter.s

Differential Revision: [D52574086](https://our.internmc.facebook.com/intern/diff/D52574086/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117052
Approved by: https://github.com/jerryzh168, https://github.com/digantdesai
This commit is contained in:
Max Ren 2024-01-09 12:11:53 -08:00 committed by PyTorch MergeBot
parent 4f3d698cac
commit d2033a0639
4 changed files with 108 additions and 0 deletions

View file

@ -153,6 +153,40 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
qconfig_mapping,
)
def test_linear_relu(self):
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
m_eager = TestHelperModules.LinearReluModel().eval()
# Test with 2d inputs
example_inputs_2d = (torch.randn(1, 5),)
example_inputs_3d = (torch.randn(1, 2, 5),)
example_inputs_4d = (torch.randn(1, 2, 3, 5),)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
# There should not be extra quantize_per_tensor or dequantize_per_tensors for relu
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
qconfig = default_per_channel_symmetric_qnnpack_qconfig
qconfig_mapping = QConfigMapping().set_global(qconfig)
for example_inputs in [example_inputs_2d, example_inputs_3d, example_inputs_4d]:
self._test_quantizer(
m_eager,
example_inputs,
quantizer,
node_occurrence,
[], # node_list
False, # executorch_backend_config() does not fuse linear-relu
qconfig_mapping,
)
def test_conv_linear_no_permute(self):
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)

View file

@ -260,7 +260,9 @@ class XNNPACKQuantizer(Quantizer):
]
# static quantization ops (both PTQ and QAT)
# Preserve the order that fusions come before singular ops
STATIC_OPS = [
"linear_relu",
"linear",
"conv_relu",
"conv",

View file

@ -218,6 +218,68 @@ def _annotate_linear(
return annotated_partitions
@register_annotator("linear_relu")
def _annotate_linear_relu(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
annotated_partitions = []
input_act_qspec = get_input_act_qspec(quantization_config)
output_act_qspec = get_output_act_qspec(quantization_config)
weight_qspec = get_weight_qspec(quantization_config)
bias_qspec = get_bias_qspec(quantization_config)
for node in gm.graph.nodes:
if node.op != "call_function" or node.target not in [
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
]:
continue
relu_node = node
maybe_linear_node = node.args[0]
if (
not isinstance(maybe_linear_node, Node)
or maybe_linear_node.op != "call_function"
or maybe_linear_node.target != torch.ops.aten.linear.default
):
continue
linear_node = maybe_linear_node
input_qspec_map = {}
input_act = linear_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = input_act_qspec
weight = linear_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = weight_qspec
# adding weight node to the partition as well
partition = [relu_node, linear_node, weight]
bias = linear_node.args[2] if len(linear_node.args) > 2 else None
if isinstance(bias, Node):
input_qspec_map[bias] = bias_qspec
partition.append(bias)
if _is_annotated(partition):
continue
if filter_fn and any(not filter_fn(n) for n in partition):
continue
linear_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=output_act_qspec,
_annotated=True,
)
_mark_nodes_as_annotated(partition)
annotated_partitions.append(partition)
return annotated_partitions
@register_annotator("conv")
def _annotate_conv(
gm: torch.fx.GraphModule,

View file

@ -2801,3 +2801,13 @@ class TestHelperModules:
def example_inputs(self):
return (torch.randn(2, 4, 10, 10),)
class LinearReluModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.relu(self.fc(x))
return x