mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[reland] Kill capture_pre_autograd_graph API (#143426)
Summary: Delete the following API: - capture_pre_autograd_graph() - capture_pre_autograd_graph_using_training_ir() - gm_using_training_ir() Update XLA pin to include https://github.com/pytorch/xla/pull/8398 There's no more call sites to `capture_pre_autograd_graph`. Except 1) two test cases in coreml, guarded by version guard, PR to remove: https://github.com/apple/coremltools/pull/2400 2) a few call sites guarded by version guard (< 2.5.0) Test Plan: CI Differential Revision: D67354440 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143426 Approved by: https://github.com/gmagogsfm
This commit is contained in:
parent
eb67dd3e2d
commit
d8ea4ce631
8 changed files with 6 additions and 250 deletions
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
|
|
@ -1 +1 @@
|
||||||
73f54ba5bd7fb83d7ba81fe6f5e05fb6ee815d6f
|
b2b890e962f5fb6f481e5da2eb4a43bb990d0f1b
|
||||||
|
|
|
||||||
|
|
@ -58,215 +58,6 @@ class ExportDynamoConfig:
|
||||||
allow_rnn: bool = True
|
allow_rnn: bool = True
|
||||||
|
|
||||||
|
|
||||||
# We only want to print this once to avoid flooding logs in workflows where capture_pre_autograd_graph
|
|
||||||
# is called multiple times.
|
|
||||||
@lru_cache
|
|
||||||
def capture_pre_autograd_graph_warning():
|
|
||||||
from torch._inductor import config
|
|
||||||
|
|
||||||
log.warning("+============================+")
|
|
||||||
log.warning("| !!! WARNING !!! |")
|
|
||||||
log.warning("+============================+")
|
|
||||||
log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.")
|
|
||||||
log.warning("Please switch to use torch.export.export_for_training instead.")
|
|
||||||
if config.is_fbcode():
|
|
||||||
log.warning("For unittest, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950
|
|
||||||
|
|
||||||
@lru_cache
|
|
||||||
def print_export_warning():
|
|
||||||
log.warning("Using torch.export.export_for_training(...,strict=True)")
|
|
||||||
|
|
||||||
def gm_using_training_ir(graph_module: torch.fx.GraphModule) -> bool:
|
|
||||||
"""
|
|
||||||
Returns true if the graph module is detected to use training IR.
|
|
||||||
|
|
||||||
This function checks for two specific conditions within the nodes of the graph module:
|
|
||||||
1. The presence of the `torch.ops.aten.batch_norm.default` operation which indicates the use of training IR.
|
|
||||||
2. The presence of deprecated IR tags on node meta or batch norm ops produced by the deprecated IR.
|
|
||||||
|
|
||||||
The function raises a RuntimeError if both conditions are met, indicating a conflict in the IR.
|
|
||||||
"""
|
|
||||||
# TODO: clean up this code after training IR migration.
|
|
||||||
# T199018392
|
|
||||||
has_training_ir_batch_norm = False
|
|
||||||
has_deprecated_ir_tag = getattr(graph_module, "capture_pre_autograd_graph_tag", False)
|
|
||||||
for node in graph_module.graph.nodes:
|
|
||||||
if node.op == "call_function":
|
|
||||||
if node.target == torch.ops.aten.batch_norm.default:
|
|
||||||
has_training_ir_batch_norm = True
|
|
||||||
if node.meta.get("capture_pre_autograd_graph_tag", False):
|
|
||||||
has_deprecated_ir_tag = True
|
|
||||||
if node.target in [
|
|
||||||
torch.ops.aten._native_batch_norm_legit.default,
|
|
||||||
torch.ops.aten.cudnn_batch_norm.default,
|
|
||||||
torch.ops.aten.miopen_batch_norm.default,
|
|
||||||
]:
|
|
||||||
has_deprecated_ir_tag = True
|
|
||||||
|
|
||||||
if has_deprecated_ir_tag and has_training_ir_batch_norm:
|
|
||||||
raise RuntimeError("Conflicting IR detected.")
|
|
||||||
return has_training_ir_batch_norm or not has_deprecated_ir_tag
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
|
||||||
def capture_pre_autograd_graph(
|
|
||||||
f: torch.nn.Module,
|
|
||||||
args: Tuple[Any],
|
|
||||||
kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
||||||
) -> torch.nn.Module:
|
|
||||||
"""
|
|
||||||
A helper function that is intended to trace a module before any pre-autograd
|
|
||||||
decomposition is run. The produced module will be "non-functional" and
|
|
||||||
composed of aten operators. Later this API will be deleted in favor of more general
|
|
||||||
torch.export API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
f: nn.Module to be traced
|
|
||||||
|
|
||||||
args: example positional inputs.
|
|
||||||
|
|
||||||
kwargs: optional example keyword inputs.
|
|
||||||
|
|
||||||
dynamic_shapes: Should either be:
|
|
||||||
1) a dict from argument names of ``f`` to their dynamic shape specifications,
|
|
||||||
2) a tuple that specifies dynamic shape specifications for each input in original order.
|
|
||||||
If you are specifying dynamism on keyword args, you will need to pass them in the order that
|
|
||||||
is defined in the original function signature.
|
|
||||||
|
|
||||||
The dynamic shape of a tensor argument can be specified as either
|
|
||||||
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
|
|
||||||
not required to include static dimension indices in this dict, but when they are,
|
|
||||||
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
|
|
||||||
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
|
|
||||||
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
|
|
||||||
recursively specified by using mappings or sequences of contained specifications.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An nn.Module containing the traced method.
|
|
||||||
|
|
||||||
"""
|
|
||||||
from torch.export._trace import _extract_fake_inputs, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps
|
|
||||||
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
|
|
||||||
from torch._export.non_strict_utils import make_constraints
|
|
||||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
|
||||||
from torch.export._unlift import _create_stateful_graph_module
|
|
||||||
from torch.export.dynamic_shapes import _combine_args
|
|
||||||
|
|
||||||
capture_pre_autograd_graph_warning()
|
|
||||||
|
|
||||||
if sys.platform == "win32":
|
|
||||||
raise RuntimeError("capture_pre_autograd_graph not yet supported on Windows")
|
|
||||||
|
|
||||||
assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance."
|
|
||||||
|
|
||||||
if kwargs is None:
|
|
||||||
kwargs = {}
|
|
||||||
|
|
||||||
if capture_pre_autograd_graph_using_training_ir():
|
|
||||||
print_export_warning()
|
|
||||||
module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module()
|
|
||||||
else:
|
|
||||||
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
|
|
||||||
|
|
||||||
# Do not decompose dropout for exported models, because in eval mode the dropout
|
|
||||||
# op disappears from the graph, which makes it difficult to switch to train mode.
|
|
||||||
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
|
|
||||||
|
|
||||||
# We force create native_batch_norm because the below materialization logic
|
|
||||||
# only applies to CIA ops.
|
|
||||||
maybe_aliasing_or_mutating_ops = [torch.ops.aten.native_batch_norm.default]
|
|
||||||
|
|
||||||
_materialize_cpp_cia_ops()
|
|
||||||
|
|
||||||
for op in torch.ops.aten:
|
|
||||||
op_obj = getattr(torch.ops.aten, op)
|
|
||||||
for overload in op_obj.overloads():
|
|
||||||
op_overload = getattr(op_obj, overload)
|
|
||||||
if torch.Tag.maybe_aliasing_or_mutating in op_overload.tags:
|
|
||||||
maybe_aliasing_or_mutating_ops.append(op_overload)
|
|
||||||
|
|
||||||
decomp_table = {
|
|
||||||
op: op.decompose
|
|
||||||
for op in maybe_aliasing_or_mutating_ops
|
|
||||||
if op != torch.ops.aten.dropout.default
|
|
||||||
}
|
|
||||||
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
|
|
||||||
m = torch._dynamo.export(
|
|
||||||
f,
|
|
||||||
dynamic_shapes=dynamic_shapes,
|
|
||||||
assume_static_by_default=True,
|
|
||||||
tracing_mode="symbolic",
|
|
||||||
decomposition_table=decomp_table,
|
|
||||||
pre_dispatch=True,
|
|
||||||
aten_graph=True,
|
|
||||||
_log_export_usage=False,
|
|
||||||
)(
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
_, _, fake_mode = _extract_fake_inputs(m, args, kwargs)
|
|
||||||
|
|
||||||
m.meta["inline_constraints"] = {
|
|
||||||
k: v
|
|
||||||
for k, v in fake_mode.shape_env.var_to_range.items()
|
|
||||||
if re.match(r"^[if]\d+$", str(k))
|
|
||||||
}
|
|
||||||
|
|
||||||
if isinstance(f, torch.nn.Module):
|
|
||||||
from torch.export._trace import _restore_state_dict
|
|
||||||
_restore_state_dict(f, m)
|
|
||||||
|
|
||||||
combined_args = _combine_args(f, args, kwargs)
|
|
||||||
range_constraints = make_constraints(
|
|
||||||
fake_mode,
|
|
||||||
m,
|
|
||||||
combined_args,
|
|
||||||
dynamic_shapes,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
|
|
||||||
module = _create_stateful_graph_module(
|
|
||||||
m,
|
|
||||||
range_constraints=range_constraints,
|
|
||||||
)
|
|
||||||
|
|
||||||
setattr(module, "capture_pre_autograd_graph_tag", True) # noqa: B010
|
|
||||||
for node in module.graph.nodes:
|
|
||||||
node.meta["capture_pre_autograd_graph_tag"] = True
|
|
||||||
|
|
||||||
error_message = \
|
|
||||||
"""
|
|
||||||
Calling train() or eval() is not supported for exported models.
|
|
||||||
Alternatively, you may override these methods to do custom user behavior as follows:
|
|
||||||
|
|
||||||
def _my_train(self, mode: bool = True):
|
|
||||||
...
|
|
||||||
|
|
||||||
def _my_eval(self):
|
|
||||||
...
|
|
||||||
|
|
||||||
model.train = types.MethodType(_my_train, model)
|
|
||||||
model.eval = types.MethodType(_my_eval, model)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _train(self, mode: bool = True):
|
|
||||||
raise NotImplementedError(error_message)
|
|
||||||
|
|
||||||
def _eval(self, mode: bool = True):
|
|
||||||
raise NotImplementedError(error_message)
|
|
||||||
|
|
||||||
module.train = types.MethodType(_train, module) # type: ignore[method-assign]
|
|
||||||
module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]
|
|
||||||
|
|
||||||
# Remove Proxy because they cannot be deepcopied or pickled.
|
|
||||||
if hasattr(module, "_buffers"):
|
|
||||||
torch._export.utils.remove_proxy_from_state_dict(
|
|
||||||
module._buffers, in_place=True
|
|
||||||
)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
# We only want to print this once to avoid flooding logs in workflows where aot_compile_warning
|
# We only want to print this once to avoid flooding logs in workflows where aot_compile_warning
|
||||||
# is called multiple times.
|
# is called multiple times.
|
||||||
@lru_cache
|
@lru_cache
|
||||||
|
|
|
||||||
|
|
@ -167,10 +167,6 @@ def log_torch_jit_trace_exportability(
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def capture_pre_autograd_graph_using_training_ir() -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def justknobs_check(name: str, default: bool = True) -> bool:
|
def justknobs_check(name: str, default: bool = True) -> bool:
|
||||||
"""
|
"""
|
||||||
This function can be used to killswitch functionality in FB prod,
|
This function can be used to killswitch functionality in FB prod,
|
||||||
|
|
|
||||||
|
|
@ -55,10 +55,6 @@ def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
|
||||||
m.graph.eliminate_dead_code()
|
m.graph.eliminate_dead_code()
|
||||||
m.recompile()
|
m.recompile()
|
||||||
|
|
||||||
from torch._export import gm_using_training_ir
|
|
||||||
|
|
||||||
using_training_ir = gm_using_training_ir(m)
|
|
||||||
|
|
||||||
for inplace in [False, True]:
|
for inplace in [False, True]:
|
||||||
|
|
||||||
def dropout_train(x):
|
def dropout_train(x):
|
||||||
|
|
@ -72,23 +68,19 @@ def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
|
||||||
match_pattern = _get_aten_graph_module_for_pattern(
|
match_pattern = _get_aten_graph_module_for_pattern(
|
||||||
_WrapperModule(dropout_train),
|
_WrapperModule(dropout_train),
|
||||||
example_inputs,
|
example_inputs,
|
||||||
using_training_ir=using_training_ir,
|
|
||||||
)
|
)
|
||||||
replacement_pattern = _get_aten_graph_module_for_pattern(
|
replacement_pattern = _get_aten_graph_module_for_pattern(
|
||||||
_WrapperModule(dropout_eval),
|
_WrapperModule(dropout_eval),
|
||||||
example_inputs,
|
example_inputs,
|
||||||
using_training_ir=using_training_ir,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
match_pattern = _get_aten_graph_module_for_pattern(
|
match_pattern = _get_aten_graph_module_for_pattern(
|
||||||
_WrapperModule(dropout_eval),
|
_WrapperModule(dropout_eval),
|
||||||
example_inputs,
|
example_inputs,
|
||||||
using_training_ir=using_training_ir,
|
|
||||||
)
|
)
|
||||||
replacement_pattern = _get_aten_graph_module_for_pattern(
|
replacement_pattern = _get_aten_graph_module_for_pattern(
|
||||||
_WrapperModule(dropout_train),
|
_WrapperModule(dropout_train),
|
||||||
example_inputs,
|
example_inputs,
|
||||||
using_training_ir=using_training_ir,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
|
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
|
||||||
|
|
@ -122,10 +114,6 @@ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
|
||||||
m.graph.eliminate_dead_code()
|
m.graph.eliminate_dead_code()
|
||||||
m.recompile()
|
m.recompile()
|
||||||
|
|
||||||
from torch._export import gm_using_training_ir
|
|
||||||
|
|
||||||
using_training_ir = gm_using_training_ir(m)
|
|
||||||
|
|
||||||
def bn_train(
|
def bn_train(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bn_weight: torch.Tensor,
|
bn_weight: torch.Tensor,
|
||||||
|
|
@ -162,13 +150,11 @@ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
|
||||||
_WrapperModule(bn_train),
|
_WrapperModule(bn_train),
|
||||||
example_inputs,
|
example_inputs,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
using_training_ir=using_training_ir,
|
|
||||||
)
|
)
|
||||||
bn_eval_aten = _get_aten_graph_module_for_pattern(
|
bn_eval_aten = _get_aten_graph_module_for_pattern(
|
||||||
_WrapperModule(bn_eval),
|
_WrapperModule(bn_eval),
|
||||||
example_inputs,
|
example_inputs,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
using_training_ir=using_training_ir,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if train_to_eval:
|
if train_to_eval:
|
||||||
|
|
|
||||||
|
|
@ -667,16 +667,11 @@ def _fuse_conv_bn_qat_helper(
|
||||||
m.graph.eliminate_dead_code()
|
m.graph.eliminate_dead_code()
|
||||||
m.recompile()
|
m.recompile()
|
||||||
|
|
||||||
from torch._export import gm_using_training_ir
|
|
||||||
|
|
||||||
using_training_ir = gm_using_training_ir(m)
|
|
||||||
|
|
||||||
conv_bn_pattern = _get_conv_bn_pattern(conv_fn)
|
conv_bn_pattern = _get_conv_bn_pattern(conv_fn)
|
||||||
match_pattern = _get_aten_graph_module_for_pattern(
|
match_pattern = _get_aten_graph_module_for_pattern(
|
||||||
conv_bn_pattern,
|
conv_bn_pattern,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
using_training_ir=using_training_ir,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step (1): Replace patterns with conv bias
|
# Step (1): Replace patterns with conv bias
|
||||||
|
|
@ -690,7 +685,6 @@ def _fuse_conv_bn_qat_helper(
|
||||||
qat_conv_bn_pattern,
|
qat_conv_bn_pattern,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
using_training_ir=using_training_ir,
|
|
||||||
)
|
)
|
||||||
replacements_with_conv_bias = replace_pattern_with_filters(
|
replacements_with_conv_bias = replace_pattern_with_filters(
|
||||||
m,
|
m,
|
||||||
|
|
@ -708,7 +702,6 @@ def _fuse_conv_bn_qat_helper(
|
||||||
qat_conv_bn_pattern_no_conv_bias,
|
qat_conv_bn_pattern_no_conv_bias,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
using_training_ir=using_training_ir,
|
|
||||||
)
|
)
|
||||||
replacements_no_conv_bias = replace_pattern_with_filters(
|
replacements_no_conv_bias = replace_pattern_with_filters(
|
||||||
m,
|
m,
|
||||||
|
|
@ -922,9 +915,6 @@ def _fold_conv_bn_qat_helper(
|
||||||
"""
|
"""
|
||||||
Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv.
|
Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv.
|
||||||
"""
|
"""
|
||||||
from torch._export import gm_using_training_ir
|
|
||||||
|
|
||||||
using_training_ir = gm_using_training_ir(m)
|
|
||||||
|
|
||||||
m.graph.eliminate_dead_code()
|
m.graph.eliminate_dead_code()
|
||||||
m.recompile()
|
m.recompile()
|
||||||
|
|
@ -958,7 +948,6 @@ def _fold_conv_bn_qat_helper(
|
||||||
match_pattern,
|
match_pattern,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
using_training_ir=using_training_ir,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern(
|
replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern(
|
||||||
|
|
@ -968,7 +957,6 @@ def _fold_conv_bn_qat_helper(
|
||||||
replacement_pattern,
|
replacement_pattern,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
using_training_ir=using_training_ir,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
replacements.extend(
|
replacements.extend(
|
||||||
|
|
|
||||||
|
|
@ -797,9 +797,6 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
||||||
]
|
]
|
||||||
|
|
||||||
remove_tensor_overload_for_qdq_ops(model)
|
remove_tensor_overload_for_qdq_ops(model)
|
||||||
from torch._export import gm_using_training_ir
|
|
||||||
|
|
||||||
using_training_ir = gm_using_training_ir(model)
|
|
||||||
|
|
||||||
for rewrite_info in _REWRITE_INFO_LIST:
|
for rewrite_info in _REWRITE_INFO_LIST:
|
||||||
example_inputs = rewrite_info.example_inputs
|
example_inputs = rewrite_info.example_inputs
|
||||||
|
|
@ -807,9 +804,9 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
||||||
replacement = rewrite_info.replacement
|
replacement = rewrite_info.replacement
|
||||||
pattern_post_trans = rewrite_info.pattern_post_trans
|
pattern_post_trans = rewrite_info.pattern_post_trans
|
||||||
replacement_post_trans = rewrite_info.replacement_post_trans
|
replacement_post_trans = rewrite_info.replacement_post_trans
|
||||||
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, using_training_ir=using_training_ir) # type: ignore[arg-type, assignment]
|
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment]
|
||||||
remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type]
|
remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type]
|
||||||
replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs, using_training_ir=using_training_ir) # type: ignore[arg-type, assignment]
|
replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment]
|
||||||
remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type]
|
remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type]
|
||||||
if pattern_post_trans:
|
if pattern_post_trans:
|
||||||
pattern = pattern_post_trans(pattern)
|
pattern = pattern_post_trans(pattern)
|
||||||
|
|
|
||||||
|
|
@ -351,6 +351,8 @@ def _get_aten_graph_module_for_pattern(
|
||||||
[x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
|
[x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# T199018392
|
||||||
|
# TODO: remove the using_training_ir flag from function
|
||||||
if using_training_ir:
|
if using_training_ir:
|
||||||
aten_pattern = torch.export.export_for_training(
|
aten_pattern = torch.export.export_for_training(
|
||||||
pattern, # type: ignore[arg-type]
|
pattern, # type: ignore[arg-type]
|
||||||
|
|
|
||||||
|
|
@ -530,10 +530,6 @@ def _do_annotate_conv_bn(
|
||||||
gm.graph.eliminate_dead_code()
|
gm.graph.eliminate_dead_code()
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
from torch._export import gm_using_training_ir
|
|
||||||
|
|
||||||
using_training_ir = gm_using_training_ir(gm)
|
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
if is_conv_transpose:
|
if is_conv_transpose:
|
||||||
combinations = [
|
combinations = [
|
||||||
|
|
@ -556,7 +552,7 @@ def _do_annotate_conv_bn(
|
||||||
# Match against all conv dimensions and cuda variants
|
# Match against all conv dimensions and cuda variants
|
||||||
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: # type: ignore[misc]
|
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: # type: ignore[misc]
|
||||||
pattern = get_pattern(conv_fn, relu_is_inplace) # type: ignore[has-type]
|
pattern = get_pattern(conv_fn, relu_is_inplace) # type: ignore[has-type]
|
||||||
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda, using_training_ir=using_training_ir) # type: ignore[has-type]
|
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda) # type: ignore[has-type]
|
||||||
pattern.graph.eliminate_dead_code()
|
pattern.graph.eliminate_dead_code()
|
||||||
pattern.recompile()
|
pattern.recompile()
|
||||||
matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
|
matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue