diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 646b6aa497f..9749dad663f 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -312,25 +312,25 @@ - func: _shape_as_tensor(Tensor self) -> Tensor - func: dropout(Tensor input, float p, bool train) -> Tensor - tags: nondeterministic_seeded + tags: [nondeterministic_seeded, maybe_aliasing_or_mutating] - func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) tags: nondeterministic_seeded - func: feature_dropout(Tensor input, float p, bool train) -> Tensor - tags: nondeterministic_seeded + tags: [nondeterministic_seeded, maybe_aliasing_or_mutating] - func: feature_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) tags: nondeterministic_seeded - func: alpha_dropout(Tensor input, float p, bool train) -> Tensor - tags: nondeterministic_seeded + tags: [nondeterministic_seeded, maybe_aliasing_or_mutating] - func: alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) tags: nondeterministic_seeded - func: feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor - tags: nondeterministic_seeded + tags: [nondeterministic_seeded, maybe_aliasing_or_mutating] - func: feature_alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) tags: nondeterministic_seeded @@ -480,7 +480,7 @@ - func: conj_physical(Tensor self) -> Tensor variants: function, method - tags: pointwise + tags: [pointwise, maybe_aliasing_or_mutating] - func: conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -1035,17 +1035,20 @@ - func: atleast_1d(Tensor self) -> Tensor variants: function + tags: maybe_aliasing_or_mutating - func: atleast_1d.Sequence(Tensor[] tensors) -> Tensor[] - func: atleast_2d(Tensor self) -> Tensor variants: function + tags: maybe_aliasing_or_mutating - func: atleast_2d.Sequence(Tensor[] tensors) -> Tensor[] variants: function - func: atleast_3d(Tensor self) -> Tensor variants: function + tags: maybe_aliasing_or_mutating - func: atleast_3d.Sequence(Tensor[] tensors) -> Tensor[] variants: function @@ -1079,6 +1082,7 @@ autogen: bartlett_window.periodic_out - func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor + tags: maybe_aliasing_or_mutating - func: quantized_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor dispatch: @@ -1086,6 +1090,7 @@ autogen: quantized_batch_norm.out - func: _batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int) + tags: maybe_aliasing_or_mutating - func: _batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor) @@ -1468,6 +1473,7 @@ variants: function, method device_check: NoCheck device_guard: False + tags: maybe_aliasing_or_mutating - func: chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[] variants: function, method @@ -7758,6 +7764,7 @@ - func: cartesian_prod(Tensor[] tensors) -> Tensor variants: function + tags: maybe_aliasing_or_mutating - func: combinations(Tensor self, int r=2, bool with_replacement=False) -> Tensor variants: function diff --git a/aten/src/ATen/native/tags.yaml b/aten/src/ATen/native/tags.yaml index 3544a3cf0b1..62f5367670b 100644 --- a/aten/src/ATen/native/tags.yaml +++ b/aten/src/ATen/native/tags.yaml @@ -72,3 +72,9 @@ Pointwise operators are operators where each element of the output is computed only by accessing the corresponding element of all the broadcasted inputs. The output shape will be the broadcasted shape of the inputs. +- tag: maybe_aliasing_or_mutating + desc: | + For some ops, we can't statically determine whether the op is functional or not. Note that this is only + relevant to CIA ops that decompose before functionalization/autograd. It is useful to + know this information for export as we would want to decompose these ops as they are unsafe to be + preserved. diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 402e07a821f..b0cfe252788 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -31,7 +31,7 @@ __all__ = [ "register_decomposition", "get_decompositions", "core_aten_decompositions", - "_special_op_to_preserve_cia", + "_should_decompose_because_unsafe_op", ] _T = TypeVar("_T") @@ -48,6 +48,24 @@ pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"] meta_table = global_decomposition_table["meta"] +def _should_decompose_because_unsafe_op(op: torch._ops.OperatorBase) -> bool: + """ + Returns True if the op must always decompose in export/compile tracing system + + In export, we always decompose certain CIA ops that are tagged with + maybe_aliasing_or_mutating because we statically need to know if the op is + mutating or not. But these CIA ops could have different behaviour in runtime. + + native_batch_norm is a prim op which has a wrong schema and it needs to be replaced + with correct schema. But until then, we will force decompose it via this tag. + """ + if not isinstance(op, torch._ops.OpOverload): + return False + if torch.Tag.maybe_aliasing_or_mutating in op.tags: + return True + return op == torch.ops.aten.native_batch_norm.default + + def _add_op_to_registry(registry, op, fn): """ This is an internal API for adding an op to the decomposition table. diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 8e67cbb7713..c34d4431a07 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -45,6 +45,7 @@ from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo from .wrappers import _wrap_submodules +from .utils import _materialize_cpp_cia_ops log = logging.getLogger(__name__) @@ -169,9 +170,23 @@ def capture_pre_autograd_graph( # Do not decompose dropout for exported models, because in eval mode the dropout # op disappears from the graph, which makes it difficult to switch to train mode. # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832. + + # We force create native_batch_norm because the below materialization logic + # only applies to CIA ops. + maybe_aliasing_or_mutating_ops = [torch.ops.aten.native_batch_norm.default] + + _materialize_cpp_cia_ops() + + for op in torch.ops.aten: + op_obj = getattr(torch.ops.aten, op) + for overload in op_obj.overloads(): + op_overload = getattr(op_obj, overload) + if torch.Tag.maybe_aliasing_or_mutating in op_overload.tags: + maybe_aliasing_or_mutating_ops.append(op_overload) + decomp_table = { op: op.decompose - for op in FunctionalTensor.maybe_aliasing_or_mutating_ops + for op in maybe_aliasing_or_mutating_ops if op != torch.ops.aten.dropout.default } with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps(): diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 9fc616ff1ff..132e25286c6 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -1036,10 +1036,10 @@ def _special_op_to_preserve_cia(*args, **kwargs): # 1. The op should be known statically that it is functional # 2. If it is maybe aliasing, we decompose because we must know if an op # is mutating or aliasing. -# TODO (tmanlaibaatar) make this utility function and share it with functional_tensor -# decomp part. (https://github.com/pytorch/pytorch/issues/129431) def _check_valid_to_preserve(op_overload: "OperatorBase"): - if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops: + from torch._decomp import _should_decompose_because_unsafe_op + + if _should_decompose_because_unsafe_op(op_overload): return False if op_overload in FunctionalTensor.metadata_fns: return False diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index 7a3cab1b095..25eab8acf4c 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -91,26 +91,6 @@ class FunctionalTensor(torch.Tensor): torch.ops.prim.device.default, # type: ignore[has-type] ] - # These are ops that claim to be functional, but actually are maybe-mutating/maybe-aliasing - # TODO (tmanlaibaatar) make it a tag - maybe_aliasing_or_mutating_ops = [ - torch.ops.aten.dropout.default, # type: ignore[has-type] - torch.ops.aten.batch_norm.default, # type: ignore[has-type] - torch.ops.aten.native_batch_norm.default, # type: ignore[has-type] - torch.ops.aten._batch_norm_impl_index.default, # type: ignore[has-type] - torch.ops.aten.cudnn_batch_norm.default, # type: ignore[has-type] - torch.ops.aten.miopen_batch_norm.default, # type: ignore[has-type] - torch.ops.aten.atleast_1d.default, # type: ignore[has-type] - torch.ops.aten.atleast_2d.default, # type: ignore[has-type] - torch.ops.aten.atleast_3d.default, # type: ignore[has-type] - torch.ops.aten.cartesian_prod.default, # type: ignore[has-type] - torch.ops.aten.conj_physical.default, # type: ignore[has-type] - torch.ops.aten.alpha_dropout.default, # type: ignore[has-type] - torch.ops.aten.feature_dropout.default, # type: ignore[has-type] - torch.ops.aten.feature_alpha_dropout.default, # type: ignore[has-type] - torch.ops.aten.unsafe_chunk.default, # type: ignore[has-type] - ] - # Used by auto_functionalize to determine base of tensors during inference mode. _inference_mode_base: Optional["FunctionalTensor"] = None @@ -410,7 +390,9 @@ class FunctionalTensorMode(TorchDispatchMode): return False # We unconditionally decompose ops that are maybe aliasing or mutating ops - if func in FunctionalTensor.maybe_aliasing_or_mutating_ops: + from torch._decomp import _should_decompose_because_unsafe_op + + if _should_decompose_because_unsafe_op(func): return True # (1) we unconditionally decompose maybe-aliasing or maybe-mutating ops, diff --git a/torch/export/_trace.py b/torch/export/_trace.py index dbedefe3b07..4100bf9016f 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -1794,7 +1794,6 @@ def _non_strict_export( ) -# TODO (tmanlaibaatar) We need to preserve aten.to here somehow @_log_export_wrapper @_disable_prexisiting_fake_mode def _export_for_training(