Add doc for torch.cond (#108691)

We add a doc for torch.cond. This PR is a replacement of https://github.com/pytorch/pytorch/pull/107977.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108691
Approved by: https://github.com/zou3519
This commit is contained in:
ydwu4 2023-10-03 09:31:26 -07:00 committed by PyTorch MergeBot
parent 901aa85b58
commit 6db3853eeb
3 changed files with 200 additions and 25 deletions

View file

@ -0,0 +1,176 @@
.. _control_flow_cond:
Control Flow - Cond
====================
`torch.cond` is a structured control flow operator. It can be used to specify if-else like control flow
and can logically be seen as implemented as follows.
.. code-block:: python
def cond(
pred: Union[bool, torch.Tensor],
true_fn: Callable,
false_fn: Callable,
operands: Tuple[torch.Tensor]
):
if pred:
return true_fn(*operands)
else:
return false_fn(*operands)
Its unique power lies in its aibilty of expressing **data-dependent control flow**: it lowers to a conditional
operator (`torch.ops.higher_order.cond`), which preserves predicate, true function and false functions.
This unlocks great flexibilty in writing and deploying models that change model architecture based on
the **value** or **shape** of inputs or intermediate outputs of tensor operations.
.. warning::
`torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and
doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
Examples
~~~~~~~~
Below is an example that uses cond to branch based on input shape:
.. code-block:: python
import torch
def true_fn(x: torch.Tensor):
return x.cos() + x.sin()
def false_fn(x: torch.Tensor):
return x.sin()
class DynamicShapeCondPredicate(torch.nn.Module):
"""
A basic usage of cond based on dynamic shape predicate.
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
def true_fn(x: torch.Tensor):
return x.cos()
def false_fn(x: torch.Tensor):
return x.sin()
return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))
dyn_shape_mod = DynamicShapeCondPredicate()
We can eagerly run the model and expect the results vary based on input shape:
.. code-block:: python
inp = torch.randn(3)
inp2 = torch.randn(5)
assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))
We can export the model for further transformations and deployment:
.. code-block:: python
inp = torch.randn(4, 3)
dim_batch = torch.export.Dim("batch", min=2)
ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
print(ep)
This gives us an exported program as shown below:
.. code-block::
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
gt: Sym(s0 > 4) = sym_size > 4; sym_size = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
return (conditional,)
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
return add
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return sin
Notice that `torch.cond` is lowered to `torch.ops.higher_order.cond`, its predicate becomes a Symbolic expression over the shape of input,
and branch functions becomes two sub-graph attributes of the top level graph module.
Here is another exmaple that showcases how to express a data-dependet control flow:
.. code-block:: python
class DataDependentCondPredicacte(torch.nn.Module):
"""
A basic usage of cond based on data dependent predicate.
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))
The exported program we get after export:
.. code-block::
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
return (conditional,)
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
return add
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return sin
Invariants of torch.ops.higher_order.cond
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
There are several useful invariants for `torch.ops.higher_order.cond`:
- For predicate:
- Dynamicness of predicate is preserved (e.g. `gt` shown in the above example)
- If the predicate in user-program is constant (e.g. a python bool constant), the `pred` of the operator will be a constant.
- For branches:
- The input and output signature will be a flattened tuple.
- They are `torch.fx.GraphModule`.
- Closures in original function becomes explicit inputs. No closures.
- No mutations on inputs or globals are allowed.
- For operands:
- It will also be a flat tuple.
- Nesting of `torch.cond` in user program becomes nested graph modules.
API Reference
-------------
.. autofunction:: torch._higher_order_ops.cond.cond

View file

@ -501,7 +501,8 @@ Graph breaks can also be encountered on data-dependent control flow (``if
x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot
possibly deal with without generating code for a combinatorially exploding
number of paths. In such cases, users will need to rewrite their code using
special control flow operators (coming soon!).
special control flow operators. Currently, we support :ref:`torch.cond <control_flow_cond>`
to express if-else like control flow (more coming soon!).
Data-Dependent Accesses
^^^^^^^^^^^^^^^^^^^^^^^
@ -538,6 +539,7 @@ Read More
torch.compiler_transformations
torch.compiler_ir
generated/exportdb/index
control_flow_cond
.. toctree::
:caption: Deep Dive for PyTorch Developers

View file

@ -44,13 +44,18 @@ class UnsupportedAliasMutationException(RuntimeError):
def cond(pred, true_fn, false_fn, operands):
r"""
Conditionally applies ``true_fn`` or ``false_fn``.
Conditionally applies `true_fn` or `false_fn`.
``cond`` is structured control flow operator. That is, it is like a Python if-statement,
but has limitations on ``true_fn``, ``false_fn``, and ``operands`` that enable it to be
.. warning::
`torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and
doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
`cond` is structured control flow operator. That is, it is like a Python if-statement,
but has restrictions on `true_fn`, `false_fn`, and `operands` that enable it to be
capturable using torch.compile and torch.export.
Assuming the constraints on ``cond``'s arguments are met, ``cond`` is equivalent to the following::
Assuming the constraints on `cond`'s arguments are met, `cond` is equivalent to the following::
def cond(pred, true_branch, false_branch, operands):
if pred:
@ -58,29 +63,21 @@ def cond(pred, true_fn, false_fn, operands):
else:
return false_branch(*operands)
.. warning::
cond is a prototype feature in PyTorch, included as a part of the torch.export release. The main limitations are that
it may not work in eager-mode PyTorch and you may encounter various failure modes while using it.
Please look forward to a more stable implementation in a future version of PyTorch.
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
Args:
- `pred (Union[bool, torch.Tensor])`: A boolean expression or a tensor with one element,
pred (Union[bool, torch.Tensor]): A boolean expression or a tensor with one element,
indicating which branch function to apply.
- `true_fn (Callable)`: A callable function (a -> b) that is within the
true_fn (Callable): A callable function (a -> b) that is within the
scope that is being traced.
- `false_fn (Callable)`: A callable function (a -> b) that is within the
scope that is being traced. The true branch and false branch must have
consistent input and outputs, meaning the inputs have to be the same, and
the outputs have to be the same type and shape.
false_fn (Callable): A callable function (a -> b) that is within the
scope that is being traced. The true branch and false branch must
have consistent input and outputs, meaning the inputs have to be
the same, and the outputs have to be the same type and shape.
- `operands (Tuple[torch.Tensor])`: A tuple of inputs to the true/false
branches.
operands (Tuple[torch.Tensor]): A tuple of inputs to the true/false functions.
Example:
Example::
def true_fn(x: torch.Tensor):
return x.cos()
@ -102,12 +99,12 @@ def cond(pred, true_fn, false_fn, operands):
- The function must return a tensor with the same metadata, e.g. shape,
dtype, etc.
- The function cannot have in-place mutations on inputs or global variables. (Note: in-place tensor
operations such as `add_` for intermediate results are allowed in a branch)
- The function cannot have in-place mutations on inputs or global variables.
(Note: in-place tensor operations such as `add_` for intermediate results
are allowed in a branch)
.. warning::
Temporal Limitations:
Temporal Limitations:
- `cond` only supports **inference** right now. Autograd will be supported in the future.