pytorch/test/quantization/pt2e/test_metadata_porting.py
Kimish Patel 9e2af971fc [Quantization] Add "quantization_tag" as metadata to fx proxy (#108764)
Summary:
In order to make sure that quantization_tag is preserved through second
stage export, this PR adds it as a special metadata that should be
preserved.

Since quantization in export path will work on top of pre dispatch
graph, subsequent post dispatch op decomposition, will decompose ops
that quant workflow tagged. In order to make sure that the patterns
identified by quantizer, remains identifiable, even after decompositions
are applied, we must preserve "quantization_tag".

This enables backend delegates, that quantized a model for specific
backend, to be able to identify "quantized" patterns.

Test Plan:
metadata porting tests

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D49056259](https://our.internmc.facebook.com/intern/diff/D49056259)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108764
Approved by: https://github.com/tugsbayasgalan, https://github.com/jerryzh168
2023-11-01 21:41:58 +00:00

458 lines
18 KiB
Python

# Owner(s): ["oncall: quantization"]
import copy
import unittest
from typing import List
import torch
import torch._export as export
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
from torch.fx import Node
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_utils import IS_WINDOWS
class TestHelperModules:
class Conv2dWithObsSharingOps(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.hardtanh = torch.nn.Hardtanh()
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
x = self.conv(x)
x = self.adaptive_avg_pool2d(x)
x = self.hardtanh(x)
x = x.view(-1, 3)
x = self.linear(x)
return x
def _tag_partitions(
backend_name: str, op_name: str, annotated_partitions: List[List[Node]]
):
for index, partition_nodes in enumerate(annotated_partitions):
tag_name = backend_name + "_" + op_name + "_" + str(index)
for node in partition_nodes:
assert "quantization_tag" not in node.meta, f"{node} is already tagged"
node.meta["quantization_tag"] = tag_name
_QUANT_OPS = {
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.quantized_decomposed.choose_qparams.tensor,
}
# TODO: rename to TestPortMetadataPass to align with the util name?
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
class TestMetaDataPorting(QuantizationTestCase):
def _test_quant_tag_preservation_through_decomp(
self, model, example_inputs, from_node_to_tags
):
ep = export.export(model, example_inputs)
found_tags = True
not_found_nodes = ""
for from_node, tag in from_node_to_tags.items():
for n in ep.graph_module.graph.nodes:
from_node_meta = n.meta.get("from_node", None)
if from_node_meta is None:
continue
if not isinstance(from_node_meta, list):
raise ValueError(
f"from_node metadata is of type {type(from_node_meta)}, but expected list"
)
for meta in from_node_meta:
node_target = meta[1]
if node_target == from_node:
node_tag = n.meta.get("quantization_tag", None)
if node_tag is None or tag != node_tag:
not_found_nodes += str(n.target) + ", "
found_tags = False
break
if not found_tags:
break
self.assertTrue(
found_tags,
f"Decomposition did not preserve quantization tag for {not_found_nodes}",
)
def _test_metadata_porting(
self,
model,
example_inputs,
quantizer,
node_tags=None,
) -> torch.fx.GraphModule:
m_eager = model.eval()
# program capture
m = copy.deepcopy(m_eager)
m = export.capture_pre_autograd_graph(
m,
example_inputs,
)
m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m, fold_quantize=True)
pt2_quant_output = m(*example_inputs)
recorded_node_tags = {}
for n in m.graph.nodes:
if "quantization_tag" not in n.meta:
continue
if n.op == "call_function" and n.target in _QUANT_OPS:
key = n.target
elif n.op == "get_attr":
key = "get_attr"
else:
continue
if key not in recorded_node_tags:
recorded_node_tags[key] = set()
if (
n.op == "call_function"
and n.meta["quantization_tag"] in recorded_node_tags[key]
):
raise ValueError(
f"{key} {n.format_node()} has tag {n.meta['quantization_tag']} that "
"is associated with another node of the same type"
)
recorded_node_tags[key].add(n.meta["quantization_tag"])
self.assertEqual(set(recorded_node_tags.keys()), set(node_tags.keys()))
for k, v in recorded_node_tags.items():
self.assertEqual(v, node_tags[k])
return m
def test_simple_metadata_porting(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Check quantization tags on conv2d, avgpool and linear are correctly set
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
annotated_partitions = OP_TO_ANNOTATOR["linear"](
gm, quantization_config
)
_tag_partitions(backend_string, "linear", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
_tag_partitions(backend_string, "conv2d", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
gm, quantization_config
)
_tag_partitions(
backend_string, "adaptive_avg_pool2d", annotated_partitions
)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
get_attr_tags = {
"BackendA_conv2d_0",
"BackendA_linear_0",
}
quantize_per_tensor_tags = {
"BackendA_conv2d_0",
"BackendA_adaptive_avg_pool2d_0",
"BackendA_linear_0",
}
dequantize_per_tensor_tags = {
"BackendA_adaptive_avg_pool2d_0",
"BackendA_conv2d_0",
"BackendA_linear_0",
}
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
node_tags = {
"get_attr": get_attr_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
}
m = self._test_metadata_porting(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
node_tags,
)
from_node_to_tags = {
torch.ops.aten.adaptive_avg_pool2d.default: "BackendA_adaptive_avg_pool2d_0",
torch.ops.aten.linear.default: "BackendA_linear_0",
}
self._test_quant_tag_preservation_through_decomp(
m, example_inputs, from_node_to_tags
)
def test_metadata_porting_with_no_quant_inbetween(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Dont quantize avgpool
Check quantization tags on conv2d and linear are correctly set
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
annotated_partitions = OP_TO_ANNOTATOR["linear"](
gm, quantization_config
)
_tag_partitions(backend_string, "linear", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
_tag_partitions(backend_string, "conv2d", annotated_partitions)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
get_attr_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
quantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
dequantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
node_tags = {
"get_attr": get_attr_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
}
self._test_metadata_porting(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
node_tags,
)
@unittest.skip("Temporarily disabled")
def test_metadata_porting_for_dq(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Quantize all except linear.
Quantize linear with dynamic quantization
Check quantization tags on conv2d, avgpool and linear are correctly set
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
# static quantiazation
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
_tag_partitions(backend_string, "conv2d", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
gm, quantization_config
)
_tag_partitions(
backend_string, "adaptive_avg_pool2d", annotated_partitions
)
# dynamic quantization
quantization_config_dynamic = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
annotated_partitions = OP_TO_ANNOTATOR["linear"](
gm, quantization_config_dynamic
)
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
# TODO: add get_attr_tags when the test is re-enabled
get_attr_tags = {}
quantize_per_tensor_tags = {
"BackendA_conv2d_0",
"BackendA_adaptive_avg_pool2d_0",
}
quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
choose_qparams_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
dequantize_per_tensor_tags = {
"BackendA_adaptive_avg_pool2d_0",
"BackendA_conv2d_0",
}
dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
dequantize_per_channel_tags = {
"BackendA_conv2d_0",
"BackendA_linear_dynamic_0",
}
node_tags = {
"get_attr": get_attr_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tensor_tags,
}
self._test_metadata_porting(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
node_tags,
)
def test_metadata_porting_for_two_dq(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Quantize linear and conv with dynamic quantization
Check quantization tags on conv2d, avgpool and linear are correctly set
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
# dynamic quantization
quantization_config_dynamic = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
annotated_partitions = OP_TO_ANNOTATOR["conv"](
gm, quantization_config_dynamic
)
_tag_partitions(backend_string, "conv2d_dynamic", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["linear"](
gm, quantization_config_dynamic
)
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
get_attr_tags = {
"BackendA_conv2d_dynamic_0",
"BackendA_linear_dynamic_0",
}
choose_qparams_tensor_tags = {
"BackendA_conv2d_dynamic_0",
"BackendA_linear_dynamic_0",
}
quantize_per_tensor_tensor_tags = {
"BackendA_conv2d_dynamic_0",
"BackendA_linear_dynamic_0",
}
dequantize_per_tensor_tensor_tags = {
"BackendA_conv2d_dynamic_0",
"BackendA_linear_dynamic_0",
}
dequantize_per_channel_tags = {
"BackendA_conv2d_dynamic_0",
"BackendA_linear_dynamic_0",
}
node_tags = {
"get_attr": get_attr_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags,
}
self._test_metadata_porting(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
node_tags,
)
def test_metadata_porting_for_dq_no_static_q(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Dont quantize anything except linear.
Quantize linear with dynamic quantization
Check quantization tags on conv2d, avgpool and linear are correctly set
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
# dynamic quantization
quantization_config_dynamic = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
annotated_partitions = OP_TO_ANNOTATOR["linear"](
gm, quantization_config_dynamic
)
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
get_attr_tags = {"BackendA_linear_dynamic_0"}
choose_qparams_tensor_tags = {"BackendA_linear_dynamic_0"}
quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
dequantize_per_channel_tags = {"BackendA_linear_dynamic_0"}
node_tags = {
"get_attr": get_attr_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags,
}
self._test_metadata_porting(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
node_tags,
)
def test_no_metadata_porting(self):
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
node_tags = {}
m = self._test_metadata_porting(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
node_tags,
)
from_node_to_tags = {}
self._test_quant_tag_preservation_through_decomp(
m, example_inputs, from_node_to_tags
)