From 6db3853eebb3b2d2e4c68adcb0d04d3259da5910 Mon Sep 17 00:00:00 2001 From: ydwu4 Date: Tue, 3 Oct 2023 09:31:26 -0700 Subject: [PATCH] Add doc for torch.cond (#108691) We add a doc for torch.cond. This PR is a replacement of https://github.com/pytorch/pytorch/pull/107977. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108691 Approved by: https://github.com/zou3519 --- docs/source/control_flow_cond.rst | 176 ++++++++++++++++++++++++++++++ docs/source/export.rst | 4 +- torch/_higher_order_ops/cond.py | 45 ++++---- 3 files changed, 200 insertions(+), 25 deletions(-) create mode 100644 docs/source/control_flow_cond.rst diff --git a/docs/source/control_flow_cond.rst b/docs/source/control_flow_cond.rst new file mode 100644 index 00000000000..44031598d20 --- /dev/null +++ b/docs/source/control_flow_cond.rst @@ -0,0 +1,176 @@ +.. _control_flow_cond: + +Control Flow - Cond +==================== + +`torch.cond` is a structured control flow operator. It can be used to specify if-else like control flow +and can logically be seen as implemented as follows. + +.. code-block:: python + + def cond( + pred: Union[bool, torch.Tensor], + true_fn: Callable, + false_fn: Callable, + operands: Tuple[torch.Tensor] + ): + if pred: + return true_fn(*operands) + else: + return false_fn(*operands) + +Its unique power lies in its aibilty of expressing **data-dependent control flow**: it lowers to a conditional +operator (`torch.ops.higher_order.cond`), which preserves predicate, true function and false functions. +This unlocks great flexibilty in writing and deploying models that change model architecture based on +the **value** or **shape** of inputs or intermediate outputs of tensor operations. + +.. warning:: + `torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and + doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch. + Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + +Examples +~~~~~~~~ + +Below is an example that uses cond to branch based on input shape: + +.. code-block:: python + + import torch + + def true_fn(x: torch.Tensor): + return x.cos() + x.sin() + + def false_fn(x: torch.Tensor): + return x.sin() + + class DynamicShapeCondPredicate(torch.nn.Module): + """ + A basic usage of cond based on dynamic shape predicate. + """ + + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_fn(x: torch.Tensor): + return x.cos() + + def false_fn(x: torch.Tensor): + return x.sin() + + return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,)) + + dyn_shape_mod = DynamicShapeCondPredicate() + +We can eagerly run the model and expect the results vary based on input shape: + +.. code-block:: python + + inp = torch.randn(3) + inp2 = torch.randn(5) + assert torch.equal(dyn_shape_mod(inp), false_fn(inp)) + assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2)) + +We can export the model for further transformations and deployment: + +.. code-block:: python + + inp = torch.randn(4, 3) + dim_batch = torch.export.Dim("batch", min=2) + ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}}) + print(ep) + +This gives us an exported program as shown below: + +.. code-block:: + + class GraphModule(torch.nn.Module): + def forward(self, arg0_1: f32[s0, 3]): + sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0) + gt: Sym(s0 > 4) = sym_size > 4; sym_size = None + true_graph_0 = self.true_graph_0 + false_graph_0 = self.false_graph_0 + conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None + return (conditional,) + + class (torch.nn.Module): + def forward(self, arg0_1: f32[s0, 3]): + cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1) + sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None + add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None + return add + + class (torch.nn.Module): + def forward(self, arg0_1: f32[s0, 3]): + sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None + return sin + +Notice that `torch.cond` is lowered to `torch.ops.higher_order.cond`, its predicate becomes a Symbolic expression over the shape of input, +and branch functions becomes two sub-graph attributes of the top level graph module. + +Here is another exmaple that showcases how to express a data-dependet control flow: + +.. code-block:: python + + class DataDependentCondPredicacte(torch.nn.Module): + """ + A basic usage of cond based on data dependent predicate. + """ + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,)) + +The exported program we get after export: + +.. code-block:: + + class GraphModule(torch.nn.Module): + def forward(self, arg0_1: f32[s0, 3]): + sum_1: f32[] = torch.ops.aten.sum.default(arg0_1) + gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None + + true_graph_0 = self.true_graph_0 + false_graph_0 = self.false_graph_0 + conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None + return (conditional,) + + class (torch.nn.Module): + def forward(self, arg0_1: f32[s0, 3]): + cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1) + sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None + add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None + return add + + class (torch.nn.Module): + def forward(self, arg0_1: f32[s0, 3]): + sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None + return sin + + +Invariants of torch.ops.higher_order.cond +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +There are several useful invariants for `torch.ops.higher_order.cond`: + +- For predicate: + - Dynamicness of predicate is preserved (e.g. `gt` shown in the above example) + - If the predicate in user-program is constant (e.g. a python bool constant), the `pred` of the operator will be a constant. + +- For branches: + - The input and output signature will be a flattened tuple. + - They are `torch.fx.GraphModule`. + - Closures in original function becomes explicit inputs. No closures. + - No mutations on inputs or globals are allowed. + +- For operands: + - It will also be a flat tuple. + +- Nesting of `torch.cond` in user program becomes nested graph modules. + + +API Reference +------------- +.. autofunction:: torch._higher_order_ops.cond.cond diff --git a/docs/source/export.rst b/docs/source/export.rst index 9fce22a5c09..a75a9f512a5 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -501,7 +501,8 @@ Graph breaks can also be encountered on data-dependent control flow (``if x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot possibly deal with without generating code for a combinatorially exploding number of paths. In such cases, users will need to rewrite their code using -special control flow operators (coming soon!). +special control flow operators. Currently, we support :ref:`torch.cond ` +to express if-else like control flow (more coming soon!). Data-Dependent Accesses ^^^^^^^^^^^^^^^^^^^^^^^ @@ -538,6 +539,7 @@ Read More torch.compiler_transformations torch.compiler_ir generated/exportdb/index + control_flow_cond .. toctree:: :caption: Deep Dive for PyTorch Developers diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 20313dac191..0e7a5c43ced 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -44,13 +44,18 @@ class UnsupportedAliasMutationException(RuntimeError): def cond(pred, true_fn, false_fn, operands): r""" - Conditionally applies ``true_fn`` or ``false_fn``. + Conditionally applies `true_fn` or `false_fn`. - ``cond`` is structured control flow operator. That is, it is like a Python if-statement, - but has limitations on ``true_fn``, ``false_fn``, and ``operands`` that enable it to be + .. warning:: + `torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and + doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch. + Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + `cond` is structured control flow operator. That is, it is like a Python if-statement, + but has restrictions on `true_fn`, `false_fn`, and `operands` that enable it to be capturable using torch.compile and torch.export. - Assuming the constraints on ``cond``'s arguments are met, ``cond`` is equivalent to the following:: + Assuming the constraints on `cond`'s arguments are met, `cond` is equivalent to the following:: def cond(pred, true_branch, false_branch, operands): if pred: @@ -58,29 +63,21 @@ def cond(pred, true_fn, false_fn, operands): else: return false_branch(*operands) - .. warning:: - cond is a prototype feature in PyTorch, included as a part of the torch.export release. The main limitations are that - it may not work in eager-mode PyTorch and you may encounter various failure modes while using it. - Please look forward to a more stable implementation in a future version of PyTorch. - - Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype - Args: - - `pred (Union[bool, torch.Tensor])`: A boolean expression or a tensor with one element, + pred (Union[bool, torch.Tensor]): A boolean expression or a tensor with one element, indicating which branch function to apply. - - `true_fn (Callable)`: A callable function (a -> b) that is within the + true_fn (Callable): A callable function (a -> b) that is within the scope that is being traced. - - `false_fn (Callable)`: A callable function (a -> b) that is within the - scope that is being traced. The true branch and false branch must have - consistent input and outputs, meaning the inputs have to be the same, and - the outputs have to be the same type and shape. + false_fn (Callable): A callable function (a -> b) that is within the + scope that is being traced. The true branch and false branch must + have consistent input and outputs, meaning the inputs have to be + the same, and the outputs have to be the same type and shape. - - `operands (Tuple[torch.Tensor])`: A tuple of inputs to the true/false - branches. + operands (Tuple[torch.Tensor]): A tuple of inputs to the true/false functions. - Example: + Example:: def true_fn(x: torch.Tensor): return x.cos() @@ -102,12 +99,12 @@ def cond(pred, true_fn, false_fn, operands): - The function must return a tensor with the same metadata, e.g. shape, dtype, etc. - - The function cannot have in-place mutations on inputs or global variables. (Note: in-place tensor - operations such as `add_` for intermediate results are allowed in a branch) + - The function cannot have in-place mutations on inputs or global variables. + (Note: in-place tensor operations such as `add_` for intermediate results + are allowed in a branch) .. warning:: - - Temporal Limitations: + Temporal Limitations: - `cond` only supports **inference** right now. Autograd will be supported in the future.