diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index 96ce8c05fbb..1b808136ef1 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -508,7 +508,7 @@ API Example:: import torch from torch.ao.quantization.quantize_pt2e import prepare_pt2e - from torch._export import capture_pre_autograd_graph + from torch.export import export_for_training from torch.ao.quantization.quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config, @@ -535,7 +535,7 @@ API Example:: # Step 1. program capture # NOTE: this API will be updated to torch.export API in the future, but the captured # result should mostly stay the same - m = capture_pre_autograd_graph(m, *example_inputs) + m = export_for_training(m, *example_inputs).module() # we get a model with aten ops # Step 2. quantization diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index e400e3a6b68..3ecc1bef17b 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -568,84 +568,6 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase): m = M(self.conv_class) self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) - def test_prepare_qat_conv_bn_fusion_getitem_placeholder(self): - """ - Test the case where the placeholder node for the [conv - bn - getitem] pattern - is also a getitem node: - - some_op -> unrelated_getitem -> conv -> bn -> conv_bn_getitem - - We want the metadata to be copied from the `conv_bn_getitem` node, not from - the `unrelated_getitem` node, which is not part of the conv-bn pattern but - is returned as part of the match anyway (as a placeholder). - """ - from torch._utils_internal import capture_pre_autograd_graph_using_training_ir - - # T199018392 - # remove this test after we kill capture_pre_autograd_graph() - if capture_pre_autograd_graph_using_training_ir(): - self.skipTest("Not applicable to training IR") - - class M(torch.nn.Module): - def __init__(self, conv_class, bn_class): - super().__init__() - self.bn1 = bn_class(3) - self.conv = conv_class(3, 3, 3) - self.bn2 = bn_class(3) - - def forward(self, x): - x = self.bn1(x) - x = self.conv(x) - x = self.bn2(x) - return x - - def _get_getitem_nodes(m: torch.fx.GraphModule): - """ - Return a 2-tuple of (unrelated_getitem_node, conv_bn_getitem_node) from the graph. - """ - unrelated_getitem_node, conv_bn_getitem_node = None, None - for node in m.graph.nodes: - if ( - node.target != operator.getitem - or node.args[0].target - != torch.ops.aten._native_batch_norm_legit.default - ): - continue - if node.args[0].args[0].op == "placeholder": - unrelated_getitem_node = node - else: - conv_bn_getitem_node = node - assert ( - unrelated_getitem_node is not None - ), "did not find unrelated getitem node, bad test setup" - assert ( - conv_bn_getitem_node is not None - ), "did not find conv bn getitem node, bad test setup" - return (unrelated_getitem_node, conv_bn_getitem_node) - - # Program capture - m = M(self.conv_class, self.bn_class) - m = torch._export.capture_pre_autograd_graph(m, self.example_inputs) - m.graph.eliminate_dead_code() - m.recompile() - (_, original_conv_bn_getitem_node) = _get_getitem_nodes(m) - - # Prepare QAT - quantizer = XNNPACKQuantizer() - quantizer.set_global( - get_symmetric_quantization_config(is_per_channel=False, is_qat=True) - ) - m = prepare_qat_pt2e(m, quantizer) - (unrelated_getitem_node, conv_bn_getitem_node) = _get_getitem_nodes(m) - - # Verify that the metadata was copied from `conv_bn_getitem`, not `unrelated_getitem` - original_conv_bn_getitem_meta = original_conv_bn_getitem_node.meta[ - "quantization_annotation" - ] - conv_bn_getitem_meta = conv_bn_getitem_node.meta["quantization_annotation"] - self.assertEqual(conv_bn_getitem_meta, original_conv_bn_getitem_meta) - self.assertTrue("quantization_annotation" not in unrelated_getitem_node.meta) - def test_qat_update_shared_qspec(self): """ Test the case where nodes used in SharedQuantizationSpec were replaced @@ -926,39 +848,6 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase): self.assertTrue(conv_node is not None) self.assertTrue(bn_node is None) - def test_preserve_capture_pre_autograd_graph_tag(self): - """ - Ensure the capture_pre_autograd_graph_tag node meta is preserved. - TODO: Remove this test after training IR migration. - T199018392 - """ - from torch._export import capture_pre_autograd_graph - from torch._utils_internal import capture_pre_autograd_graph_using_training_ir - - if capture_pre_autograd_graph_using_training_ir(): - self.skipTest( - "test doesn't apply when capture_pre_autograd_graph is using training IR" - ) - - m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False) - m = capture_pre_autograd_graph(m, self.example_inputs) - - for node in m.graph.nodes: - self.assertTrue(node.meta.get("capture_pre_autograd_graph_tag", False)) - quantizer = XNNPACKQuantizer() - quantizer.set_global( - get_symmetric_quantization_config(is_per_channel=False, is_qat=True), - ) - m = prepare_qat_pt2e(m, quantizer) - m = convert_pt2e(m) - has_tag = False - for node in m.graph.nodes: - if not node.meta.get("capture_pre_autograd_graph_tag", False): - has_tag = True - break - self.assertTrue(has_tag) - torch.export.export(m, self.example_inputs) - @skipIfNoQNNPACK class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base): diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index d9223fefb12..7b22bacbe57 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F -from torch._export import capture_pre_autograd_graph # Makes sure that quantized_decomposed ops are registered from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 @@ -381,10 +380,9 @@ def _get_aten_graph_module_for_pattern( kwargs, ).module() else: - aten_pattern = capture_pre_autograd_graph( - pattern, # type: ignore[arg-type] - example_inputs, - kwargs, + raise RuntimeError( + "capture_pre_autograd_graph is deprecated and will be deleted soon." + "Please use torch.export.export_for_training instead." ) aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] aten_pattern.recompile() # type: ignore[operator] diff --git a/torch/ao/quantization/quantize_pt2e.py b/torch/ao/quantization/quantize_pt2e.py index 5760c07ccc0..c25d7c5b951 100644 --- a/torch/ao/quantization/quantize_pt2e.py +++ b/torch/ao/quantization/quantize_pt2e.py @@ -35,9 +35,7 @@ def prepare_pt2e( """Prepare a model for post training quantization Args: - * `model` (torch.fx.GraphModule): a model captured by `torch.export` API - in the short term we are using `torch._export.capture_pre_autograd_graph`, - in the long term we'll migrate to some `torch.export` API + * `model` (torch.fx.GraphModule): a model captured by `torch.export.export_for_training` API. * `quantizer`: A backend specific quantizer that conveys how user want the model to be quantized. Tutorial for how to write a quantizer can be found here: https://pytorch.org/tutorials/prototype/pt2e_quantizer.html @@ -49,7 +47,6 @@ def prepare_pt2e( import torch from torch.ao.quantization.quantize_pt2e import prepare_pt2e - from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config, @@ -127,7 +124,6 @@ def prepare_qat_pt2e( Example:: import torch from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e - from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config,