mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[BE] Make maybe_aliasing_or_mutating proper tag (#131990)
For better tracking, we need to make maybe aliasing/mutating ops with proper tag. We need to special case native_batch_norm because it is not a CIA but has a wrong schema. I guess native_batch_norm will be removed at some point, so until then we just keep it around. D60347117 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131990 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
c513f01516
commit
11c786dcb5
7 changed files with 59 additions and 32 deletions
|
|
@ -312,25 +312,25 @@
|
|||
- func: _shape_as_tensor(Tensor self) -> Tensor
|
||||
|
||||
- func: dropout(Tensor input, float p, bool train) -> Tensor
|
||||
tags: nondeterministic_seeded
|
||||
tags: [nondeterministic_seeded, maybe_aliasing_or_mutating]
|
||||
|
||||
- func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: feature_dropout(Tensor input, float p, bool train) -> Tensor
|
||||
tags: nondeterministic_seeded
|
||||
tags: [nondeterministic_seeded, maybe_aliasing_or_mutating]
|
||||
|
||||
- func: feature_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: alpha_dropout(Tensor input, float p, bool train) -> Tensor
|
||||
tags: nondeterministic_seeded
|
||||
tags: [nondeterministic_seeded, maybe_aliasing_or_mutating]
|
||||
|
||||
- func: alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor
|
||||
tags: nondeterministic_seeded
|
||||
tags: [nondeterministic_seeded, maybe_aliasing_or_mutating]
|
||||
|
||||
- func: feature_alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
|
||||
tags: nondeterministic_seeded
|
||||
|
|
@ -480,7 +480,7 @@
|
|||
|
||||
- func: conj_physical(Tensor self) -> Tensor
|
||||
variants: function, method
|
||||
tags: pointwise
|
||||
tags: [pointwise, maybe_aliasing_or_mutating]
|
||||
|
||||
- func: conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
|
|
@ -1035,17 +1035,20 @@
|
|||
|
||||
- func: atleast_1d(Tensor self) -> Tensor
|
||||
variants: function
|
||||
tags: maybe_aliasing_or_mutating
|
||||
|
||||
- func: atleast_1d.Sequence(Tensor[] tensors) -> Tensor[]
|
||||
|
||||
- func: atleast_2d(Tensor self) -> Tensor
|
||||
variants: function
|
||||
tags: maybe_aliasing_or_mutating
|
||||
|
||||
- func: atleast_2d.Sequence(Tensor[] tensors) -> Tensor[]
|
||||
variants: function
|
||||
|
||||
- func: atleast_3d(Tensor self) -> Tensor
|
||||
variants: function
|
||||
tags: maybe_aliasing_or_mutating
|
||||
|
||||
- func: atleast_3d.Sequence(Tensor[] tensors) -> Tensor[]
|
||||
variants: function
|
||||
|
|
@ -1079,6 +1082,7 @@
|
|||
autogen: bartlett_window.periodic_out
|
||||
|
||||
- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor
|
||||
tags: maybe_aliasing_or_mutating
|
||||
|
||||
- func: quantized_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor
|
||||
dispatch:
|
||||
|
|
@ -1086,6 +1090,7 @@
|
|||
autogen: quantized_batch_norm.out
|
||||
|
||||
- func: _batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)
|
||||
tags: maybe_aliasing_or_mutating
|
||||
|
||||
- func: _batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)
|
||||
|
||||
|
|
@ -1468,6 +1473,7 @@
|
|||
variants: function, method
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
tags: maybe_aliasing_or_mutating
|
||||
|
||||
- func: chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]
|
||||
variants: function, method
|
||||
|
|
@ -7758,6 +7764,7 @@
|
|||
|
||||
- func: cartesian_prod(Tensor[] tensors) -> Tensor
|
||||
variants: function
|
||||
tags: maybe_aliasing_or_mutating
|
||||
|
||||
- func: combinations(Tensor self, int r=2, bool with_replacement=False) -> Tensor
|
||||
variants: function
|
||||
|
|
|
|||
|
|
@ -72,3 +72,9 @@
|
|||
Pointwise operators are operators where each element of the output is computed only by accessing
|
||||
the corresponding element of all the broadcasted inputs. The output shape will be the broadcasted
|
||||
shape of the inputs.
|
||||
- tag: maybe_aliasing_or_mutating
|
||||
desc: |
|
||||
For some ops, we can't statically determine whether the op is functional or not. Note that this is only
|
||||
relevant to CIA ops that decompose before functionalization/autograd. It is useful to
|
||||
know this information for export as we would want to decompose these ops as they are unsafe to be
|
||||
preserved.
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ __all__ = [
|
|||
"register_decomposition",
|
||||
"get_decompositions",
|
||||
"core_aten_decompositions",
|
||||
"_special_op_to_preserve_cia",
|
||||
"_should_decompose_because_unsafe_op",
|
||||
]
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
|
@ -48,6 +48,24 @@ pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
|
|||
meta_table = global_decomposition_table["meta"]
|
||||
|
||||
|
||||
def _should_decompose_because_unsafe_op(op: torch._ops.OperatorBase) -> bool:
|
||||
"""
|
||||
Returns True if the op must always decompose in export/compile tracing system
|
||||
|
||||
In export, we always decompose certain CIA ops that are tagged with
|
||||
maybe_aliasing_or_mutating because we statically need to know if the op is
|
||||
mutating or not. But these CIA ops could have different behaviour in runtime.
|
||||
|
||||
native_batch_norm is a prim op which has a wrong schema and it needs to be replaced
|
||||
with correct schema. But until then, we will force decompose it via this tag.
|
||||
"""
|
||||
if not isinstance(op, torch._ops.OpOverload):
|
||||
return False
|
||||
if torch.Tag.maybe_aliasing_or_mutating in op.tags:
|
||||
return True
|
||||
return op == torch.ops.aten.native_batch_norm.default
|
||||
|
||||
|
||||
def _add_op_to_registry(registry, op, fn):
|
||||
"""
|
||||
This is an internal API for adding an op to the decomposition table.
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ from torch.fx.experimental.proxy_tensor import make_fx
|
|||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||
|
||||
from .wrappers import _wrap_submodules
|
||||
from .utils import _materialize_cpp_cia_ops
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -169,9 +170,23 @@ def capture_pre_autograd_graph(
|
|||
# Do not decompose dropout for exported models, because in eval mode the dropout
|
||||
# op disappears from the graph, which makes it difficult to switch to train mode.
|
||||
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
|
||||
|
||||
# We force create native_batch_norm because the below materialization logic
|
||||
# only applies to CIA ops.
|
||||
maybe_aliasing_or_mutating_ops = [torch.ops.aten.native_batch_norm.default]
|
||||
|
||||
_materialize_cpp_cia_ops()
|
||||
|
||||
for op in torch.ops.aten:
|
||||
op_obj = getattr(torch.ops.aten, op)
|
||||
for overload in op_obj.overloads():
|
||||
op_overload = getattr(op_obj, overload)
|
||||
if torch.Tag.maybe_aliasing_or_mutating in op_overload.tags:
|
||||
maybe_aliasing_or_mutating_ops.append(op_overload)
|
||||
|
||||
decomp_table = {
|
||||
op: op.decompose
|
||||
for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
|
||||
for op in maybe_aliasing_or_mutating_ops
|
||||
if op != torch.ops.aten.dropout.default
|
||||
}
|
||||
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
|
||||
|
|
|
|||
|
|
@ -1036,10 +1036,10 @@ def _special_op_to_preserve_cia(*args, **kwargs):
|
|||
# 1. The op should be known statically that it is functional
|
||||
# 2. If it is maybe aliasing, we decompose because we must know if an op
|
||||
# is mutating or aliasing.
|
||||
# TODO (tmanlaibaatar) make this utility function and share it with functional_tensor
|
||||
# decomp part. (https://github.com/pytorch/pytorch/issues/129431)
|
||||
def _check_valid_to_preserve(op_overload: "OperatorBase"):
|
||||
if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops:
|
||||
from torch._decomp import _should_decompose_because_unsafe_op
|
||||
|
||||
if _should_decompose_because_unsafe_op(op_overload):
|
||||
return False
|
||||
if op_overload in FunctionalTensor.metadata_fns:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -91,26 +91,6 @@ class FunctionalTensor(torch.Tensor):
|
|||
torch.ops.prim.device.default, # type: ignore[has-type]
|
||||
]
|
||||
|
||||
# These are ops that claim to be functional, but actually are maybe-mutating/maybe-aliasing
|
||||
# TODO (tmanlaibaatar) make it a tag
|
||||
maybe_aliasing_or_mutating_ops = [
|
||||
torch.ops.aten.dropout.default, # type: ignore[has-type]
|
||||
torch.ops.aten.batch_norm.default, # type: ignore[has-type]
|
||||
torch.ops.aten.native_batch_norm.default, # type: ignore[has-type]
|
||||
torch.ops.aten._batch_norm_impl_index.default, # type: ignore[has-type]
|
||||
torch.ops.aten.cudnn_batch_norm.default, # type: ignore[has-type]
|
||||
torch.ops.aten.miopen_batch_norm.default, # type: ignore[has-type]
|
||||
torch.ops.aten.atleast_1d.default, # type: ignore[has-type]
|
||||
torch.ops.aten.atleast_2d.default, # type: ignore[has-type]
|
||||
torch.ops.aten.atleast_3d.default, # type: ignore[has-type]
|
||||
torch.ops.aten.cartesian_prod.default, # type: ignore[has-type]
|
||||
torch.ops.aten.conj_physical.default, # type: ignore[has-type]
|
||||
torch.ops.aten.alpha_dropout.default, # type: ignore[has-type]
|
||||
torch.ops.aten.feature_dropout.default, # type: ignore[has-type]
|
||||
torch.ops.aten.feature_alpha_dropout.default, # type: ignore[has-type]
|
||||
torch.ops.aten.unsafe_chunk.default, # type: ignore[has-type]
|
||||
]
|
||||
|
||||
# Used by auto_functionalize to determine base of tensors during inference mode.
|
||||
_inference_mode_base: Optional["FunctionalTensor"] = None
|
||||
|
||||
|
|
@ -410,7 +390,9 @@ class FunctionalTensorMode(TorchDispatchMode):
|
|||
return False
|
||||
|
||||
# We unconditionally decompose ops that are maybe aliasing or mutating ops
|
||||
if func in FunctionalTensor.maybe_aliasing_or_mutating_ops:
|
||||
from torch._decomp import _should_decompose_because_unsafe_op
|
||||
|
||||
if _should_decompose_because_unsafe_op(func):
|
||||
return True
|
||||
|
||||
# (1) we unconditionally decompose maybe-aliasing or maybe-mutating ops,
|
||||
|
|
|
|||
|
|
@ -1794,7 +1794,6 @@ def _non_strict_export(
|
|||
)
|
||||
|
||||
|
||||
# TODO (tmanlaibaatar) We need to preserve aten.to here somehow
|
||||
@_log_export_wrapper
|
||||
@_disable_prexisiting_fake_mode
|
||||
def _export_for_training(
|
||||
|
|
|
|||
Loading…
Reference in a new issue