mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[quant][pt2e][xnnpack_quantizer] add support for linear_relu (#117052)
Add support for linear_relu annotation for XNNPACKQuantizer, this allows the input to linear and the output to relu to share the same quantization parameter.s Differential Revision: [D52574086](https://our.internmc.facebook.com/intern/diff/D52574086/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/117052 Approved by: https://github.com/jerryzh168, https://github.com/digantdesai
This commit is contained in:
parent
4f3d698cac
commit
d2033a0639
4 changed files with 108 additions and 0 deletions
|
|
@ -153,6 +153,40 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
|
|||
qconfig_mapping,
|
||||
)
|
||||
|
||||
def test_linear_relu(self):
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
quantizer.set_global(quantization_config)
|
||||
m_eager = TestHelperModules.LinearReluModel().eval()
|
||||
|
||||
# Test with 2d inputs
|
||||
example_inputs_2d = (torch.randn(1, 5),)
|
||||
example_inputs_3d = (torch.randn(1, 2, 5),)
|
||||
example_inputs_4d = (torch.randn(1, 2, 3, 5),)
|
||||
|
||||
node_occurrence = {
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
# There should not be extra quantize_per_tensor or dequantize_per_tensors for relu
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
qconfig = default_per_channel_symmetric_qnnpack_qconfig
|
||||
qconfig_mapping = QConfigMapping().set_global(qconfig)
|
||||
for example_inputs in [example_inputs_2d, example_inputs_3d, example_inputs_4d]:
|
||||
self._test_quantizer(
|
||||
m_eager,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
[], # node_list
|
||||
False, # executorch_backend_config() does not fuse linear-relu
|
||||
qconfig_mapping,
|
||||
)
|
||||
|
||||
|
||||
def test_conv_linear_no_permute(self):
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
|
|
|
|||
|
|
@ -260,7 +260,9 @@ class XNNPACKQuantizer(Quantizer):
|
|||
]
|
||||
|
||||
# static quantization ops (both PTQ and QAT)
|
||||
# Preserve the order that fusions come before singular ops
|
||||
STATIC_OPS = [
|
||||
"linear_relu",
|
||||
"linear",
|
||||
"conv_relu",
|
||||
"conv",
|
||||
|
|
|
|||
|
|
@ -218,6 +218,68 @@ def _annotate_linear(
|
|||
return annotated_partitions
|
||||
|
||||
|
||||
@register_annotator("linear_relu")
|
||||
def _annotate_linear_relu(
|
||||
gm: torch.fx.GraphModule,
|
||||
quantization_config: Optional[QuantizationConfig],
|
||||
filter_fn: Optional[Callable[[Node], bool]] = None,
|
||||
) -> Optional[List[List[Node]]]:
|
||||
annotated_partitions = []
|
||||
input_act_qspec = get_input_act_qspec(quantization_config)
|
||||
output_act_qspec = get_output_act_qspec(quantization_config)
|
||||
weight_qspec = get_weight_qspec(quantization_config)
|
||||
bias_qspec = get_bias_qspec(quantization_config)
|
||||
for node in gm.graph.nodes:
|
||||
if node.op != "call_function" or node.target not in [
|
||||
torch.ops.aten.relu.default,
|
||||
torch.ops.aten.relu_.default,
|
||||
]:
|
||||
continue
|
||||
relu_node = node
|
||||
maybe_linear_node = node.args[0]
|
||||
if (
|
||||
not isinstance(maybe_linear_node, Node)
|
||||
or maybe_linear_node.op != "call_function"
|
||||
or maybe_linear_node.target != torch.ops.aten.linear.default
|
||||
):
|
||||
continue
|
||||
|
||||
linear_node = maybe_linear_node
|
||||
input_qspec_map = {}
|
||||
input_act = linear_node.args[0]
|
||||
assert isinstance(input_act, Node)
|
||||
input_qspec_map[input_act] = input_act_qspec
|
||||
|
||||
weight = linear_node.args[1]
|
||||
assert isinstance(weight, Node)
|
||||
input_qspec_map[weight] = weight_qspec
|
||||
|
||||
# adding weight node to the partition as well
|
||||
partition = [relu_node, linear_node, weight]
|
||||
bias = linear_node.args[2] if len(linear_node.args) > 2 else None
|
||||
if isinstance(bias, Node):
|
||||
input_qspec_map[bias] = bias_qspec
|
||||
partition.append(bias)
|
||||
|
||||
if _is_annotated(partition):
|
||||
continue
|
||||
|
||||
if filter_fn and any(not filter_fn(n) for n in partition):
|
||||
continue
|
||||
|
||||
linear_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
_annotated=True,
|
||||
)
|
||||
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
output_qspec=output_act_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
_mark_nodes_as_annotated(partition)
|
||||
annotated_partitions.append(partition)
|
||||
return annotated_partitions
|
||||
|
||||
|
||||
@register_annotator("conv")
|
||||
def _annotate_conv(
|
||||
gm: torch.fx.GraphModule,
|
||||
|
|
|
|||
|
|
@ -2801,3 +2801,13 @@ class TestHelperModules:
|
|||
|
||||
def example_inputs(self):
|
||||
return (torch.randn(2, 4, 10, 10),)
|
||||
|
||||
class LinearReluModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.fc(x))
|
||||
return x
|
||||
|
|
|
|||
Loading…
Reference in a new issue