mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[reland] Kill capture_pre_autograd_graph API (#143426)
Summary: Delete the following API: - capture_pre_autograd_graph() - capture_pre_autograd_graph_using_training_ir() - gm_using_training_ir() Update XLA pin to include https://github.com/pytorch/xla/pull/8398 There's no more call sites to `capture_pre_autograd_graph`. Except 1) two test cases in coreml, guarded by version guard, PR to remove: https://github.com/apple/coremltools/pull/2400 2) a few call sites guarded by version guard (< 2.5.0) Test Plan: CI Differential Revision: D67354440 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143426 Approved by: https://github.com/gmagogsfm
This commit is contained in:
parent
eb67dd3e2d
commit
d8ea4ce631
8 changed files with 6 additions and 250 deletions
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
|
|
@ -1 +1 @@
|
|||
73f54ba5bd7fb83d7ba81fe6f5e05fb6ee815d6f
|
||||
b2b890e962f5fb6f481e5da2eb4a43bb990d0f1b
|
||||
|
|
|
|||
|
|
@ -58,215 +58,6 @@ class ExportDynamoConfig:
|
|||
allow_rnn: bool = True
|
||||
|
||||
|
||||
# We only want to print this once to avoid flooding logs in workflows where capture_pre_autograd_graph
|
||||
# is called multiple times.
|
||||
@lru_cache
|
||||
def capture_pre_autograd_graph_warning():
|
||||
from torch._inductor import config
|
||||
|
||||
log.warning("+============================+")
|
||||
log.warning("| !!! WARNING !!! |")
|
||||
log.warning("+============================+")
|
||||
log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.")
|
||||
log.warning("Please switch to use torch.export.export_for_training instead.")
|
||||
if config.is_fbcode():
|
||||
log.warning("For unittest, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950
|
||||
|
||||
@lru_cache
|
||||
def print_export_warning():
|
||||
log.warning("Using torch.export.export_for_training(...,strict=True)")
|
||||
|
||||
def gm_using_training_ir(graph_module: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Returns true if the graph module is detected to use training IR.
|
||||
|
||||
This function checks for two specific conditions within the nodes of the graph module:
|
||||
1. The presence of the `torch.ops.aten.batch_norm.default` operation which indicates the use of training IR.
|
||||
2. The presence of deprecated IR tags on node meta or batch norm ops produced by the deprecated IR.
|
||||
|
||||
The function raises a RuntimeError if both conditions are met, indicating a conflict in the IR.
|
||||
"""
|
||||
# TODO: clean up this code after training IR migration.
|
||||
# T199018392
|
||||
has_training_ir_batch_norm = False
|
||||
has_deprecated_ir_tag = getattr(graph_module, "capture_pre_autograd_graph_tag", False)
|
||||
for node in graph_module.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target == torch.ops.aten.batch_norm.default:
|
||||
has_training_ir_batch_norm = True
|
||||
if node.meta.get("capture_pre_autograd_graph_tag", False):
|
||||
has_deprecated_ir_tag = True
|
||||
if node.target in [
|
||||
torch.ops.aten._native_batch_norm_legit.default,
|
||||
torch.ops.aten.cudnn_batch_norm.default,
|
||||
torch.ops.aten.miopen_batch_norm.default,
|
||||
]:
|
||||
has_deprecated_ir_tag = True
|
||||
|
||||
if has_deprecated_ir_tag and has_training_ir_batch_norm:
|
||||
raise RuntimeError("Conflicting IR detected.")
|
||||
return has_training_ir_batch_norm or not has_deprecated_ir_tag
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def capture_pre_autograd_graph(
|
||||
f: torch.nn.Module,
|
||||
args: Tuple[Any],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
||||
) -> torch.nn.Module:
|
||||
"""
|
||||
A helper function that is intended to trace a module before any pre-autograd
|
||||
decomposition is run. The produced module will be "non-functional" and
|
||||
composed of aten operators. Later this API will be deleted in favor of more general
|
||||
torch.export API.
|
||||
|
||||
Args:
|
||||
f: nn.Module to be traced
|
||||
|
||||
args: example positional inputs.
|
||||
|
||||
kwargs: optional example keyword inputs.
|
||||
|
||||
dynamic_shapes: Should either be:
|
||||
1) a dict from argument names of ``f`` to their dynamic shape specifications,
|
||||
2) a tuple that specifies dynamic shape specifications for each input in original order.
|
||||
If you are specifying dynamism on keyword args, you will need to pass them in the order that
|
||||
is defined in the original function signature.
|
||||
|
||||
The dynamic shape of a tensor argument can be specified as either
|
||||
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
|
||||
not required to include static dimension indices in this dict, but when they are,
|
||||
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
|
||||
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
|
||||
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
|
||||
recursively specified by using mappings or sequences of contained specifications.
|
||||
|
||||
Returns:
|
||||
An nn.Module containing the traced method.
|
||||
|
||||
"""
|
||||
from torch.export._trace import _extract_fake_inputs, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps
|
||||
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
|
||||
from torch._export.non_strict_utils import make_constraints
|
||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||
from torch.export._unlift import _create_stateful_graph_module
|
||||
from torch.export.dynamic_shapes import _combine_args
|
||||
|
||||
capture_pre_autograd_graph_warning()
|
||||
|
||||
if sys.platform == "win32":
|
||||
raise RuntimeError("capture_pre_autograd_graph not yet supported on Windows")
|
||||
|
||||
assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance."
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if capture_pre_autograd_graph_using_training_ir():
|
||||
print_export_warning()
|
||||
module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module()
|
||||
else:
|
||||
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
|
||||
|
||||
# Do not decompose dropout for exported models, because in eval mode the dropout
|
||||
# op disappears from the graph, which makes it difficult to switch to train mode.
|
||||
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
|
||||
|
||||
# We force create native_batch_norm because the below materialization logic
|
||||
# only applies to CIA ops.
|
||||
maybe_aliasing_or_mutating_ops = [torch.ops.aten.native_batch_norm.default]
|
||||
|
||||
_materialize_cpp_cia_ops()
|
||||
|
||||
for op in torch.ops.aten:
|
||||
op_obj = getattr(torch.ops.aten, op)
|
||||
for overload in op_obj.overloads():
|
||||
op_overload = getattr(op_obj, overload)
|
||||
if torch.Tag.maybe_aliasing_or_mutating in op_overload.tags:
|
||||
maybe_aliasing_or_mutating_ops.append(op_overload)
|
||||
|
||||
decomp_table = {
|
||||
op: op.decompose
|
||||
for op in maybe_aliasing_or_mutating_ops
|
||||
if op != torch.ops.aten.dropout.default
|
||||
}
|
||||
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
|
||||
m = torch._dynamo.export(
|
||||
f,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
assume_static_by_default=True,
|
||||
tracing_mode="symbolic",
|
||||
decomposition_table=decomp_table,
|
||||
pre_dispatch=True,
|
||||
aten_graph=True,
|
||||
_log_export_usage=False,
|
||||
)(
|
||||
*args,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
_, _, fake_mode = _extract_fake_inputs(m, args, kwargs)
|
||||
|
||||
m.meta["inline_constraints"] = {
|
||||
k: v
|
||||
for k, v in fake_mode.shape_env.var_to_range.items()
|
||||
if re.match(r"^[if]\d+$", str(k))
|
||||
}
|
||||
|
||||
if isinstance(f, torch.nn.Module):
|
||||
from torch.export._trace import _restore_state_dict
|
||||
_restore_state_dict(f, m)
|
||||
|
||||
combined_args = _combine_args(f, args, kwargs)
|
||||
range_constraints = make_constraints(
|
||||
fake_mode,
|
||||
m,
|
||||
combined_args,
|
||||
dynamic_shapes,
|
||||
0,
|
||||
)
|
||||
|
||||
module = _create_stateful_graph_module(
|
||||
m,
|
||||
range_constraints=range_constraints,
|
||||
)
|
||||
|
||||
setattr(module, "capture_pre_autograd_graph_tag", True) # noqa: B010
|
||||
for node in module.graph.nodes:
|
||||
node.meta["capture_pre_autograd_graph_tag"] = True
|
||||
|
||||
error_message = \
|
||||
"""
|
||||
Calling train() or eval() is not supported for exported models.
|
||||
Alternatively, you may override these methods to do custom user behavior as follows:
|
||||
|
||||
def _my_train(self, mode: bool = True):
|
||||
...
|
||||
|
||||
def _my_eval(self):
|
||||
...
|
||||
|
||||
model.train = types.MethodType(_my_train, model)
|
||||
model.eval = types.MethodType(_my_eval, model)
|
||||
"""
|
||||
|
||||
def _train(self, mode: bool = True):
|
||||
raise NotImplementedError(error_message)
|
||||
|
||||
def _eval(self, mode: bool = True):
|
||||
raise NotImplementedError(error_message)
|
||||
|
||||
module.train = types.MethodType(_train, module) # type: ignore[method-assign]
|
||||
module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]
|
||||
|
||||
# Remove Proxy because they cannot be deepcopied or pickled.
|
||||
if hasattr(module, "_buffers"):
|
||||
torch._export.utils.remove_proxy_from_state_dict(
|
||||
module._buffers, in_place=True
|
||||
)
|
||||
return module
|
||||
|
||||
|
||||
# We only want to print this once to avoid flooding logs in workflows where aot_compile_warning
|
||||
# is called multiple times.
|
||||
@lru_cache
|
||||
|
|
|
|||
|
|
@ -167,10 +167,6 @@ def log_torch_jit_trace_exportability(
|
|||
return
|
||||
|
||||
|
||||
def capture_pre_autograd_graph_using_training_ir() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def justknobs_check(name: str, default: bool = True) -> bool:
|
||||
"""
|
||||
This function can be used to killswitch functionality in FB prod,
|
||||
|
|
|
|||
|
|
@ -55,10 +55,6 @@ def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
|
|||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
|
||||
from torch._export import gm_using_training_ir
|
||||
|
||||
using_training_ir = gm_using_training_ir(m)
|
||||
|
||||
for inplace in [False, True]:
|
||||
|
||||
def dropout_train(x):
|
||||
|
|
@ -72,23 +68,19 @@ def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
|
|||
match_pattern = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(dropout_train),
|
||||
example_inputs,
|
||||
using_training_ir=using_training_ir,
|
||||
)
|
||||
replacement_pattern = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(dropout_eval),
|
||||
example_inputs,
|
||||
using_training_ir=using_training_ir,
|
||||
)
|
||||
else:
|
||||
match_pattern = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(dropout_eval),
|
||||
example_inputs,
|
||||
using_training_ir=using_training_ir,
|
||||
)
|
||||
replacement_pattern = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(dropout_train),
|
||||
example_inputs,
|
||||
using_training_ir=using_training_ir,
|
||||
)
|
||||
|
||||
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
|
||||
|
|
@ -122,10 +114,6 @@ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
|
|||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
|
||||
from torch._export import gm_using_training_ir
|
||||
|
||||
using_training_ir = gm_using_training_ir(m)
|
||||
|
||||
def bn_train(
|
||||
x: torch.Tensor,
|
||||
bn_weight: torch.Tensor,
|
||||
|
|
@ -162,13 +150,11 @@ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
|
|||
_WrapperModule(bn_train),
|
||||
example_inputs,
|
||||
is_cuda,
|
||||
using_training_ir=using_training_ir,
|
||||
)
|
||||
bn_eval_aten = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(bn_eval),
|
||||
example_inputs,
|
||||
is_cuda,
|
||||
using_training_ir=using_training_ir,
|
||||
)
|
||||
|
||||
if train_to_eval:
|
||||
|
|
|
|||
|
|
@ -667,16 +667,11 @@ def _fuse_conv_bn_qat_helper(
|
|||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
|
||||
from torch._export import gm_using_training_ir
|
||||
|
||||
using_training_ir = gm_using_training_ir(m)
|
||||
|
||||
conv_bn_pattern = _get_conv_bn_pattern(conv_fn)
|
||||
match_pattern = _get_aten_graph_module_for_pattern(
|
||||
conv_bn_pattern,
|
||||
example_inputs,
|
||||
is_cuda,
|
||||
using_training_ir=using_training_ir,
|
||||
)
|
||||
|
||||
# Step (1): Replace patterns with conv bias
|
||||
|
|
@ -690,7 +685,6 @@ def _fuse_conv_bn_qat_helper(
|
|||
qat_conv_bn_pattern,
|
||||
example_inputs,
|
||||
is_cuda,
|
||||
using_training_ir=using_training_ir,
|
||||
)
|
||||
replacements_with_conv_bias = replace_pattern_with_filters(
|
||||
m,
|
||||
|
|
@ -708,7 +702,6 @@ def _fuse_conv_bn_qat_helper(
|
|||
qat_conv_bn_pattern_no_conv_bias,
|
||||
example_inputs,
|
||||
is_cuda,
|
||||
using_training_ir=using_training_ir,
|
||||
)
|
||||
replacements_no_conv_bias = replace_pattern_with_filters(
|
||||
m,
|
||||
|
|
@ -922,9 +915,6 @@ def _fold_conv_bn_qat_helper(
|
|||
"""
|
||||
Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv.
|
||||
"""
|
||||
from torch._export import gm_using_training_ir
|
||||
|
||||
using_training_ir = gm_using_training_ir(m)
|
||||
|
||||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
|
|
@ -958,7 +948,6 @@ def _fold_conv_bn_qat_helper(
|
|||
match_pattern,
|
||||
example_inputs,
|
||||
is_cuda,
|
||||
using_training_ir=using_training_ir,
|
||||
**kwargs,
|
||||
)
|
||||
replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern(
|
||||
|
|
@ -968,7 +957,6 @@ def _fold_conv_bn_qat_helper(
|
|||
replacement_pattern,
|
||||
example_inputs,
|
||||
is_cuda,
|
||||
using_training_ir=using_training_ir,
|
||||
**kwargs,
|
||||
)
|
||||
replacements.extend(
|
||||
|
|
|
|||
|
|
@ -797,9 +797,6 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
|||
]
|
||||
|
||||
remove_tensor_overload_for_qdq_ops(model)
|
||||
from torch._export import gm_using_training_ir
|
||||
|
||||
using_training_ir = gm_using_training_ir(model)
|
||||
|
||||
for rewrite_info in _REWRITE_INFO_LIST:
|
||||
example_inputs = rewrite_info.example_inputs
|
||||
|
|
@ -807,9 +804,9 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
|||
replacement = rewrite_info.replacement
|
||||
pattern_post_trans = rewrite_info.pattern_post_trans
|
||||
replacement_post_trans = rewrite_info.replacement_post_trans
|
||||
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, using_training_ir=using_training_ir) # type: ignore[arg-type, assignment]
|
||||
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment]
|
||||
remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type]
|
||||
replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs, using_training_ir=using_training_ir) # type: ignore[arg-type, assignment]
|
||||
replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment]
|
||||
remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type]
|
||||
if pattern_post_trans:
|
||||
pattern = pattern_post_trans(pattern)
|
||||
|
|
|
|||
|
|
@ -351,6 +351,8 @@ def _get_aten_graph_module_for_pattern(
|
|||
[x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
|
||||
)
|
||||
|
||||
# T199018392
|
||||
# TODO: remove the using_training_ir flag from function
|
||||
if using_training_ir:
|
||||
aten_pattern = torch.export.export_for_training(
|
||||
pattern, # type: ignore[arg-type]
|
||||
|
|
|
|||
|
|
@ -530,10 +530,6 @@ def _do_annotate_conv_bn(
|
|||
gm.graph.eliminate_dead_code()
|
||||
gm.recompile()
|
||||
|
||||
from torch._export import gm_using_training_ir
|
||||
|
||||
using_training_ir = gm_using_training_ir(gm)
|
||||
|
||||
matches = []
|
||||
if is_conv_transpose:
|
||||
combinations = [
|
||||
|
|
@ -556,7 +552,7 @@ def _do_annotate_conv_bn(
|
|||
# Match against all conv dimensions and cuda variants
|
||||
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: # type: ignore[misc]
|
||||
pattern = get_pattern(conv_fn, relu_is_inplace) # type: ignore[has-type]
|
||||
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda, using_training_ir=using_training_ir) # type: ignore[has-type]
|
||||
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda) # type: ignore[has-type]
|
||||
pattern.graph.eliminate_dead_code()
|
||||
pattern.recompile()
|
||||
matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
|
||||
|
|
|
|||
Loading…
Reference in a new issue