mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[BC-Breaking]Remove capture_pre_autograd_graph references in quantization (#139505)
Summary: As title This is a BC-breaking change because graph produced by "capture_pre_autograd_graph" cannot be input to quantization anymore. But this is ok, since this API is deprecated for a while and is going to be deleted. We have removed all call sites of it. We remove the deprecated API references in code, docs, and tests. We also removed two tests that specific to capture_pre_autograd_graph API. Test Plan: CI Differential Revision: D65351887 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139505 Approved by: https://github.com/tugsbayasgalan, https://github.com/andrewor14, https://github.com/jerryzh168
This commit is contained in:
parent
d25e6e623f
commit
bb574abe73
4 changed files with 6 additions and 123 deletions
|
|
@ -508,7 +508,7 @@ API Example::
|
|||
|
||||
import torch
|
||||
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
from torch.export import export_for_training
|
||||
from torch.ao.quantization.quantizer import (
|
||||
XNNPACKQuantizer,
|
||||
get_symmetric_quantization_config,
|
||||
|
|
@ -535,7 +535,7 @@ API Example::
|
|||
# Step 1. program capture
|
||||
# NOTE: this API will be updated to torch.export API in the future, but the captured
|
||||
# result should mostly stay the same
|
||||
m = capture_pre_autograd_graph(m, *example_inputs)
|
||||
m = export_for_training(m, *example_inputs).module()
|
||||
# we get a model with aten ops
|
||||
|
||||
# Step 2. quantization
|
||||
|
|
|
|||
|
|
@ -568,84 +568,6 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
|||
m = M(self.conv_class)
|
||||
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
|
||||
|
||||
def test_prepare_qat_conv_bn_fusion_getitem_placeholder(self):
|
||||
"""
|
||||
Test the case where the placeholder node for the [conv - bn - getitem] pattern
|
||||
is also a getitem node:
|
||||
|
||||
some_op -> unrelated_getitem -> conv -> bn -> conv_bn_getitem
|
||||
|
||||
We want the metadata to be copied from the `conv_bn_getitem` node, not from
|
||||
the `unrelated_getitem` node, which is not part of the conv-bn pattern but
|
||||
is returned as part of the match anyway (as a placeholder).
|
||||
"""
|
||||
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
|
||||
|
||||
# T199018392
|
||||
# remove this test after we kill capture_pre_autograd_graph()
|
||||
if capture_pre_autograd_graph_using_training_ir():
|
||||
self.skipTest("Not applicable to training IR")
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, conv_class, bn_class):
|
||||
super().__init__()
|
||||
self.bn1 = bn_class(3)
|
||||
self.conv = conv_class(3, 3, 3)
|
||||
self.bn2 = bn_class(3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.bn1(x)
|
||||
x = self.conv(x)
|
||||
x = self.bn2(x)
|
||||
return x
|
||||
|
||||
def _get_getitem_nodes(m: torch.fx.GraphModule):
|
||||
"""
|
||||
Return a 2-tuple of (unrelated_getitem_node, conv_bn_getitem_node) from the graph.
|
||||
"""
|
||||
unrelated_getitem_node, conv_bn_getitem_node = None, None
|
||||
for node in m.graph.nodes:
|
||||
if (
|
||||
node.target != operator.getitem
|
||||
or node.args[0].target
|
||||
!= torch.ops.aten._native_batch_norm_legit.default
|
||||
):
|
||||
continue
|
||||
if node.args[0].args[0].op == "placeholder":
|
||||
unrelated_getitem_node = node
|
||||
else:
|
||||
conv_bn_getitem_node = node
|
||||
assert (
|
||||
unrelated_getitem_node is not None
|
||||
), "did not find unrelated getitem node, bad test setup"
|
||||
assert (
|
||||
conv_bn_getitem_node is not None
|
||||
), "did not find conv bn getitem node, bad test setup"
|
||||
return (unrelated_getitem_node, conv_bn_getitem_node)
|
||||
|
||||
# Program capture
|
||||
m = M(self.conv_class, self.bn_class)
|
||||
m = torch._export.capture_pre_autograd_graph(m, self.example_inputs)
|
||||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
(_, original_conv_bn_getitem_node) = _get_getitem_nodes(m)
|
||||
|
||||
# Prepare QAT
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantizer.set_global(
|
||||
get_symmetric_quantization_config(is_per_channel=False, is_qat=True)
|
||||
)
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
(unrelated_getitem_node, conv_bn_getitem_node) = _get_getitem_nodes(m)
|
||||
|
||||
# Verify that the metadata was copied from `conv_bn_getitem`, not `unrelated_getitem`
|
||||
original_conv_bn_getitem_meta = original_conv_bn_getitem_node.meta[
|
||||
"quantization_annotation"
|
||||
]
|
||||
conv_bn_getitem_meta = conv_bn_getitem_node.meta["quantization_annotation"]
|
||||
self.assertEqual(conv_bn_getitem_meta, original_conv_bn_getitem_meta)
|
||||
self.assertTrue("quantization_annotation" not in unrelated_getitem_node.meta)
|
||||
|
||||
def test_qat_update_shared_qspec(self):
|
||||
"""
|
||||
Test the case where nodes used in SharedQuantizationSpec were replaced
|
||||
|
|
@ -926,39 +848,6 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
|||
self.assertTrue(conv_node is not None)
|
||||
self.assertTrue(bn_node is None)
|
||||
|
||||
def test_preserve_capture_pre_autograd_graph_tag(self):
|
||||
"""
|
||||
Ensure the capture_pre_autograd_graph_tag node meta is preserved.
|
||||
TODO: Remove this test after training IR migration.
|
||||
T199018392
|
||||
"""
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
|
||||
|
||||
if capture_pre_autograd_graph_using_training_ir():
|
||||
self.skipTest(
|
||||
"test doesn't apply when capture_pre_autograd_graph is using training IR"
|
||||
)
|
||||
|
||||
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
|
||||
m = capture_pre_autograd_graph(m, self.example_inputs)
|
||||
|
||||
for node in m.graph.nodes:
|
||||
self.assertTrue(node.meta.get("capture_pre_autograd_graph_tag", False))
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantizer.set_global(
|
||||
get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
|
||||
)
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m = convert_pt2e(m)
|
||||
has_tag = False
|
||||
for node in m.graph.nodes:
|
||||
if not node.meta.get("capture_pre_autograd_graph_tag", False):
|
||||
has_tag = True
|
||||
break
|
||||
self.assertTrue(has_tag)
|
||||
torch.export.export(m, self.example_inputs)
|
||||
|
||||
|
||||
@skipIfNoQNNPACK
|
||||
class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
|
||||
# Makes sure that quantized_decomposed ops are registered
|
||||
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
||||
|
|
@ -381,10 +380,9 @@ def _get_aten_graph_module_for_pattern(
|
|||
kwargs,
|
||||
).module()
|
||||
else:
|
||||
aten_pattern = capture_pre_autograd_graph(
|
||||
pattern, # type: ignore[arg-type]
|
||||
example_inputs,
|
||||
kwargs,
|
||||
raise RuntimeError(
|
||||
"capture_pre_autograd_graph is deprecated and will be deleted soon."
|
||||
"Please use torch.export.export_for_training instead."
|
||||
)
|
||||
aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr]
|
||||
aten_pattern.recompile() # type: ignore[operator]
|
||||
|
|
|
|||
|
|
@ -35,9 +35,7 @@ def prepare_pt2e(
|
|||
"""Prepare a model for post training quantization
|
||||
|
||||
Args:
|
||||
* `model` (torch.fx.GraphModule): a model captured by `torch.export` API
|
||||
in the short term we are using `torch._export.capture_pre_autograd_graph`,
|
||||
in the long term we'll migrate to some `torch.export` API
|
||||
* `model` (torch.fx.GraphModule): a model captured by `torch.export.export_for_training` API.
|
||||
* `quantizer`: A backend specific quantizer that conveys how user want the
|
||||
model to be quantized. Tutorial for how to write a quantizer can be found here:
|
||||
https://pytorch.org/tutorials/prototype/pt2e_quantizer.html
|
||||
|
|
@ -49,7 +47,6 @@ def prepare_pt2e(
|
|||
|
||||
import torch
|
||||
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
from torch.ao.quantization.quantizer import (
|
||||
XNNPACKQuantizer,
|
||||
get_symmetric_quantization_config,
|
||||
|
|
@ -127,7 +124,6 @@ def prepare_qat_pt2e(
|
|||
Example::
|
||||
import torch
|
||||
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
from torch.ao.quantization.quantizer import (
|
||||
XNNPACKQuantizer,
|
||||
get_symmetric_quantization_config,
|
||||
|
|
|
|||
Loading…
Reference in a new issue