From bb574abe73aa84fd9b895b4c7c39ad79bd53f037 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Fri, 13 Dec 2024 22:26:22 +0000 Subject: [PATCH] [BC-Breaking]Remove capture_pre_autograd_graph references in quantization (#139505) Summary: As title This is a BC-breaking change because graph produced by "capture_pre_autograd_graph" cannot be input to quantization anymore. But this is ok, since this API is deprecated for a while and is going to be deleted. We have removed all call sites of it. We remove the deprecated API references in code, docs, and tests. We also removed two tests that specific to capture_pre_autograd_graph API. Test Plan: CI Differential Revision: D65351887 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139505 Approved by: https://github.com/tugsbayasgalan, https://github.com/andrewor14, https://github.com/jerryzh168 --- docs/source/quantization.rst | 4 +- .../pt2e/test_quantize_pt2e_qat.py | 111 ------------------ torch/ao/quantization/pt2e/utils.py | 8 +- torch/ao/quantization/quantize_pt2e.py | 6 +- 4 files changed, 6 insertions(+), 123 deletions(-) diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index 96ce8c05fbb..1b808136ef1 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -508,7 +508,7 @@ API Example:: import torch from torch.ao.quantization.quantize_pt2e import prepare_pt2e - from torch._export import capture_pre_autograd_graph + from torch.export import export_for_training from torch.ao.quantization.quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config, @@ -535,7 +535,7 @@ API Example:: # Step 1. program capture # NOTE: this API will be updated to torch.export API in the future, but the captured # result should mostly stay the same - m = capture_pre_autograd_graph(m, *example_inputs) + m = export_for_training(m, *example_inputs).module() # we get a model with aten ops # Step 2. quantization diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index e400e3a6b68..3ecc1bef17b 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -568,84 +568,6 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase): m = M(self.conv_class) self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) - def test_prepare_qat_conv_bn_fusion_getitem_placeholder(self): - """ - Test the case where the placeholder node for the [conv - bn - getitem] pattern - is also a getitem node: - - some_op -> unrelated_getitem -> conv -> bn -> conv_bn_getitem - - We want the metadata to be copied from the `conv_bn_getitem` node, not from - the `unrelated_getitem` node, which is not part of the conv-bn pattern but - is returned as part of the match anyway (as a placeholder). - """ - from torch._utils_internal import capture_pre_autograd_graph_using_training_ir - - # T199018392 - # remove this test after we kill capture_pre_autograd_graph() - if capture_pre_autograd_graph_using_training_ir(): - self.skipTest("Not applicable to training IR") - - class M(torch.nn.Module): - def __init__(self, conv_class, bn_class): - super().__init__() - self.bn1 = bn_class(3) - self.conv = conv_class(3, 3, 3) - self.bn2 = bn_class(3) - - def forward(self, x): - x = self.bn1(x) - x = self.conv(x) - x = self.bn2(x) - return x - - def _get_getitem_nodes(m: torch.fx.GraphModule): - """ - Return a 2-tuple of (unrelated_getitem_node, conv_bn_getitem_node) from the graph. - """ - unrelated_getitem_node, conv_bn_getitem_node = None, None - for node in m.graph.nodes: - if ( - node.target != operator.getitem - or node.args[0].target - != torch.ops.aten._native_batch_norm_legit.default - ): - continue - if node.args[0].args[0].op == "placeholder": - unrelated_getitem_node = node - else: - conv_bn_getitem_node = node - assert ( - unrelated_getitem_node is not None - ), "did not find unrelated getitem node, bad test setup" - assert ( - conv_bn_getitem_node is not None - ), "did not find conv bn getitem node, bad test setup" - return (unrelated_getitem_node, conv_bn_getitem_node) - - # Program capture - m = M(self.conv_class, self.bn_class) - m = torch._export.capture_pre_autograd_graph(m, self.example_inputs) - m.graph.eliminate_dead_code() - m.recompile() - (_, original_conv_bn_getitem_node) = _get_getitem_nodes(m) - - # Prepare QAT - quantizer = XNNPACKQuantizer() - quantizer.set_global( - get_symmetric_quantization_config(is_per_channel=False, is_qat=True) - ) - m = prepare_qat_pt2e(m, quantizer) - (unrelated_getitem_node, conv_bn_getitem_node) = _get_getitem_nodes(m) - - # Verify that the metadata was copied from `conv_bn_getitem`, not `unrelated_getitem` - original_conv_bn_getitem_meta = original_conv_bn_getitem_node.meta[ - "quantization_annotation" - ] - conv_bn_getitem_meta = conv_bn_getitem_node.meta["quantization_annotation"] - self.assertEqual(conv_bn_getitem_meta, original_conv_bn_getitem_meta) - self.assertTrue("quantization_annotation" not in unrelated_getitem_node.meta) - def test_qat_update_shared_qspec(self): """ Test the case where nodes used in SharedQuantizationSpec were replaced @@ -926,39 +848,6 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase): self.assertTrue(conv_node is not None) self.assertTrue(bn_node is None) - def test_preserve_capture_pre_autograd_graph_tag(self): - """ - Ensure the capture_pre_autograd_graph_tag node meta is preserved. - TODO: Remove this test after training IR migration. - T199018392 - """ - from torch._export import capture_pre_autograd_graph - from torch._utils_internal import capture_pre_autograd_graph_using_training_ir - - if capture_pre_autograd_graph_using_training_ir(): - self.skipTest( - "test doesn't apply when capture_pre_autograd_graph is using training IR" - ) - - m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False) - m = capture_pre_autograd_graph(m, self.example_inputs) - - for node in m.graph.nodes: - self.assertTrue(node.meta.get("capture_pre_autograd_graph_tag", False)) - quantizer = XNNPACKQuantizer() - quantizer.set_global( - get_symmetric_quantization_config(is_per_channel=False, is_qat=True), - ) - m = prepare_qat_pt2e(m, quantizer) - m = convert_pt2e(m) - has_tag = False - for node in m.graph.nodes: - if not node.meta.get("capture_pre_autograd_graph_tag", False): - has_tag = True - break - self.assertTrue(has_tag) - torch.export.export(m, self.example_inputs) - @skipIfNoQNNPACK class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base): diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index d9223fefb12..7b22bacbe57 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F -from torch._export import capture_pre_autograd_graph # Makes sure that quantized_decomposed ops are registered from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 @@ -381,10 +380,9 @@ def _get_aten_graph_module_for_pattern( kwargs, ).module() else: - aten_pattern = capture_pre_autograd_graph( - pattern, # type: ignore[arg-type] - example_inputs, - kwargs, + raise RuntimeError( + "capture_pre_autograd_graph is deprecated and will be deleted soon." + "Please use torch.export.export_for_training instead." ) aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] aten_pattern.recompile() # type: ignore[operator] diff --git a/torch/ao/quantization/quantize_pt2e.py b/torch/ao/quantization/quantize_pt2e.py index 5760c07ccc0..c25d7c5b951 100644 --- a/torch/ao/quantization/quantize_pt2e.py +++ b/torch/ao/quantization/quantize_pt2e.py @@ -35,9 +35,7 @@ def prepare_pt2e( """Prepare a model for post training quantization Args: - * `model` (torch.fx.GraphModule): a model captured by `torch.export` API - in the short term we are using `torch._export.capture_pre_autograd_graph`, - in the long term we'll migrate to some `torch.export` API + * `model` (torch.fx.GraphModule): a model captured by `torch.export.export_for_training` API. * `quantizer`: A backend specific quantizer that conveys how user want the model to be quantized. Tutorial for how to write a quantizer can be found here: https://pytorch.org/tutorials/prototype/pt2e_quantizer.html @@ -49,7 +47,6 @@ def prepare_pt2e( import torch from torch.ao.quantization.quantize_pt2e import prepare_pt2e - from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config, @@ -127,7 +124,6 @@ def prepare_qat_pt2e( Example:: import torch from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e - from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config,