mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: In order to make sure that quantization_tag is preserved through second stage export, this PR adds it as a special metadata that should be preserved. Since quantization in export path will work on top of pre dispatch graph, subsequent post dispatch op decomposition, will decompose ops that quant workflow tagged. In order to make sure that the patterns identified by quantizer, remains identifiable, even after decompositions are applied, we must preserve "quantization_tag". This enables backend delegates, that quantized a model for specific backend, to be able to identify "quantized" patterns. Test Plan: metadata porting tests Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D49056259](https://our.internmc.facebook.com/intern/diff/D49056259) Pull Request resolved: https://github.com/pytorch/pytorch/pull/108764 Approved by: https://github.com/tugsbayasgalan, https://github.com/jerryzh168
458 lines
18 KiB
Python
458 lines
18 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
import copy
|
|
|
|
import unittest
|
|
from typing import List
|
|
|
|
import torch
|
|
import torch._export as export
|
|
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
|
from torch.ao.quantization.quantizer import Quantizer
|
|
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
|
get_symmetric_quantization_config,
|
|
)
|
|
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
|
|
|
|
from torch.fx import Node
|
|
|
|
from torch.testing._internal.common_quantization import QuantizationTestCase
|
|
from torch.testing._internal.common_utils import IS_WINDOWS
|
|
|
|
|
|
class TestHelperModules:
|
|
class Conv2dWithObsSharingOps(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3)
|
|
self.hardtanh = torch.nn.Hardtanh()
|
|
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.adaptive_avg_pool2d(x)
|
|
x = self.hardtanh(x)
|
|
x = x.view(-1, 3)
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
|
|
def _tag_partitions(
|
|
backend_name: str, op_name: str, annotated_partitions: List[List[Node]]
|
|
):
|
|
for index, partition_nodes in enumerate(annotated_partitions):
|
|
tag_name = backend_name + "_" + op_name + "_" + str(index)
|
|
for node in partition_nodes:
|
|
assert "quantization_tag" not in node.meta, f"{node} is already tagged"
|
|
node.meta["quantization_tag"] = tag_name
|
|
|
|
|
|
_QUANT_OPS = {
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
|
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
|
torch.ops.quantized_decomposed.choose_qparams.tensor,
|
|
}
|
|
|
|
|
|
# TODO: rename to TestPortMetadataPass to align with the util name?
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
|
|
class TestMetaDataPorting(QuantizationTestCase):
|
|
def _test_quant_tag_preservation_through_decomp(
|
|
self, model, example_inputs, from_node_to_tags
|
|
):
|
|
ep = export.export(model, example_inputs)
|
|
found_tags = True
|
|
not_found_nodes = ""
|
|
for from_node, tag in from_node_to_tags.items():
|
|
for n in ep.graph_module.graph.nodes:
|
|
from_node_meta = n.meta.get("from_node", None)
|
|
if from_node_meta is None:
|
|
continue
|
|
if not isinstance(from_node_meta, list):
|
|
raise ValueError(
|
|
f"from_node metadata is of type {type(from_node_meta)}, but expected list"
|
|
)
|
|
for meta in from_node_meta:
|
|
node_target = meta[1]
|
|
if node_target == from_node:
|
|
node_tag = n.meta.get("quantization_tag", None)
|
|
if node_tag is None or tag != node_tag:
|
|
not_found_nodes += str(n.target) + ", "
|
|
found_tags = False
|
|
break
|
|
if not found_tags:
|
|
break
|
|
self.assertTrue(
|
|
found_tags,
|
|
f"Decomposition did not preserve quantization tag for {not_found_nodes}",
|
|
)
|
|
|
|
def _test_metadata_porting(
|
|
self,
|
|
model,
|
|
example_inputs,
|
|
quantizer,
|
|
node_tags=None,
|
|
) -> torch.fx.GraphModule:
|
|
m_eager = model.eval()
|
|
|
|
# program capture
|
|
m = copy.deepcopy(m_eager)
|
|
m = export.capture_pre_autograd_graph(
|
|
m,
|
|
example_inputs,
|
|
)
|
|
|
|
m = prepare_pt2e(m, quantizer)
|
|
# Calibrate
|
|
m(*example_inputs)
|
|
m = convert_pt2e(m, fold_quantize=True)
|
|
|
|
pt2_quant_output = m(*example_inputs)
|
|
recorded_node_tags = {}
|
|
for n in m.graph.nodes:
|
|
if "quantization_tag" not in n.meta:
|
|
continue
|
|
if n.op == "call_function" and n.target in _QUANT_OPS:
|
|
key = n.target
|
|
elif n.op == "get_attr":
|
|
key = "get_attr"
|
|
else:
|
|
continue
|
|
|
|
if key not in recorded_node_tags:
|
|
recorded_node_tags[key] = set()
|
|
|
|
if (
|
|
n.op == "call_function"
|
|
and n.meta["quantization_tag"] in recorded_node_tags[key]
|
|
):
|
|
raise ValueError(
|
|
f"{key} {n.format_node()} has tag {n.meta['quantization_tag']} that "
|
|
"is associated with another node of the same type"
|
|
)
|
|
recorded_node_tags[key].add(n.meta["quantization_tag"])
|
|
|
|
self.assertEqual(set(recorded_node_tags.keys()), set(node_tags.keys()))
|
|
for k, v in recorded_node_tags.items():
|
|
self.assertEqual(v, node_tags[k])
|
|
return m
|
|
|
|
def test_simple_metadata_porting(self):
|
|
"""
|
|
Model under test
|
|
conv2d -> avgpool -> hardtanh -> linear
|
|
Check quantization tags on conv2d, avgpool and linear are correctly set
|
|
"""
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
backend_string = "BackendA"
|
|
quantization_config = get_symmetric_quantization_config(
|
|
is_per_channel=True
|
|
)
|
|
annotated_partitions = OP_TO_ANNOTATOR["linear"](
|
|
gm, quantization_config
|
|
)
|
|
_tag_partitions(backend_string, "linear", annotated_partitions)
|
|
annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
|
|
_tag_partitions(backend_string, "conv2d", annotated_partitions)
|
|
annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
|
|
gm, quantization_config
|
|
)
|
|
_tag_partitions(
|
|
backend_string, "adaptive_avg_pool2d", annotated_partitions
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
get_attr_tags = {
|
|
"BackendA_conv2d_0",
|
|
"BackendA_linear_0",
|
|
}
|
|
quantize_per_tensor_tags = {
|
|
"BackendA_conv2d_0",
|
|
"BackendA_adaptive_avg_pool2d_0",
|
|
"BackendA_linear_0",
|
|
}
|
|
dequantize_per_tensor_tags = {
|
|
"BackendA_adaptive_avg_pool2d_0",
|
|
"BackendA_conv2d_0",
|
|
"BackendA_linear_0",
|
|
}
|
|
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
|
node_tags = {
|
|
"get_attr": get_attr_tags,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
|
}
|
|
m = self._test_metadata_porting(
|
|
TestHelperModules.Conv2dWithObsSharingOps(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_tags,
|
|
)
|
|
|
|
from_node_to_tags = {
|
|
torch.ops.aten.adaptive_avg_pool2d.default: "BackendA_adaptive_avg_pool2d_0",
|
|
torch.ops.aten.linear.default: "BackendA_linear_0",
|
|
}
|
|
self._test_quant_tag_preservation_through_decomp(
|
|
m, example_inputs, from_node_to_tags
|
|
)
|
|
|
|
def test_metadata_porting_with_no_quant_inbetween(self):
|
|
"""
|
|
Model under test
|
|
conv2d -> avgpool -> hardtanh -> linear
|
|
Dont quantize avgpool
|
|
Check quantization tags on conv2d and linear are correctly set
|
|
"""
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
backend_string = "BackendA"
|
|
quantization_config = get_symmetric_quantization_config(
|
|
is_per_channel=True
|
|
)
|
|
annotated_partitions = OP_TO_ANNOTATOR["linear"](
|
|
gm, quantization_config
|
|
)
|
|
_tag_partitions(backend_string, "linear", annotated_partitions)
|
|
annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
|
|
_tag_partitions(backend_string, "conv2d", annotated_partitions)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
get_attr_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
|
quantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
|
dequantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
|
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
|
node_tags = {
|
|
"get_attr": get_attr_tags,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
|
}
|
|
self._test_metadata_porting(
|
|
TestHelperModules.Conv2dWithObsSharingOps(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_tags,
|
|
)
|
|
|
|
@unittest.skip("Temporarily disabled")
|
|
def test_metadata_porting_for_dq(self):
|
|
"""
|
|
Model under test
|
|
conv2d -> avgpool -> hardtanh -> linear
|
|
Quantize all except linear.
|
|
Quantize linear with dynamic quantization
|
|
Check quantization tags on conv2d, avgpool and linear are correctly set
|
|
"""
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
backend_string = "BackendA"
|
|
# static quantiazation
|
|
quantization_config = get_symmetric_quantization_config(
|
|
is_per_channel=True
|
|
)
|
|
annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
|
|
_tag_partitions(backend_string, "conv2d", annotated_partitions)
|
|
annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
|
|
gm, quantization_config
|
|
)
|
|
_tag_partitions(
|
|
backend_string, "adaptive_avg_pool2d", annotated_partitions
|
|
)
|
|
|
|
# dynamic quantization
|
|
quantization_config_dynamic = get_symmetric_quantization_config(
|
|
is_per_channel=True, is_dynamic=True
|
|
)
|
|
annotated_partitions = OP_TO_ANNOTATOR["linear"](
|
|
gm, quantization_config_dynamic
|
|
)
|
|
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
# TODO: add get_attr_tags when the test is re-enabled
|
|
get_attr_tags = {}
|
|
quantize_per_tensor_tags = {
|
|
"BackendA_conv2d_0",
|
|
"BackendA_adaptive_avg_pool2d_0",
|
|
}
|
|
quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
|
choose_qparams_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
|
dequantize_per_tensor_tags = {
|
|
"BackendA_adaptive_avg_pool2d_0",
|
|
"BackendA_conv2d_0",
|
|
}
|
|
dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
|
dequantize_per_channel_tags = {
|
|
"BackendA_conv2d_0",
|
|
"BackendA_linear_dynamic_0",
|
|
}
|
|
node_tags = {
|
|
"get_attr": get_attr_tags,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
|
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tensor_tags,
|
|
}
|
|
self._test_metadata_porting(
|
|
TestHelperModules.Conv2dWithObsSharingOps(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_tags,
|
|
)
|
|
|
|
def test_metadata_porting_for_two_dq(self):
|
|
"""
|
|
Model under test
|
|
conv2d -> avgpool -> hardtanh -> linear
|
|
Quantize linear and conv with dynamic quantization
|
|
Check quantization tags on conv2d, avgpool and linear are correctly set
|
|
"""
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
backend_string = "BackendA"
|
|
|
|
# dynamic quantization
|
|
quantization_config_dynamic = get_symmetric_quantization_config(
|
|
is_per_channel=True, is_dynamic=True
|
|
)
|
|
annotated_partitions = OP_TO_ANNOTATOR["conv"](
|
|
gm, quantization_config_dynamic
|
|
)
|
|
_tag_partitions(backend_string, "conv2d_dynamic", annotated_partitions)
|
|
annotated_partitions = OP_TO_ANNOTATOR["linear"](
|
|
gm, quantization_config_dynamic
|
|
)
|
|
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
get_attr_tags = {
|
|
"BackendA_conv2d_dynamic_0",
|
|
"BackendA_linear_dynamic_0",
|
|
}
|
|
choose_qparams_tensor_tags = {
|
|
"BackendA_conv2d_dynamic_0",
|
|
"BackendA_linear_dynamic_0",
|
|
}
|
|
quantize_per_tensor_tensor_tags = {
|
|
"BackendA_conv2d_dynamic_0",
|
|
"BackendA_linear_dynamic_0",
|
|
}
|
|
dequantize_per_tensor_tensor_tags = {
|
|
"BackendA_conv2d_dynamic_0",
|
|
"BackendA_linear_dynamic_0",
|
|
}
|
|
dequantize_per_channel_tags = {
|
|
"BackendA_conv2d_dynamic_0",
|
|
"BackendA_linear_dynamic_0",
|
|
}
|
|
node_tags = {
|
|
"get_attr": get_attr_tags,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
|
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags,
|
|
}
|
|
self._test_metadata_porting(
|
|
TestHelperModules.Conv2dWithObsSharingOps(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_tags,
|
|
)
|
|
|
|
def test_metadata_porting_for_dq_no_static_q(self):
|
|
"""
|
|
Model under test
|
|
conv2d -> avgpool -> hardtanh -> linear
|
|
Dont quantize anything except linear.
|
|
Quantize linear with dynamic quantization
|
|
Check quantization tags on conv2d, avgpool and linear are correctly set
|
|
"""
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
backend_string = "BackendA"
|
|
# dynamic quantization
|
|
quantization_config_dynamic = get_symmetric_quantization_config(
|
|
is_per_channel=True, is_dynamic=True
|
|
)
|
|
annotated_partitions = OP_TO_ANNOTATOR["linear"](
|
|
gm, quantization_config_dynamic
|
|
)
|
|
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
get_attr_tags = {"BackendA_linear_dynamic_0"}
|
|
choose_qparams_tensor_tags = {"BackendA_linear_dynamic_0"}
|
|
quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
|
dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
|
dequantize_per_channel_tags = {"BackendA_linear_dynamic_0"}
|
|
node_tags = {
|
|
"get_attr": get_attr_tags,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
|
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags,
|
|
}
|
|
self._test_metadata_porting(
|
|
TestHelperModules.Conv2dWithObsSharingOps(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_tags,
|
|
)
|
|
|
|
def test_no_metadata_porting(self):
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
backend_string = "BackendA"
|
|
quantization_config = get_symmetric_quantization_config(
|
|
is_per_channel=True
|
|
)
|
|
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
|
|
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
|
|
OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
node_tags = {}
|
|
m = self._test_metadata_porting(
|
|
TestHelperModules.Conv2dWithObsSharingOps(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_tags,
|
|
)
|
|
|
|
from_node_to_tags = {}
|
|
self._test_quant_tag_preservation_through_decomp(
|
|
m, example_inputs, from_node_to_tags
|
|
)
|