mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add prepare_obs_or_fq_callback to quantizer (#140863)
Test Plan: CI. Differential Revision: D65982003 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140863 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
c79e78b503
commit
efe8482c0d
4 changed files with 107 additions and 3 deletions
|
|
@ -1,5 +1,5 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -18,6 +18,7 @@ from torch.ao.quantization.quantize_pt2e import (
|
|||
)
|
||||
from torch.ao.quantization.quantizer import (
|
||||
DerivedQuantizationSpec,
|
||||
EdgeOrNode,
|
||||
FixedQParamsQuantizationSpec,
|
||||
QuantizationAnnotation,
|
||||
QuantizationSpec,
|
||||
|
|
@ -2339,6 +2340,76 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
m = convert_pt2e(m)
|
||||
m(*example_inputs)
|
||||
|
||||
def test_prepare_obs_or_fq_callback(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.max_pool2d(x, 2, 2)
|
||||
x = torch.nn.functional.pixel_shuffle(x, 2)
|
||||
return x.permute(0, 2, 3, 1)
|
||||
|
||||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
act_qspec = QuantizationSpec(
|
||||
dtype=torch.uint8,
|
||||
quant_min=0,
|
||||
quant_max=255,
|
||||
qscheme=torch.per_tensor_affine,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=observer.default_observer,
|
||||
)
|
||||
for node in model.graph.nodes:
|
||||
if node.op == "call_function" and node.target in (
|
||||
torch.ops.aten.max_pool2d.default,
|
||||
torch.ops.aten.permute.default,
|
||||
torch.ops.aten.pixel_shuffle.default,
|
||||
):
|
||||
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map={
|
||||
node.args[0]: act_qspec,
|
||||
},
|
||||
output_qspec=SharedQuantizationSpec((node.args[0], node)),
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
def prepare_obs_or_fq_callback(
|
||||
self,
|
||||
model: torch.fx.GraphModule,
|
||||
edge_or_node_to_obs_or_fq: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||||
) -> None:
|
||||
# hard code output quant by updating entire sharing group
|
||||
output_node = next(n for n in model.graph.nodes if n.op == "output")
|
||||
output_value = output_node.args[0][0]
|
||||
old_observer = edge_or_node_to_obs_or_fq[output_value]
|
||||
sharing_group = [
|
||||
k for k, v in edge_or_node_to_obs_or_fq.items() if v is old_observer
|
||||
]
|
||||
new_observer = observer.FixedQParamsObserver(
|
||||
scale=0.125,
|
||||
zero_point=42,
|
||||
dtype=torch.uint8,
|
||||
quant_min=0,
|
||||
quant_max=255,
|
||||
qscheme=torch.per_tensor_affine,
|
||||
)
|
||||
for x in sharing_group:
|
||||
edge_or_node_to_obs_or_fq[x] = new_observer
|
||||
|
||||
example_inputs = (torch.rand(1, 32, 16, 16),)
|
||||
gm = export_for_training(Model().eval(), example_inputs).module()
|
||||
gm = prepare_pt2e(gm, BackendAQuantizer())
|
||||
gm = convert_pt2e(gm)
|
||||
for n in gm.graph.nodes:
|
||||
if n.op == "call_function" and n.target in (
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
):
|
||||
# Entire graph share the same qspec which was overriden by FixedQParamsObserver
|
||||
self.assertEqual(n.args[1], 0.125)
|
||||
self.assertEqual(n.args[2], 42)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestQuantizePT2E)
|
||||
|
||||
|
|
|
|||
|
|
@ -535,6 +535,7 @@ def prepare(
|
|||
model: GraphModule,
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]],
|
||||
is_qat: bool,
|
||||
obs_or_fq_callback=None,
|
||||
) -> GraphModule:
|
||||
# Since we are mutating the graph as we go, we iterate over the original
|
||||
# nodes before observer insertion, instead of model.graph.nodes.
|
||||
|
|
@ -549,6 +550,8 @@ def prepare(
|
|||
obs_or_fq_map = _get_obs_or_fq_map(
|
||||
edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat
|
||||
)
|
||||
if obs_or_fq_callback:
|
||||
obs_or_fq_callback(model, obs_or_fq_map)
|
||||
|
||||
for node in nodes_before_observation:
|
||||
# TODO: simplify logic for inserting observers
|
||||
|
|
|
|||
|
|
@ -99,7 +99,12 @@ def prepare_pt2e(
|
|||
model = quantizer.transform_for_annotation(model)
|
||||
quantizer.annotate(model)
|
||||
quantizer.validate(model)
|
||||
model = prepare(model, node_name_to_scope, is_qat=False)
|
||||
model = prepare(
|
||||
model,
|
||||
node_name_to_scope,
|
||||
is_qat=False,
|
||||
obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback,
|
||||
)
|
||||
model.meta.update(original_graph_meta)
|
||||
model = _disallow_eval_train(model)
|
||||
return model
|
||||
|
|
@ -172,7 +177,12 @@ def prepare_qat_pt2e(
|
|||
# subgraph that don't need to be quantized
|
||||
# TODO: only fuse if conv and bn are both configured to be quantized
|
||||
_fuse_conv_bn_qat(model)
|
||||
model = prepare(model, node_name_to_scope, is_qat=True)
|
||||
model = prepare(
|
||||
model,
|
||||
node_name_to_scope,
|
||||
is_qat=True,
|
||||
obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback,
|
||||
)
|
||||
model.meta.update(original_graph_meta)
|
||||
model = _disallow_eval_train(model)
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -159,3 +159,23 @@ class Quantizer(ABC):
|
|||
@abstractmethod
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
def prepare_obs_or_fq_callback(
|
||||
self,
|
||||
model: torch.fx.GraphModule,
|
||||
edge_or_node_to_obs_or_fq: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||||
) -> None:
|
||||
"""A callback that will be called after the observers or fake quants are created
|
||||
for each sharing group, but before they are inserted into the graph. The
|
||||
callback can be used to make final quantization adjustments, such as enforcing
|
||||
specific scale and zero point on model input or output.
|
||||
|
||||
Args:
|
||||
* `model`: the graph module being prepared.
|
||||
* `edge_or_node_to_obs_or_fq`: a dictionary mapping each annotated edge and
|
||||
node to the corresponding observer or fake quant object. Note that multiple
|
||||
edges and/or nodes can map to the same observer / fake quant instance if
|
||||
they were annotated with SharedQuantizationSpec. This dictionary can be
|
||||
modified by the callback.
|
||||
"""
|
||||
return
|
||||
|
|
|
|||
Loading…
Reference in a new issue