diff --git a/docs/source/control_flow_cond.rst b/docs/source/cond.rst similarity index 99% rename from docs/source/control_flow_cond.rst rename to docs/source/cond.rst index 44031598d20..8f11cb239f0 100644 --- a/docs/source/control_flow_cond.rst +++ b/docs/source/cond.rst @@ -1,4 +1,4 @@ -.. _control_flow_cond: +.. _cond: Control Flow - Cond ==================== diff --git a/docs/source/export.rst b/docs/source/export.rst index ead18491bb6..723f4a6ca05 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -501,7 +501,7 @@ Graph breaks can also be encountered on data-dependent control flow (``if x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot possibly deal with without generating code for a combinatorially exploding number of paths. In such cases, users will need to rewrite their code using -special control flow operators. Currently, we support :ref:`torch.cond ` +special control flow operators. Currently, we support :ref:`torch.cond ` to express if-else like control flow (more coming soon!). Data-Dependent Accesses @@ -540,7 +540,7 @@ Read More torch.compiler_transformations torch.compiler_ir generated/exportdb/index - control_flow_cond + cond .. toctree:: :caption: Deep Dive for PyTorch Developers diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 4df084433df..f632b2e7e2b 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -718,6 +718,18 @@ Export Path export generated/exportdb/index +Control Flow +------------ + +.. warning:: + This feature is a prototype and may have compatibility breaking changes in the future. + +.. autosummary:: + :toctree: generated + :nosignatures: + + cond + Optimizations ------------- .. autosummary:: diff --git a/functorch/experimental/control_flow.py b/functorch/experimental/control_flow.py index ddfdd69a767..cb6ff2e4724 100644 --- a/functorch/experimental/control_flow.py +++ b/functorch/experimental/control_flow.py @@ -1,6 +1,4 @@ -from torch._higher_order_ops.cond import ( # noqa: F401 - cond, - UnsupportedAliasMutationException, -) +from torch import cond # noqa: F401 +from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401 from ._map import map # noqa: F401 diff --git a/torch/__init__.py b/torch/__init__.py index 74662edccef..e13fa1c2139 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -56,7 +56,7 @@ __all__ = [ 'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat', 'SymBool', 'sym_not', 'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap', - 'export', 'autocast', + 'export', 'autocast', 'cond', ] ################################################################################ @@ -986,7 +986,7 @@ def is_warn_always_enabled() -> builtins.bool: # These error checking functions must be kept consistent with their C++ # equivalents. Their C++ equivalents are mentioned where applicable. -def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]): +def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]): # noqa: F811 if not isinstance(cond, (builtins.bool, torch.SymBool)): raise TypeError(f'cond must be a bool, but got {type(cond)}') @@ -1010,7 +1010,7 @@ def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callab raise error_type(message_evaluated) -def _check(cond, message=None): +def _check(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1041,7 +1041,7 @@ def _check_is_size(i, message=None): _check(i >= 0, message) torch.fx.experimental.symbolic_shapes._advise_is_size(i) -def _check_index(cond, message=None): +def _check_index(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1058,7 +1058,7 @@ def _check_index(cond, message=None): """ _check_with(IndexError, cond, message) -def _check_value(cond, message=None): +def _check_value(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1075,7 +1075,7 @@ def _check_value(cond, message=None): """ _check_with(ValueError, cond, message) -def _check_type(cond, message=None): +def _check_type(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1092,7 +1092,7 @@ def _check_type(cond, message=None): """ _check_with(TypeError, cond, message) -def _check_not_implemented(cond, message=None): +def _check_not_implemented(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1109,7 +1109,7 @@ def _check_not_implemented(cond, message=None): """ _check_with(NotImplementedError, cond, message) -def _check_tensor_all_with(error_type, cond, message=None): +def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811 if not torch.is_tensor(cond): raise TypeError(f'cond must be a tensor, but got {type(cond)}') @@ -1120,7 +1120,7 @@ def _check_tensor_all_with(error_type, cond, message=None): _check_with(error_type, cond._is_all_true().item(), message) # C++ equivalent: `TORCH_CHECK_TENSOR_ALL` -def _check_tensor_all(cond, message=None): +def _check_tensor_all(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1761,6 +1761,7 @@ def compile(model: Optional[Callable] = None, *, from torch import export as export +from torch._higher_order_ops import cond def _register_device_module(device_type, module): r"""Register an external runtime module of the specific :attr:`device_type` diff --git a/torch/_dynamo/allowed_functions.py b/torch/_dynamo/allowed_functions.py index 8beca2b4502..0c1c529bcd2 100644 --- a/torch/_dynamo/allowed_functions.py +++ b/torch/_dynamo/allowed_functions.py @@ -215,6 +215,7 @@ def _allowed_function_ids(): torch.func.vmap, deprecated_func.vmap, torch.nn.functional.triplet_margin_with_distance_loss, + torch.cond, ): continue diff --git a/torch/_higher_order_ops/__init__.py b/torch/_higher_order_ops/__init__.py index e69de29bb2d..2ac132d9db5 100644 --- a/torch/_higher_order_ops/__init__.py +++ b/torch/_higher_order_ops/__init__.py @@ -0,0 +1 @@ +from .cond import cond diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 0a35839500e..c984e3d649a 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -8,8 +8,7 @@ import torch.fx.traceback as fx_traceback import torch.utils._pytree as pytree from torch._C import DispatchKey -from torch._dynamo.exc import CondOpArgsMismatchError -from torch._dynamo.utils import disable_cache_limit +from torch._functorch.utils import exposed_in from torch._higher_order_ops.utils import autograd_not_implemented from torch._ops import HigherOrderOperator @@ -42,6 +41,7 @@ class UnsupportedAliasMutationException(RuntimeError): reason: str +@exposed_in("torch") def cond(pred, true_fn, false_fn, operands): r""" Conditionally applies `true_fn` or `false_fn`. @@ -142,7 +142,7 @@ def cond(pred, true_fn, false_fn, operands): raise RuntimeError("torch.cond requires dynamo support.") with _set_compilation_env(): - with disable_cache_limit(): + with torch._dynamo.utils.disable_cache_limit(): return torch.compile(cond_op, backend="eager", fullgraph=True)( pred, true_fn, false_fn, operands ) @@ -198,7 +198,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): flat_true_outs, _ = pytree.tree_flatten(true_outs) flat_false_outs, _ = pytree.tree_flatten(false_outs) if len(flat_true_outs) != len(flat_false_outs): - raise CondOpArgsMismatchError( + raise torch._dynamo.exc.CondOpArgsMismatchError( f"Expected to return same number of outputs but got:" f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)" f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)" @@ -208,7 +208,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): true_out = flat_true_outs[i] false_out = flat_false_outs[i] if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]: - raise CondOpArgsMismatchError( + raise torch._dynamo.exc.CondOpArgsMismatchError( f"Expected each tensor to have same metadata but got:" f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}" f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}" @@ -291,7 +291,7 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands): true_meta = _extract_tensor_metadata(true_out) false_meta = _extract_tensor_metadata(false_out) if true_meta != false_meta: - raise CondOpArgsMismatchError( + raise torch._dynamo.exc.CondOpArgsMismatchError( f"Expected each tensor to have same metadata but got:" f"\n {true_fn.__name__} returns {true_meta}" f"\n {false_fn.__name__} returns {false_meta}" diff --git a/torch/overrides.py b/torch/overrides.py index a771d6021ff..4793793769d 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -297,6 +297,7 @@ def get_ignored_functions() -> Set[Callable]: torch.set_vital, torch.read_vitals, torch.vmap, + torch.cond, torch.frombuffer, torch.asarray, torch._functional_sym_constrain_range,