diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 0aa7b06f445..a577cdaff77 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -73f54ba5bd7fb83d7ba81fe6f5e05fb6ee815d6f +b2b890e962f5fb6f481e5da2eb4a43bb990d0f1b diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 0544301d090..2ac01e2e308 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -58,215 +58,6 @@ class ExportDynamoConfig: allow_rnn: bool = True -# We only want to print this once to avoid flooding logs in workflows where capture_pre_autograd_graph -# is called multiple times. -@lru_cache -def capture_pre_autograd_graph_warning(): - from torch._inductor import config - - log.warning("+============================+") - log.warning("| !!! WARNING !!! |") - log.warning("+============================+") - log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.") - log.warning("Please switch to use torch.export.export_for_training instead.") - if config.is_fbcode(): - log.warning("For unittest, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950 - -@lru_cache -def print_export_warning(): - log.warning("Using torch.export.export_for_training(...,strict=True)") - -def gm_using_training_ir(graph_module: torch.fx.GraphModule) -> bool: - """ - Returns true if the graph module is detected to use training IR. - - This function checks for two specific conditions within the nodes of the graph module: - 1. The presence of the `torch.ops.aten.batch_norm.default` operation which indicates the use of training IR. - 2. The presence of deprecated IR tags on node meta or batch norm ops produced by the deprecated IR. - - The function raises a RuntimeError if both conditions are met, indicating a conflict in the IR. - """ - # TODO: clean up this code after training IR migration. - # T199018392 - has_training_ir_batch_norm = False - has_deprecated_ir_tag = getattr(graph_module, "capture_pre_autograd_graph_tag", False) - for node in graph_module.graph.nodes: - if node.op == "call_function": - if node.target == torch.ops.aten.batch_norm.default: - has_training_ir_batch_norm = True - if node.meta.get("capture_pre_autograd_graph_tag", False): - has_deprecated_ir_tag = True - if node.target in [ - torch.ops.aten._native_batch_norm_legit.default, - torch.ops.aten.cudnn_batch_norm.default, - torch.ops.aten.miopen_batch_norm.default, - ]: - has_deprecated_ir_tag = True - - if has_deprecated_ir_tag and has_training_ir_batch_norm: - raise RuntimeError("Conflicting IR detected.") - return has_training_ir_batch_norm or not has_deprecated_ir_tag - -@compatibility(is_backward_compatible=False) -def capture_pre_autograd_graph( - f: torch.nn.Module, - args: Tuple[Any], - kwargs: Optional[Dict[str, Any]] = None, - dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, -) -> torch.nn.Module: - """ - A helper function that is intended to trace a module before any pre-autograd - decomposition is run. The produced module will be "non-functional" and - composed of aten operators. Later this API will be deleted in favor of more general - torch.export API. - - Args: - f: nn.Module to be traced - - args: example positional inputs. - - kwargs: optional example keyword inputs. - - dynamic_shapes: Should either be: - 1) a dict from argument names of ``f`` to their dynamic shape specifications, - 2) a tuple that specifies dynamic shape specifications for each input in original order. - If you are specifying dynamism on keyword args, you will need to pass them in the order that - is defined in the original function signature. - - The dynamic shape of a tensor argument can be specified as either - (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is - not required to include static dimension indices in this dict, but when they are, - they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, - where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions - are denoted by None. Arguments that are dicts or tuples / lists of tensors are - recursively specified by using mappings or sequences of contained specifications. - - Returns: - An nn.Module containing the traced method. - - """ - from torch.export._trace import _extract_fake_inputs, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps - from torch._utils_internal import capture_pre_autograd_graph_using_training_ir - from torch._export.non_strict_utils import make_constraints - from torch._subclasses.functional_tensor import FunctionalTensor - from torch.export._unlift import _create_stateful_graph_module - from torch.export.dynamic_shapes import _combine_args - - capture_pre_autograd_graph_warning() - - if sys.platform == "win32": - raise RuntimeError("capture_pre_autograd_graph not yet supported on Windows") - - assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance." - - if kwargs is None: - kwargs = {} - - if capture_pre_autograd_graph_using_training_ir(): - print_export_warning() - module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module() - else: - log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"}) - - # Do not decompose dropout for exported models, because in eval mode the dropout - # op disappears from the graph, which makes it difficult to switch to train mode. - # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832. - - # We force create native_batch_norm because the below materialization logic - # only applies to CIA ops. - maybe_aliasing_or_mutating_ops = [torch.ops.aten.native_batch_norm.default] - - _materialize_cpp_cia_ops() - - for op in torch.ops.aten: - op_obj = getattr(torch.ops.aten, op) - for overload in op_obj.overloads(): - op_overload = getattr(op_obj, overload) - if torch.Tag.maybe_aliasing_or_mutating in op_overload.tags: - maybe_aliasing_or_mutating_ops.append(op_overload) - - decomp_table = { - op: op.decompose - for op in maybe_aliasing_or_mutating_ops - if op != torch.ops.aten.dropout.default - } - with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps(): - m = torch._dynamo.export( - f, - dynamic_shapes=dynamic_shapes, - assume_static_by_default=True, - tracing_mode="symbolic", - decomposition_table=decomp_table, - pre_dispatch=True, - aten_graph=True, - _log_export_usage=False, - )( - *args, - **kwargs, - )[0] - - _, _, fake_mode = _extract_fake_inputs(m, args, kwargs) - - m.meta["inline_constraints"] = { - k: v - for k, v in fake_mode.shape_env.var_to_range.items() - if re.match(r"^[if]\d+$", str(k)) - } - - if isinstance(f, torch.nn.Module): - from torch.export._trace import _restore_state_dict - _restore_state_dict(f, m) - - combined_args = _combine_args(f, args, kwargs) - range_constraints = make_constraints( - fake_mode, - m, - combined_args, - dynamic_shapes, - 0, - ) - - module = _create_stateful_graph_module( - m, - range_constraints=range_constraints, - ) - - setattr(module, "capture_pre_autograd_graph_tag", True) # noqa: B010 - for node in module.graph.nodes: - node.meta["capture_pre_autograd_graph_tag"] = True - - error_message = \ - """ - Calling train() or eval() is not supported for exported models. - Alternatively, you may override these methods to do custom user behavior as follows: - - def _my_train(self, mode: bool = True): - ... - - def _my_eval(self): - ... - - model.train = types.MethodType(_my_train, model) - model.eval = types.MethodType(_my_eval, model) - """ - - def _train(self, mode: bool = True): - raise NotImplementedError(error_message) - - def _eval(self, mode: bool = True): - raise NotImplementedError(error_message) - - module.train = types.MethodType(_train, module) # type: ignore[method-assign] - module.eval = types.MethodType(_eval, module) # type: ignore[method-assign] - - # Remove Proxy because they cannot be deepcopied or pickled. - if hasattr(module, "_buffers"): - torch._export.utils.remove_proxy_from_state_dict( - module._buffers, in_place=True - ) - return module - - # We only want to print this once to avoid flooding logs in workflows where aot_compile_warning # is called multiple times. @lru_cache diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index f58eb93d86d..f2b304d5800 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -167,10 +167,6 @@ def log_torch_jit_trace_exportability( return -def capture_pre_autograd_graph_using_training_ir() -> bool: - return False - - def justknobs_check(name: str, default: bool = True) -> bool: """ This function can be used to killswitch functionality in FB prod, diff --git a/torch/ao/quantization/pt2e/export_utils.py b/torch/ao/quantization/pt2e/export_utils.py index ca051691c0a..70cca73dd00 100644 --- a/torch/ao/quantization/pt2e/export_utils.py +++ b/torch/ao/quantization/pt2e/export_utils.py @@ -55,10 +55,6 @@ def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool): m.graph.eliminate_dead_code() m.recompile() - from torch._export import gm_using_training_ir - - using_training_ir = gm_using_training_ir(m) - for inplace in [False, True]: def dropout_train(x): @@ -72,23 +68,19 @@ def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool): match_pattern = _get_aten_graph_module_for_pattern( _WrapperModule(dropout_train), example_inputs, - using_training_ir=using_training_ir, ) replacement_pattern = _get_aten_graph_module_for_pattern( _WrapperModule(dropout_eval), example_inputs, - using_training_ir=using_training_ir, ) else: match_pattern = _get_aten_graph_module_for_pattern( _WrapperModule(dropout_eval), example_inputs, - using_training_ir=using_training_ir, ) replacement_pattern = _get_aten_graph_module_for_pattern( _WrapperModule(dropout_train), example_inputs, - using_training_ir=using_training_ir, ) from torch.fx.subgraph_rewriter import replace_pattern_with_filters @@ -122,10 +114,6 @@ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool): m.graph.eliminate_dead_code() m.recompile() - from torch._export import gm_using_training_ir - - using_training_ir = gm_using_training_ir(m) - def bn_train( x: torch.Tensor, bn_weight: torch.Tensor, @@ -162,13 +150,11 @@ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool): _WrapperModule(bn_train), example_inputs, is_cuda, - using_training_ir=using_training_ir, ) bn_eval_aten = _get_aten_graph_module_for_pattern( _WrapperModule(bn_eval), example_inputs, is_cuda, - using_training_ir=using_training_ir, ) if train_to_eval: diff --git a/torch/ao/quantization/pt2e/qat_utils.py b/torch/ao/quantization/pt2e/qat_utils.py index 7c479550a3e..d09b789b961 100644 --- a/torch/ao/quantization/pt2e/qat_utils.py +++ b/torch/ao/quantization/pt2e/qat_utils.py @@ -667,16 +667,11 @@ def _fuse_conv_bn_qat_helper( m.graph.eliminate_dead_code() m.recompile() - from torch._export import gm_using_training_ir - - using_training_ir = gm_using_training_ir(m) - conv_bn_pattern = _get_conv_bn_pattern(conv_fn) match_pattern = _get_aten_graph_module_for_pattern( conv_bn_pattern, example_inputs, is_cuda, - using_training_ir=using_training_ir, ) # Step (1): Replace patterns with conv bias @@ -690,7 +685,6 @@ def _fuse_conv_bn_qat_helper( qat_conv_bn_pattern, example_inputs, is_cuda, - using_training_ir=using_training_ir, ) replacements_with_conv_bias = replace_pattern_with_filters( m, @@ -708,7 +702,6 @@ def _fuse_conv_bn_qat_helper( qat_conv_bn_pattern_no_conv_bias, example_inputs, is_cuda, - using_training_ir=using_training_ir, ) replacements_no_conv_bias = replace_pattern_with_filters( m, @@ -922,9 +915,6 @@ def _fold_conv_bn_qat_helper( """ Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv. """ - from torch._export import gm_using_training_ir - - using_training_ir = gm_using_training_ir(m) m.graph.eliminate_dead_code() m.recompile() @@ -958,7 +948,6 @@ def _fold_conv_bn_qat_helper( match_pattern, example_inputs, is_cuda, - using_training_ir=using_training_ir, **kwargs, ) replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern( @@ -968,7 +957,6 @@ def _fold_conv_bn_qat_helper( replacement_pattern, example_inputs, is_cuda, - using_training_ir=using_training_ir, **kwargs, ) replacements.extend( diff --git a/torch/ao/quantization/pt2e/representation/rewrite.py b/torch/ao/quantization/pt2e/representation/rewrite.py index bd35798d239..c76f769ca67 100644 --- a/torch/ao/quantization/pt2e/representation/rewrite.py +++ b/torch/ao/quantization/pt2e/representation/rewrite.py @@ -797,9 +797,6 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: ] remove_tensor_overload_for_qdq_ops(model) - from torch._export import gm_using_training_ir - - using_training_ir = gm_using_training_ir(model) for rewrite_info in _REWRITE_INFO_LIST: example_inputs = rewrite_info.example_inputs @@ -807,9 +804,9 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: replacement = rewrite_info.replacement pattern_post_trans = rewrite_info.pattern_post_trans replacement_post_trans = rewrite_info.replacement_post_trans - pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, using_training_ir=using_training_ir) # type: ignore[arg-type, assignment] + pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment] remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type] - replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs, using_training_ir=using_training_ir) # type: ignore[arg-type, assignment] + replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment] remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type] if pattern_post_trans: pattern = pattern_post_trans(pattern) diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index d0801142bd1..89a7b6eb7f9 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -351,6 +351,8 @@ def _get_aten_graph_module_for_pattern( [x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs] ) + # T199018392 + # TODO: remove the using_training_ir flag from function if using_training_ir: aten_pattern = torch.export.export_for_training( pattern, # type: ignore[arg-type] diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index fa86c6f4c30..33c6182986e 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -530,10 +530,6 @@ def _do_annotate_conv_bn( gm.graph.eliminate_dead_code() gm.recompile() - from torch._export import gm_using_training_ir - - using_training_ir = gm_using_training_ir(gm) - matches = [] if is_conv_transpose: combinations = [ @@ -556,7 +552,7 @@ def _do_annotate_conv_bn( # Match against all conv dimensions and cuda variants for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: # type: ignore[misc] pattern = get_pattern(conv_fn, relu_is_inplace) # type: ignore[has-type] - pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda, using_training_ir=using_training_ir) # type: ignore[has-type] + pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda) # type: ignore[has-type] pattern.graph.eliminate_dead_code() pattern.recompile() matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)