[BC-Breaking]Remove capture_pre_autograd_graph references in quantization (#139505)

Summary:
As title

This is a BC-breaking change because graph produced by "capture_pre_autograd_graph" cannot be input to quantization anymore. But this is ok, since this API is deprecated for a while and is going to be deleted. We have removed all call sites of it.

We remove the deprecated API references in code, docs, and tests.

We also removed two tests that specific to capture_pre_autograd_graph API.

Test Plan: CI

Differential Revision: D65351887

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139505
Approved by: https://github.com/tugsbayasgalan, https://github.com/andrewor14, https://github.com/jerryzh168
This commit is contained in:
Shangdi Yu 2024-12-13 22:26:22 +00:00 committed by PyTorch MergeBot
parent d25e6e623f
commit bb574abe73
4 changed files with 6 additions and 123 deletions

View file

@ -508,7 +508,7 @@ API Example::
import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch._export import capture_pre_autograd_graph
from torch.export import export_for_training
from torch.ao.quantization.quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
@ -535,7 +535,7 @@ API Example::
# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result should mostly stay the same
m = capture_pre_autograd_graph(m, *example_inputs)
m = export_for_training(m, *example_inputs).module()
# we get a model with aten ops
# Step 2. quantization

View file

@ -568,84 +568,6 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
m = M(self.conv_class)
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
def test_prepare_qat_conv_bn_fusion_getitem_placeholder(self):
"""
Test the case where the placeholder node for the [conv - bn - getitem] pattern
is also a getitem node:
some_op -> unrelated_getitem -> conv -> bn -> conv_bn_getitem
We want the metadata to be copied from the `conv_bn_getitem` node, not from
the `unrelated_getitem` node, which is not part of the conv-bn pattern but
is returned as part of the match anyway (as a placeholder).
"""
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
# T199018392
# remove this test after we kill capture_pre_autograd_graph()
if capture_pre_autograd_graph_using_training_ir():
self.skipTest("Not applicable to training IR")
class M(torch.nn.Module):
def __init__(self, conv_class, bn_class):
super().__init__()
self.bn1 = bn_class(3)
self.conv = conv_class(3, 3, 3)
self.bn2 = bn_class(3)
def forward(self, x):
x = self.bn1(x)
x = self.conv(x)
x = self.bn2(x)
return x
def _get_getitem_nodes(m: torch.fx.GraphModule):
"""
Return a 2-tuple of (unrelated_getitem_node, conv_bn_getitem_node) from the graph.
"""
unrelated_getitem_node, conv_bn_getitem_node = None, None
for node in m.graph.nodes:
if (
node.target != operator.getitem
or node.args[0].target
!= torch.ops.aten._native_batch_norm_legit.default
):
continue
if node.args[0].args[0].op == "placeholder":
unrelated_getitem_node = node
else:
conv_bn_getitem_node = node
assert (
unrelated_getitem_node is not None
), "did not find unrelated getitem node, bad test setup"
assert (
conv_bn_getitem_node is not None
), "did not find conv bn getitem node, bad test setup"
return (unrelated_getitem_node, conv_bn_getitem_node)
# Program capture
m = M(self.conv_class, self.bn_class)
m = torch._export.capture_pre_autograd_graph(m, self.example_inputs)
m.graph.eliminate_dead_code()
m.recompile()
(_, original_conv_bn_getitem_node) = _get_getitem_nodes(m)
# Prepare QAT
quantizer = XNNPACKQuantizer()
quantizer.set_global(
get_symmetric_quantization_config(is_per_channel=False, is_qat=True)
)
m = prepare_qat_pt2e(m, quantizer)
(unrelated_getitem_node, conv_bn_getitem_node) = _get_getitem_nodes(m)
# Verify that the metadata was copied from `conv_bn_getitem`, not `unrelated_getitem`
original_conv_bn_getitem_meta = original_conv_bn_getitem_node.meta[
"quantization_annotation"
]
conv_bn_getitem_meta = conv_bn_getitem_node.meta["quantization_annotation"]
self.assertEqual(conv_bn_getitem_meta, original_conv_bn_getitem_meta)
self.assertTrue("quantization_annotation" not in unrelated_getitem_node.meta)
def test_qat_update_shared_qspec(self):
"""
Test the case where nodes used in SharedQuantizationSpec were replaced
@ -926,39 +848,6 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
self.assertTrue(conv_node is not None)
self.assertTrue(bn_node is None)
def test_preserve_capture_pre_autograd_graph_tag(self):
"""
Ensure the capture_pre_autograd_graph_tag node meta is preserved.
TODO: Remove this test after training IR migration.
T199018392
"""
from torch._export import capture_pre_autograd_graph
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
if capture_pre_autograd_graph_using_training_ir():
self.skipTest(
"test doesn't apply when capture_pre_autograd_graph is using training IR"
)
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
m = capture_pre_autograd_graph(m, self.example_inputs)
for node in m.graph.nodes:
self.assertTrue(node.meta.get("capture_pre_autograd_graph_tag", False))
quantizer = XNNPACKQuantizer()
quantizer.set_global(
get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
)
m = prepare_qat_pt2e(m, quantizer)
m = convert_pt2e(m)
has_tag = False
for node in m.graph.nodes:
if not node.meta.get("capture_pre_autograd_graph_tag", False):
has_tag = True
break
self.assertTrue(has_tag)
torch.export.export(m, self.example_inputs)
@skipIfNoQNNPACK
class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base):

View file

@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch._export import capture_pre_autograd_graph
# Makes sure that quantized_decomposed ops are registered
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
@ -381,10 +380,9 @@ def _get_aten_graph_module_for_pattern(
kwargs,
).module()
else:
aten_pattern = capture_pre_autograd_graph(
pattern, # type: ignore[arg-type]
example_inputs,
kwargs,
raise RuntimeError(
"capture_pre_autograd_graph is deprecated and will be deleted soon."
"Please use torch.export.export_for_training instead."
)
aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr]
aten_pattern.recompile() # type: ignore[operator]

View file

@ -35,9 +35,7 @@ def prepare_pt2e(
"""Prepare a model for post training quantization
Args:
* `model` (torch.fx.GraphModule): a model captured by `torch.export` API
in the short term we are using `torch._export.capture_pre_autograd_graph`,
in the long term we'll migrate to some `torch.export` API
* `model` (torch.fx.GraphModule): a model captured by `torch.export.export_for_training` API.
* `quantizer`: A backend specific quantizer that conveys how user want the
model to be quantized. Tutorial for how to write a quantizer can be found here:
https://pytorch.org/tutorials/prototype/pt2e_quantizer.html
@ -49,7 +47,6 @@ def prepare_pt2e(
import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
@ -127,7 +124,6 @@ def prepare_qat_pt2e(
Example::
import torch
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,