From d84bcb9c8ce2a25b33556dbae96821678eb6f8d9 Mon Sep 17 00:00:00 2001 From: ydwu4 Date: Fri, 6 Oct 2023 14:40:21 -0700 Subject: [PATCH] [HigherOrderOp] expose torch.cond (#110293) This pr expose torch._higher_order_ops.cond as torch.cond. 1. Need to add #noqa: F811 to the _check calls in torch/__init__.py to address some confusing linter error "Redefinition of unused 'cond'" but only one cond is imported and for these lines that have this error, they don't define the cond but just use it as an argument. 2. Also add cond to the list that allows it to be traced through so as dynamo could trigger the CondHigherOrder logic instead of creating a TorchVariable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110293 Approved by: https://github.com/zou3519 --- .../{control_flow_cond.rst => cond.rst} | 2 +- docs/source/export.rst | 4 ++-- docs/source/torch.rst | 12 ++++++++++++ functorch/experimental/control_flow.py | 6 ++---- torch/__init__.py | 19 ++++++++++--------- torch/_dynamo/allowed_functions.py | 1 + torch/_higher_order_ops/__init__.py | 1 + torch/_higher_order_ops/cond.py | 12 ++++++------ torch/overrides.py | 1 + 9 files changed, 36 insertions(+), 22 deletions(-) rename docs/source/{control_flow_cond.rst => cond.rst} (99%) diff --git a/docs/source/control_flow_cond.rst b/docs/source/cond.rst similarity index 99% rename from docs/source/control_flow_cond.rst rename to docs/source/cond.rst index 44031598d20..8f11cb239f0 100644 --- a/docs/source/control_flow_cond.rst +++ b/docs/source/cond.rst @@ -1,4 +1,4 @@ -.. _control_flow_cond: +.. _cond: Control Flow - Cond ==================== diff --git a/docs/source/export.rst b/docs/source/export.rst index ead18491bb6..723f4a6ca05 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -501,7 +501,7 @@ Graph breaks can also be encountered on data-dependent control flow (``if x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot possibly deal with without generating code for a combinatorially exploding number of paths. In such cases, users will need to rewrite their code using -special control flow operators. Currently, we support :ref:`torch.cond ` +special control flow operators. Currently, we support :ref:`torch.cond ` to express if-else like control flow (more coming soon!). Data-Dependent Accesses @@ -540,7 +540,7 @@ Read More torch.compiler_transformations torch.compiler_ir generated/exportdb/index - control_flow_cond + cond .. toctree:: :caption: Deep Dive for PyTorch Developers diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 4df084433df..f632b2e7e2b 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -718,6 +718,18 @@ Export Path export generated/exportdb/index +Control Flow +------------ + +.. warning:: + This feature is a prototype and may have compatibility breaking changes in the future. + +.. autosummary:: + :toctree: generated + :nosignatures: + + cond + Optimizations ------------- .. autosummary:: diff --git a/functorch/experimental/control_flow.py b/functorch/experimental/control_flow.py index ddfdd69a767..cb6ff2e4724 100644 --- a/functorch/experimental/control_flow.py +++ b/functorch/experimental/control_flow.py @@ -1,6 +1,4 @@ -from torch._higher_order_ops.cond import ( # noqa: F401 - cond, - UnsupportedAliasMutationException, -) +from torch import cond # noqa: F401 +from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401 from ._map import map # noqa: F401 diff --git a/torch/__init__.py b/torch/__init__.py index 74662edccef..e13fa1c2139 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -56,7 +56,7 @@ __all__ = [ 'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat', 'SymBool', 'sym_not', 'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap', - 'export', 'autocast', + 'export', 'autocast', 'cond', ] ################################################################################ @@ -986,7 +986,7 @@ def is_warn_always_enabled() -> builtins.bool: # These error checking functions must be kept consistent with their C++ # equivalents. Their C++ equivalents are mentioned where applicable. -def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]): +def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]): # noqa: F811 if not isinstance(cond, (builtins.bool, torch.SymBool)): raise TypeError(f'cond must be a bool, but got {type(cond)}') @@ -1010,7 +1010,7 @@ def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callab raise error_type(message_evaluated) -def _check(cond, message=None): +def _check(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1041,7 +1041,7 @@ def _check_is_size(i, message=None): _check(i >= 0, message) torch.fx.experimental.symbolic_shapes._advise_is_size(i) -def _check_index(cond, message=None): +def _check_index(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1058,7 +1058,7 @@ def _check_index(cond, message=None): """ _check_with(IndexError, cond, message) -def _check_value(cond, message=None): +def _check_value(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1075,7 +1075,7 @@ def _check_value(cond, message=None): """ _check_with(ValueError, cond, message) -def _check_type(cond, message=None): +def _check_type(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1092,7 +1092,7 @@ def _check_type(cond, message=None): """ _check_with(TypeError, cond, message) -def _check_not_implemented(cond, message=None): +def _check_not_implemented(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1109,7 +1109,7 @@ def _check_not_implemented(cond, message=None): """ _check_with(NotImplementedError, cond, message) -def _check_tensor_all_with(error_type, cond, message=None): +def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811 if not torch.is_tensor(cond): raise TypeError(f'cond must be a tensor, but got {type(cond)}') @@ -1120,7 +1120,7 @@ def _check_tensor_all_with(error_type, cond, message=None): _check_with(error_type, cond._is_all_true().item(), message) # C++ equivalent: `TORCH_CHECK_TENSOR_ALL` -def _check_tensor_all(cond, message=None): +def _check_tensor_all(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1761,6 +1761,7 @@ def compile(model: Optional[Callable] = None, *, from torch import export as export +from torch._higher_order_ops import cond def _register_device_module(device_type, module): r"""Register an external runtime module of the specific :attr:`device_type` diff --git a/torch/_dynamo/allowed_functions.py b/torch/_dynamo/allowed_functions.py index 8beca2b4502..0c1c529bcd2 100644 --- a/torch/_dynamo/allowed_functions.py +++ b/torch/_dynamo/allowed_functions.py @@ -215,6 +215,7 @@ def _allowed_function_ids(): torch.func.vmap, deprecated_func.vmap, torch.nn.functional.triplet_margin_with_distance_loss, + torch.cond, ): continue diff --git a/torch/_higher_order_ops/__init__.py b/torch/_higher_order_ops/__init__.py index e69de29bb2d..2ac132d9db5 100644 --- a/torch/_higher_order_ops/__init__.py +++ b/torch/_higher_order_ops/__init__.py @@ -0,0 +1 @@ +from .cond import cond diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 0a35839500e..c984e3d649a 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -8,8 +8,7 @@ import torch.fx.traceback as fx_traceback import torch.utils._pytree as pytree from torch._C import DispatchKey -from torch._dynamo.exc import CondOpArgsMismatchError -from torch._dynamo.utils import disable_cache_limit +from torch._functorch.utils import exposed_in from torch._higher_order_ops.utils import autograd_not_implemented from torch._ops import HigherOrderOperator @@ -42,6 +41,7 @@ class UnsupportedAliasMutationException(RuntimeError): reason: str +@exposed_in("torch") def cond(pred, true_fn, false_fn, operands): r""" Conditionally applies `true_fn` or `false_fn`. @@ -142,7 +142,7 @@ def cond(pred, true_fn, false_fn, operands): raise RuntimeError("torch.cond requires dynamo support.") with _set_compilation_env(): - with disable_cache_limit(): + with torch._dynamo.utils.disable_cache_limit(): return torch.compile(cond_op, backend="eager", fullgraph=True)( pred, true_fn, false_fn, operands ) @@ -198,7 +198,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): flat_true_outs, _ = pytree.tree_flatten(true_outs) flat_false_outs, _ = pytree.tree_flatten(false_outs) if len(flat_true_outs) != len(flat_false_outs): - raise CondOpArgsMismatchError( + raise torch._dynamo.exc.CondOpArgsMismatchError( f"Expected to return same number of outputs but got:" f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)" f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)" @@ -208,7 +208,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): true_out = flat_true_outs[i] false_out = flat_false_outs[i] if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]: - raise CondOpArgsMismatchError( + raise torch._dynamo.exc.CondOpArgsMismatchError( f"Expected each tensor to have same metadata but got:" f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}" f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}" @@ -291,7 +291,7 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands): true_meta = _extract_tensor_metadata(true_out) false_meta = _extract_tensor_metadata(false_out) if true_meta != false_meta: - raise CondOpArgsMismatchError( + raise torch._dynamo.exc.CondOpArgsMismatchError( f"Expected each tensor to have same metadata but got:" f"\n {true_fn.__name__} returns {true_meta}" f"\n {false_fn.__name__} returns {false_meta}" diff --git a/torch/overrides.py b/torch/overrides.py index a771d6021ff..4793793769d 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -297,6 +297,7 @@ def get_ignored_functions() -> Set[Callable]: torch.set_vital, torch.read_vitals, torch.vmap, + torch.cond, torch.frombuffer, torch.asarray, torch._functional_sym_constrain_range,