Add prepare_obs_or_fq_callback to quantizer (#140863)

Test Plan: CI.

Differential Revision: D65982003

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140863
Approved by: https://github.com/jerryzh168
This commit is contained in:
Shen Xu 2024-11-19 01:13:38 +00:00 committed by PyTorch MergeBot
parent c79e78b503
commit efe8482c0d
4 changed files with 107 additions and 3 deletions

View file

@ -1,5 +1,5 @@
# Owner(s): ["oncall: quantization"]
from typing import List, Tuple
from typing import Dict, List, Tuple
import torch
from torch import Tensor
@ -18,6 +18,7 @@ from torch.ao.quantization.quantize_pt2e import (
)
from torch.ao.quantization.quantizer import (
DerivedQuantizationSpec,
EdgeOrNode,
FixedQParamsQuantizationSpec,
QuantizationAnnotation,
QuantizationSpec,
@ -2339,6 +2340,76 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = convert_pt2e(m)
m(*example_inputs)
def test_prepare_obs_or_fq_callback(self):
class Model(torch.nn.Module):
def forward(self, x):
x = torch.nn.functional.max_pool2d(x, 2, 2)
x = torch.nn.functional.pixel_shuffle(x, 2)
return x.permute(0, 2, 3, 1)
class BackendAQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
act_qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_observer,
)
for node in model.graph.nodes:
if node.op == "call_function" and node.target in (
torch.ops.aten.max_pool2d.default,
torch.ops.aten.permute.default,
torch.ops.aten.pixel_shuffle.default,
):
node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
node.args[0]: act_qspec,
},
output_qspec=SharedQuantizationSpec((node.args[0], node)),
_annotated=True,
)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
def prepare_obs_or_fq_callback(
self,
model: torch.fx.GraphModule,
edge_or_node_to_obs_or_fq: Dict[EdgeOrNode, ObserverOrFakeQuantize],
) -> None:
# hard code output quant by updating entire sharing group
output_node = next(n for n in model.graph.nodes if n.op == "output")
output_value = output_node.args[0][0]
old_observer = edge_or_node_to_obs_or_fq[output_value]
sharing_group = [
k for k, v in edge_or_node_to_obs_or_fq.items() if v is old_observer
]
new_observer = observer.FixedQParamsObserver(
scale=0.125,
zero_point=42,
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
)
for x in sharing_group:
edge_or_node_to_obs_or_fq[x] = new_observer
example_inputs = (torch.rand(1, 32, 16, 16),)
gm = export_for_training(Model().eval(), example_inputs).module()
gm = prepare_pt2e(gm, BackendAQuantizer())
gm = convert_pt2e(gm)
for n in gm.graph.nodes:
if n.op == "call_function" and n.target in (
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
):
# Entire graph share the same qspec which was overriden by FixedQParamsObserver
self.assertEqual(n.args[1], 0.125)
self.assertEqual(n.args[2], 42)
instantiate_parametrized_tests(TestQuantizePT2E)

View file

@ -535,6 +535,7 @@ def prepare(
model: GraphModule,
node_name_to_scope: Dict[str, Tuple[str, type]],
is_qat: bool,
obs_or_fq_callback=None,
) -> GraphModule:
# Since we are mutating the graph as we go, we iterate over the original
# nodes before observer insertion, instead of model.graph.nodes.
@ -549,6 +550,8 @@ def prepare(
obs_or_fq_map = _get_obs_or_fq_map(
edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat
)
if obs_or_fq_callback:
obs_or_fq_callback(model, obs_or_fq_map)
for node in nodes_before_observation:
# TODO: simplify logic for inserting observers

View file

@ -99,7 +99,12 @@ def prepare_pt2e(
model = quantizer.transform_for_annotation(model)
quantizer.annotate(model)
quantizer.validate(model)
model = prepare(model, node_name_to_scope, is_qat=False)
model = prepare(
model,
node_name_to_scope,
is_qat=False,
obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback,
)
model.meta.update(original_graph_meta)
model = _disallow_eval_train(model)
return model
@ -172,7 +177,12 @@ def prepare_qat_pt2e(
# subgraph that don't need to be quantized
# TODO: only fuse if conv and bn are both configured to be quantized
_fuse_conv_bn_qat(model)
model = prepare(model, node_name_to_scope, is_qat=True)
model = prepare(
model,
node_name_to_scope,
is_qat=True,
obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback,
)
model.meta.update(original_graph_meta)
model = _disallow_eval_train(model)
return model

View file

@ -159,3 +159,23 @@ class Quantizer(ABC):
@abstractmethod
def validate(self, model: torch.fx.GraphModule) -> None:
pass
def prepare_obs_or_fq_callback(
self,
model: torch.fx.GraphModule,
edge_or_node_to_obs_or_fq: Dict[EdgeOrNode, ObserverOrFakeQuantize],
) -> None:
"""A callback that will be called after the observers or fake quants are created
for each sharing group, but before they are inserted into the graph. The
callback can be used to make final quantization adjustments, such as enforcing
specific scale and zero point on model input or output.
Args:
* `model`: the graph module being prepared.
* `edge_or_node_to_obs_or_fq`: a dictionary mapping each annotated edge and
node to the corresponding observer or fake quant object. Note that multiple
edges and/or nodes can map to the same observer / fake quant instance if
they were annotated with SharedQuantizationSpec. This dictionary can be
modified by the callback.
"""
return