mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add doc for torch.cond (#108691)
We add a doc for torch.cond. This PR is a replacement of https://github.com/pytorch/pytorch/pull/107977. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108691 Approved by: https://github.com/zou3519
This commit is contained in:
parent
901aa85b58
commit
6db3853eeb
3 changed files with 200 additions and 25 deletions
176
docs/source/control_flow_cond.rst
Normal file
176
docs/source/control_flow_cond.rst
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
.. _control_flow_cond:
|
||||
|
||||
Control Flow - Cond
|
||||
====================
|
||||
|
||||
`torch.cond` is a structured control flow operator. It can be used to specify if-else like control flow
|
||||
and can logically be seen as implemented as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def cond(
|
||||
pred: Union[bool, torch.Tensor],
|
||||
true_fn: Callable,
|
||||
false_fn: Callable,
|
||||
operands: Tuple[torch.Tensor]
|
||||
):
|
||||
if pred:
|
||||
return true_fn(*operands)
|
||||
else:
|
||||
return false_fn(*operands)
|
||||
|
||||
Its unique power lies in its aibilty of expressing **data-dependent control flow**: it lowers to a conditional
|
||||
operator (`torch.ops.higher_order.cond`), which preserves predicate, true function and false functions.
|
||||
This unlocks great flexibilty in writing and deploying models that change model architecture based on
|
||||
the **value** or **shape** of inputs or intermediate outputs of tensor operations.
|
||||
|
||||
.. warning::
|
||||
`torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and
|
||||
doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
|
||||
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
|
||||
|
||||
Examples
|
||||
~~~~~~~~
|
||||
|
||||
Below is an example that uses cond to branch based on input shape:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
|
||||
def true_fn(x: torch.Tensor):
|
||||
return x.cos() + x.sin()
|
||||
|
||||
def false_fn(x: torch.Tensor):
|
||||
return x.sin()
|
||||
|
||||
class DynamicShapeCondPredicate(torch.nn.Module):
|
||||
"""
|
||||
A basic usage of cond based on dynamic shape predicate.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def true_fn(x: torch.Tensor):
|
||||
return x.cos()
|
||||
|
||||
def false_fn(x: torch.Tensor):
|
||||
return x.sin()
|
||||
|
||||
return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))
|
||||
|
||||
dyn_shape_mod = DynamicShapeCondPredicate()
|
||||
|
||||
We can eagerly run the model and expect the results vary based on input shape:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
inp = torch.randn(3)
|
||||
inp2 = torch.randn(5)
|
||||
assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
|
||||
assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))
|
||||
|
||||
We can export the model for further transformations and deployment:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
inp = torch.randn(4, 3)
|
||||
dim_batch = torch.export.Dim("batch", min=2)
|
||||
ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
|
||||
print(ep)
|
||||
|
||||
This gives us an exported program as shown below:
|
||||
|
||||
.. code-block::
|
||||
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[s0, 3]):
|
||||
sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
|
||||
gt: Sym(s0 > 4) = sym_size > 4; sym_size = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
return (conditional,)
|
||||
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[s0, 3]):
|
||||
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
|
||||
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
|
||||
return add
|
||||
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[s0, 3]):
|
||||
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
return sin
|
||||
|
||||
Notice that `torch.cond` is lowered to `torch.ops.higher_order.cond`, its predicate becomes a Symbolic expression over the shape of input,
|
||||
and branch functions becomes two sub-graph attributes of the top level graph module.
|
||||
|
||||
Here is another exmaple that showcases how to express a data-dependet control flow:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class DataDependentCondPredicacte(torch.nn.Module):
|
||||
"""
|
||||
A basic usage of cond based on data dependent predicate.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))
|
||||
|
||||
The exported program we get after export:
|
||||
|
||||
.. code-block::
|
||||
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[s0, 3]):
|
||||
sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
|
||||
gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None
|
||||
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
return (conditional,)
|
||||
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[s0, 3]):
|
||||
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
|
||||
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
|
||||
return add
|
||||
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[s0, 3]):
|
||||
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
return sin
|
||||
|
||||
|
||||
Invariants of torch.ops.higher_order.cond
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
There are several useful invariants for `torch.ops.higher_order.cond`:
|
||||
|
||||
- For predicate:
|
||||
- Dynamicness of predicate is preserved (e.g. `gt` shown in the above example)
|
||||
- If the predicate in user-program is constant (e.g. a python bool constant), the `pred` of the operator will be a constant.
|
||||
|
||||
- For branches:
|
||||
- The input and output signature will be a flattened tuple.
|
||||
- They are `torch.fx.GraphModule`.
|
||||
- Closures in original function becomes explicit inputs. No closures.
|
||||
- No mutations on inputs or globals are allowed.
|
||||
|
||||
- For operands:
|
||||
- It will also be a flat tuple.
|
||||
|
||||
- Nesting of `torch.cond` in user program becomes nested graph modules.
|
||||
|
||||
|
||||
API Reference
|
||||
-------------
|
||||
.. autofunction:: torch._higher_order_ops.cond.cond
|
||||
|
|
@ -501,7 +501,8 @@ Graph breaks can also be encountered on data-dependent control flow (``if
|
|||
x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot
|
||||
possibly deal with without generating code for a combinatorially exploding
|
||||
number of paths. In such cases, users will need to rewrite their code using
|
||||
special control flow operators (coming soon!).
|
||||
special control flow operators. Currently, we support :ref:`torch.cond <control_flow_cond>`
|
||||
to express if-else like control flow (more coming soon!).
|
||||
|
||||
Data-Dependent Accesses
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
@ -538,6 +539,7 @@ Read More
|
|||
torch.compiler_transformations
|
||||
torch.compiler_ir
|
||||
generated/exportdb/index
|
||||
control_flow_cond
|
||||
|
||||
.. toctree::
|
||||
:caption: Deep Dive for PyTorch Developers
|
||||
|
|
|
|||
|
|
@ -44,13 +44,18 @@ class UnsupportedAliasMutationException(RuntimeError):
|
|||
|
||||
def cond(pred, true_fn, false_fn, operands):
|
||||
r"""
|
||||
Conditionally applies ``true_fn`` or ``false_fn``.
|
||||
Conditionally applies `true_fn` or `false_fn`.
|
||||
|
||||
``cond`` is structured control flow operator. That is, it is like a Python if-statement,
|
||||
but has limitations on ``true_fn``, ``false_fn``, and ``operands`` that enable it to be
|
||||
.. warning::
|
||||
`torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and
|
||||
doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
|
||||
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
|
||||
|
||||
`cond` is structured control flow operator. That is, it is like a Python if-statement,
|
||||
but has restrictions on `true_fn`, `false_fn`, and `operands` that enable it to be
|
||||
capturable using torch.compile and torch.export.
|
||||
|
||||
Assuming the constraints on ``cond``'s arguments are met, ``cond`` is equivalent to the following::
|
||||
Assuming the constraints on `cond`'s arguments are met, `cond` is equivalent to the following::
|
||||
|
||||
def cond(pred, true_branch, false_branch, operands):
|
||||
if pred:
|
||||
|
|
@ -58,29 +63,21 @@ def cond(pred, true_fn, false_fn, operands):
|
|||
else:
|
||||
return false_branch(*operands)
|
||||
|
||||
.. warning::
|
||||
cond is a prototype feature in PyTorch, included as a part of the torch.export release. The main limitations are that
|
||||
it may not work in eager-mode PyTorch and you may encounter various failure modes while using it.
|
||||
Please look forward to a more stable implementation in a future version of PyTorch.
|
||||
|
||||
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
|
||||
|
||||
Args:
|
||||
- `pred (Union[bool, torch.Tensor])`: A boolean expression or a tensor with one element,
|
||||
pred (Union[bool, torch.Tensor]): A boolean expression or a tensor with one element,
|
||||
indicating which branch function to apply.
|
||||
|
||||
- `true_fn (Callable)`: A callable function (a -> b) that is within the
|
||||
true_fn (Callable): A callable function (a -> b) that is within the
|
||||
scope that is being traced.
|
||||
|
||||
- `false_fn (Callable)`: A callable function (a -> b) that is within the
|
||||
scope that is being traced. The true branch and false branch must have
|
||||
consistent input and outputs, meaning the inputs have to be the same, and
|
||||
the outputs have to be the same type and shape.
|
||||
false_fn (Callable): A callable function (a -> b) that is within the
|
||||
scope that is being traced. The true branch and false branch must
|
||||
have consistent input and outputs, meaning the inputs have to be
|
||||
the same, and the outputs have to be the same type and shape.
|
||||
|
||||
- `operands (Tuple[torch.Tensor])`: A tuple of inputs to the true/false
|
||||
branches.
|
||||
operands (Tuple[torch.Tensor]): A tuple of inputs to the true/false functions.
|
||||
|
||||
Example:
|
||||
Example::
|
||||
|
||||
def true_fn(x: torch.Tensor):
|
||||
return x.cos()
|
||||
|
|
@ -102,12 +99,12 @@ def cond(pred, true_fn, false_fn, operands):
|
|||
- The function must return a tensor with the same metadata, e.g. shape,
|
||||
dtype, etc.
|
||||
|
||||
- The function cannot have in-place mutations on inputs or global variables. (Note: in-place tensor
|
||||
operations such as `add_` for intermediate results are allowed in a branch)
|
||||
- The function cannot have in-place mutations on inputs or global variables.
|
||||
(Note: in-place tensor operations such as `add_` for intermediate results
|
||||
are allowed in a branch)
|
||||
|
||||
.. warning::
|
||||
|
||||
Temporal Limitations:
|
||||
Temporal Limitations:
|
||||
|
||||
- `cond` only supports **inference** right now. Autograd will be supported in the future.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue