[BE] Make maybe_aliasing_or_mutating proper tag (#131990)

For better tracking, we need to make maybe aliasing/mutating ops with proper tag. We need to special case native_batch_norm because it is not a CIA but has a wrong schema. I guess native_batch_norm will be removed at some point, so until then we just keep it around.

D60347117
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131990
Approved by: https://github.com/bdhirsh
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2024-11-21 12:17:36 -08:00 committed by PyTorch MergeBot
parent c513f01516
commit 11c786dcb5
7 changed files with 59 additions and 32 deletions

View file

@ -312,25 +312,25 @@
- func: _shape_as_tensor(Tensor self) -> Tensor
- func: dropout(Tensor input, float p, bool train) -> Tensor
tags: nondeterministic_seeded
tags: [nondeterministic_seeded, maybe_aliasing_or_mutating]
- func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
tags: nondeterministic_seeded
- func: feature_dropout(Tensor input, float p, bool train) -> Tensor
tags: nondeterministic_seeded
tags: [nondeterministic_seeded, maybe_aliasing_or_mutating]
- func: feature_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
tags: nondeterministic_seeded
- func: alpha_dropout(Tensor input, float p, bool train) -> Tensor
tags: nondeterministic_seeded
tags: [nondeterministic_seeded, maybe_aliasing_or_mutating]
- func: alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
tags: nondeterministic_seeded
- func: feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor
tags: nondeterministic_seeded
tags: [nondeterministic_seeded, maybe_aliasing_or_mutating]
- func: feature_alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
tags: nondeterministic_seeded
@ -480,7 +480,7 @@
- func: conj_physical(Tensor self) -> Tensor
variants: function, method
tags: pointwise
tags: [pointwise, maybe_aliasing_or_mutating]
- func: conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
@ -1035,17 +1035,20 @@
- func: atleast_1d(Tensor self) -> Tensor
variants: function
tags: maybe_aliasing_or_mutating
- func: atleast_1d.Sequence(Tensor[] tensors) -> Tensor[]
- func: atleast_2d(Tensor self) -> Tensor
variants: function
tags: maybe_aliasing_or_mutating
- func: atleast_2d.Sequence(Tensor[] tensors) -> Tensor[]
variants: function
- func: atleast_3d(Tensor self) -> Tensor
variants: function
tags: maybe_aliasing_or_mutating
- func: atleast_3d.Sequence(Tensor[] tensors) -> Tensor[]
variants: function
@ -1079,6 +1082,7 @@
autogen: bartlett_window.periodic_out
- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor
tags: maybe_aliasing_or_mutating
- func: quantized_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor
dispatch:
@ -1086,6 +1090,7 @@
autogen: quantized_batch_norm.out
- func: _batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)
tags: maybe_aliasing_or_mutating
- func: _batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)
@ -1468,6 +1473,7 @@
variants: function, method
device_check: NoCheck
device_guard: False
tags: maybe_aliasing_or_mutating
- func: chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]
variants: function, method
@ -7758,6 +7764,7 @@
- func: cartesian_prod(Tensor[] tensors) -> Tensor
variants: function
tags: maybe_aliasing_or_mutating
- func: combinations(Tensor self, int r=2, bool with_replacement=False) -> Tensor
variants: function

View file

@ -72,3 +72,9 @@
Pointwise operators are operators where each element of the output is computed only by accessing
the corresponding element of all the broadcasted inputs. The output shape will be the broadcasted
shape of the inputs.
- tag: maybe_aliasing_or_mutating
desc: |
For some ops, we can't statically determine whether the op is functional or not. Note that this is only
relevant to CIA ops that decompose before functionalization/autograd. It is useful to
know this information for export as we would want to decompose these ops as they are unsafe to be
preserved.

View file

@ -31,7 +31,7 @@ __all__ = [
"register_decomposition",
"get_decompositions",
"core_aten_decompositions",
"_special_op_to_preserve_cia",
"_should_decompose_because_unsafe_op",
]
_T = TypeVar("_T")
@ -48,6 +48,24 @@ pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
meta_table = global_decomposition_table["meta"]
def _should_decompose_because_unsafe_op(op: torch._ops.OperatorBase) -> bool:
"""
Returns True if the op must always decompose in export/compile tracing system
In export, we always decompose certain CIA ops that are tagged with
maybe_aliasing_or_mutating because we statically need to know if the op is
mutating or not. But these CIA ops could have different behaviour in runtime.
native_batch_norm is a prim op which has a wrong schema and it needs to be replaced
with correct schema. But until then, we will force decompose it via this tag.
"""
if not isinstance(op, torch._ops.OpOverload):
return False
if torch.Tag.maybe_aliasing_or_mutating in op.tags:
return True
return op == torch.ops.aten.native_batch_norm.default
def _add_op_to_registry(registry, op, fn):
"""
This is an internal API for adding an op to the decomposition table.

View file

@ -45,6 +45,7 @@ from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from .wrappers import _wrap_submodules
from .utils import _materialize_cpp_cia_ops
log = logging.getLogger(__name__)
@ -169,9 +170,23 @@ def capture_pre_autograd_graph(
# Do not decompose dropout for exported models, because in eval mode the dropout
# op disappears from the graph, which makes it difficult to switch to train mode.
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
# We force create native_batch_norm because the below materialization logic
# only applies to CIA ops.
maybe_aliasing_or_mutating_ops = [torch.ops.aten.native_batch_norm.default]
_materialize_cpp_cia_ops()
for op in torch.ops.aten:
op_obj = getattr(torch.ops.aten, op)
for overload in op_obj.overloads():
op_overload = getattr(op_obj, overload)
if torch.Tag.maybe_aliasing_or_mutating in op_overload.tags:
maybe_aliasing_or_mutating_ops.append(op_overload)
decomp_table = {
op: op.decompose
for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
for op in maybe_aliasing_or_mutating_ops
if op != torch.ops.aten.dropout.default
}
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():

View file

@ -1036,10 +1036,10 @@ def _special_op_to_preserve_cia(*args, **kwargs):
# 1. The op should be known statically that it is functional
# 2. If it is maybe aliasing, we decompose because we must know if an op
# is mutating or aliasing.
# TODO (tmanlaibaatar) make this utility function and share it with functional_tensor
# decomp part. (https://github.com/pytorch/pytorch/issues/129431)
def _check_valid_to_preserve(op_overload: "OperatorBase"):
if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops:
from torch._decomp import _should_decompose_because_unsafe_op
if _should_decompose_because_unsafe_op(op_overload):
return False
if op_overload in FunctionalTensor.metadata_fns:
return False

View file

@ -91,26 +91,6 @@ class FunctionalTensor(torch.Tensor):
torch.ops.prim.device.default, # type: ignore[has-type]
]
# These are ops that claim to be functional, but actually are maybe-mutating/maybe-aliasing
# TODO (tmanlaibaatar) make it a tag
maybe_aliasing_or_mutating_ops = [
torch.ops.aten.dropout.default, # type: ignore[has-type]
torch.ops.aten.batch_norm.default, # type: ignore[has-type]
torch.ops.aten.native_batch_norm.default, # type: ignore[has-type]
torch.ops.aten._batch_norm_impl_index.default, # type: ignore[has-type]
torch.ops.aten.cudnn_batch_norm.default, # type: ignore[has-type]
torch.ops.aten.miopen_batch_norm.default, # type: ignore[has-type]
torch.ops.aten.atleast_1d.default, # type: ignore[has-type]
torch.ops.aten.atleast_2d.default, # type: ignore[has-type]
torch.ops.aten.atleast_3d.default, # type: ignore[has-type]
torch.ops.aten.cartesian_prod.default, # type: ignore[has-type]
torch.ops.aten.conj_physical.default, # type: ignore[has-type]
torch.ops.aten.alpha_dropout.default, # type: ignore[has-type]
torch.ops.aten.feature_dropout.default, # type: ignore[has-type]
torch.ops.aten.feature_alpha_dropout.default, # type: ignore[has-type]
torch.ops.aten.unsafe_chunk.default, # type: ignore[has-type]
]
# Used by auto_functionalize to determine base of tensors during inference mode.
_inference_mode_base: Optional["FunctionalTensor"] = None
@ -410,7 +390,9 @@ class FunctionalTensorMode(TorchDispatchMode):
return False
# We unconditionally decompose ops that are maybe aliasing or mutating ops
if func in FunctionalTensor.maybe_aliasing_or_mutating_ops:
from torch._decomp import _should_decompose_because_unsafe_op
if _should_decompose_because_unsafe_op(func):
return True
# (1) we unconditionally decompose maybe-aliasing or maybe-mutating ops,

View file

@ -1794,7 +1794,6 @@ def _non_strict_export(
)
# TODO (tmanlaibaatar) We need to preserve aten.to here somehow
@_log_export_wrapper
@_disable_prexisiting_fake_mode
def _export_for_training(