mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
fix formatting in programming model doc (#143587)
Test Plan: Some of the formatting in https://docs-preview.pytorch.org/pytorch/pytorch/143546/export.programming_model.html is broken. Differential Revision: D67458972 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143587 Approved by: https://github.com/yushangdi
This commit is contained in:
parent
fe0f20615c
commit
29b586bbad
1 changed files with 43 additions and 43 deletions
|
|
@ -26,13 +26,13 @@ Strict vs. Non-Strict Tracing
|
|||
|
||||
In *non-strict mode*, we trace through the program using the normal Python
|
||||
interpreter. Your code executes exactly as it would in eager mode; the only
|
||||
difference is that **all Tensors are replaced by
|
||||
difference is that all Tensors are replaced by
|
||||
`fake Tensors <https://pytorch.org/docs/main/torch.compiler_fake_tensor.html>`__,
|
||||
which have shapes and other forms of metadata but no data**, wrapped in
|
||||
**which have shapes and other forms of metadata but no data**, wrapped in
|
||||
`Proxy objects <https://pytorch.org/docs/main/fx.html>`__ that record all
|
||||
operations on them into a graph. We also capture
|
||||
**`conditions on Tensor shapes <https://pytorch.org/docs/main/torch.compiler_dynamic_shapes.html#the-guard-model>`__
|
||||
that guard the correctness of the generated code**.
|
||||
`conditions on Tensor shapes <https://pytorch.org/docs/main/torch.compiler_dynamic_shapes.html#the-guard-model>`__
|
||||
**that guard the correctness of the generated code**.
|
||||
|
||||
In *strict mode*, we first trace through the program using
|
||||
:ref:`TorchDynamo <torch.compiler_dynamo_deepdive>`, a Python bytecode
|
||||
|
|
@ -67,7 +67,7 @@ A *static* value is a value that is **fixed at export time and cannot change
|
|||
between executions of the exported program**. When the value is encountered
|
||||
during tracing, we treat it as a constant and hard-code it into the graph.
|
||||
|
||||
When an operation is performed (e.g. `x + y`) and all inputs are static,
|
||||
When an operation is performed (e.g. ``x + y``) and all inputs are static,
|
||||
the output of the operation is directly hard-coded into the graph and the
|
||||
operation does not show up (i.e. it gets "constant-folded").
|
||||
|
||||
|
|
@ -93,8 +93,8 @@ been *specialized* to that value. For example:
|
|||
|
||||
"""
|
||||
|
||||
Here, we provide `3` as the traced value for `y`; it is treated as a static
|
||||
value and added to `7`, burning in the static value `10` in the graph.
|
||||
Here, we provide ``3`` as the traced value for ``y``; it is treated as a static
|
||||
value and added to ``7``, burning in the static value ``10`` in the graph.
|
||||
|
||||
Dynamic Values
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
@ -122,17 +122,17 @@ Whether a value is static or dynamic depends on its type:
|
|||
- Tensors that are part of module state, i.e., parameters and buffers,
|
||||
always have static shapes.
|
||||
|
||||
- Other forms of Tensor *metadata* (e.g. `device`, `dtype`) are static.
|
||||
- Other forms of Tensor *metadata* (e.g. ``device``, ``dtype``) are static.
|
||||
|
||||
- Python *primitives* (`int`, `float`, `bool`, `str`, `None`) are static.
|
||||
- Python *primitives* (``int``, ``float``, ``bool``, ``str``, ``None``) are static.
|
||||
|
||||
- There are dynamic variants for some primitive types (`SymInt`,
|
||||
`SymFloat`, `SymBool`). Typically users do not have to deal with them.
|
||||
- There are dynamic variants for some primitive types (``SymInt``,
|
||||
``SymFloat``, ``SymBool``). Typically users do not have to deal with them.
|
||||
|
||||
- For Python *standard containers* (`list`, `tuple`, `dict`, `namedtuple`):
|
||||
- For Python *standard containers* (``list``, ``tuple``, ``dict``, ``namedtuple``):
|
||||
|
||||
- The structure (i.e., length for `list` and `tuple` values, and key
|
||||
sequence for `dict` and `namedtuple` values) is static.
|
||||
- The structure (i.e., length for ``list`` and ``tuple`` values, and key
|
||||
sequence for ``dict`` and ``namedtuple`` values) is static.
|
||||
|
||||
- The contained elements have these rules applied to them recursively
|
||||
(basically the
|
||||
|
|
@ -160,9 +160,9 @@ By default, the types of inputs you can use for your program are:
|
|||
|
||||
- Tensor
|
||||
|
||||
- Python primitives (`int`, `float`, `bool`, `str`, `None`)
|
||||
- Python primitives (``int``, ``float``, ``bool``, ``str``, ``None``)
|
||||
|
||||
- Python standard containers (`list`, `tuple`, `dict`, `namedtuple`)
|
||||
- Python standard containers (``list``, ``tuple``, ``dict``, ``namedtuple``)
|
||||
|
||||
Custom Input Types
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
|
@ -180,7 +180,7 @@ an input type.
|
|||
f: torch.Tensor
|
||||
p: torch.Tensor
|
||||
|
||||
torch._export.utils.register_dataclass_as_pytree_node(Input)
|
||||
torch.export.register_dataclass(Input)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x: Input):
|
||||
|
|
@ -245,8 +245,8 @@ is also covered by this case.)
|
|||
As mentioned above, we "burn in" static values, so the exported graph will
|
||||
never see any control flow over static values.
|
||||
|
||||
In the case of an `if` statement, we will continue tracing the branch taken
|
||||
at export time. In the case of a `for` or `while` statement, we will continue
|
||||
In the case of an ``if`` statement, we will continue tracing the branch taken
|
||||
at export time. In the case of a ``for`` or ``while`` statement, we will continue
|
||||
tracing by unrolling the loop.
|
||||
|
||||
Dynamic Control Flow: Shape-Dependent vs. Data-Dependent
|
||||
|
|
@ -262,17 +262,17 @@ Dynamic Shape-Dependent Control Flow
|
|||
|
||||
When the value involved in a control flow is a
|
||||
`dynamic shape <https://pytorch.org/docs/main/torch.compiler_dynamic_shapes.html>`__,
|
||||
**in most cases we will also know the concrete value of the dynamic shape
|
||||
in most cases **we will also know the concrete value of the dynamic shape
|
||||
during tracing**: see the following section for more details on how the
|
||||
compiler tracks this information.
|
||||
|
||||
In these cases we say that the control flow is shape-dependent. **We use the
|
||||
concrete value of the dynamic shape to evaluate the condition** to either
|
||||
`True` or `False` and continue tracing (as discussed above), additionally
|
||||
``True`` or ``False`` and continue tracing (as discussed above), additionally
|
||||
emitting a guard corresponding to the condition just evaluated.
|
||||
|
||||
Otherwise the control flow is considered data-dependent. We cannot evaluate
|
||||
the condition to either True or False, so cannot continue tracing and have to
|
||||
the condition to either ``True`` or ``False``, so cannot continue tracing and have to
|
||||
raise an error at export time. See next section.
|
||||
|
||||
Dynamic Data-Dependent Control Flow
|
||||
|
|
@ -288,8 +288,8 @@ We provide **operators to express general conditionals and loops over dynamic
|
|||
values**, e.g., `torch.cond`, `torch.map`. Note that you only need to use these
|
||||
if you truly want *data-dependent control flow*.
|
||||
|
||||
Here's an example of an `if` statement on a data-dependent condition,
|
||||
`x.sum() > 0`, where `x` is an input Tensor, rewritten using `torch.cond`.
|
||||
Here's an example of an ``if`` statement on a data-dependent condition,
|
||||
``x.sum() > 0``, where ``x`` is an input Tensor, rewritten using `torch.cond`.
|
||||
Instead of having to decide which branch to trace, now both branches are
|
||||
traced.
|
||||
|
||||
|
|
@ -312,20 +312,20 @@ traced.
|
|||
)
|
||||
|
||||
A special case of data-dependent control flow is where it involves a
|
||||
*`data-dependent dynamic shape <https://pytorch.org/docs/main/torch.compiler_dynamic_shapes.html#unbacked-symints>`__*:
|
||||
`data-dependent dynamic shape <https://pytorch.org/docs/main/torch.compiler_dynamic_shapes.html#unbacked-symints>`__:
|
||||
typically, the shape of some intermediate Tensor that depends on input data
|
||||
rather than on input shapes (thus not shape-dependent). Instead of using a
|
||||
control flow operator, in this case you can provide an assertion that decides
|
||||
whether the condition is `True` or `False`. Given such an assertion, we can
|
||||
whether the condition is ``True`` or ``False``. Given such an assertion, we can
|
||||
continue tracing, emitting a guard as above.
|
||||
|
||||
We provide **operators to express assertions on dynamic shapes**, e.g.,
|
||||
`torch._check`. Note that you only need to use this when there is control
|
||||
flow on data-dependent dynamic shapes.
|
||||
|
||||
Here's an example of an `if` statement on a condition involving a
|
||||
data-dependent dynamic shape, `nz.shape[0] > 0`, where `nz` is the result of
|
||||
calling `torch.nonzero`, an operator whose output shape depends on input
|
||||
Here's an example of an ``if`` statement on a condition involving a
|
||||
data-dependent dynamic shape, ``nz.shape[0] > 0``, where ``nz`` is the result of
|
||||
calling :func:`torch.nonzero`, an operator whose output shape depends on input
|
||||
data. Instead of rewriting it, you can add an assertion using `torch._check`
|
||||
to effectively decide which branch to trace.
|
||||
|
||||
|
|
@ -354,7 +354,7 @@ Basics of Symbolic Shapes
|
|||
|
||||
During tracing, dynamic Tensor shapes and conditions over them are encoded as
|
||||
"symbolic expressions." (In contrast, static Tensor shapes and conditions
|
||||
over them are simply `int`s and `bools`.)
|
||||
over them are simply ``int`` and ``bool`` values.)
|
||||
|
||||
A *symbol* is like a variable; it describes a dynamic Tensor shape.
|
||||
|
||||
|
|
@ -362,13 +362,13 @@ As tracing proceeds, shapes of intermediate Tensors may be described by more
|
|||
general expressions, typically involving integer arithmetic operators. This
|
||||
is because **for most PyTorch operators, shapes of output Tensors can be
|
||||
described as functions of shapes of input Tensors**. For example, the shape of
|
||||
the output of `torch.concat` is the sum of the shapes of its inputs.
|
||||
the output of :func:`torch.cat` is the sum of the shapes of its inputs.
|
||||
|
||||
Moreover, as we encounter control flow in the program, we create boolean
|
||||
expressions, typically involving relational operators, describing conditions
|
||||
along the traced path. These **expressions are evaluated to decide which path
|
||||
to trace through the program**, and recorded in a
|
||||
*`shape environment <https://pytorch.org/docs/main/torch.compiler_dynamic_shapes.html#overall-architecture>`__*
|
||||
`shape environment <https://pytorch.org/docs/main/torch.compiler_dynamic_shapes.html#overall-architecture>`__
|
||||
to guard the correctness of the traced path and to evaluate subsequently
|
||||
created expressions.
|
||||
|
||||
|
|
@ -385,7 +385,7 @@ additional fake (a.k.a. "meta") implementation, which inputs and outputs fake
|
|||
Tensors, that matches the behavior of the actual implementation in terms of
|
||||
shapes and other forms of metadata carried by fake Tensors.
|
||||
|
||||
For example, note how the fake implementation of `torch.index_select`
|
||||
For example, note how the fake implementation of :func:`torch.index_select`
|
||||
computes the shape of the output using the shape of the input (while ignoring
|
||||
input data and returning empty output data).
|
||||
|
||||
|
|
@ -429,13 +429,13 @@ Control Flow: Guards and Assertions
|
|||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
When a condition on shapes is encountered, it either involves only static
|
||||
shapes, in which case it is a `bool`, or it involves dynamic shapes, in which
|
||||
shapes, in which case it is a ``bool``, or it involves dynamic shapes, in which
|
||||
case it is a symbolic boolean expression. For the latter:
|
||||
|
||||
- When the condition involves only backed dynamic shapes, we can use the
|
||||
concrete values of those dynamic shapes to evaluate the condition to `True`
|
||||
or `False`. We can then add a guard to the shape environment that states
|
||||
that the corresponding symbolic boolean expression is `True` or `False`,
|
||||
concrete values of those dynamic shapes to evaluate the condition to ``True``
|
||||
or ``False``. We can then add a guard to the shape environment that states
|
||||
that the corresponding symbolic boolean expression is ``True`` or ``False``,
|
||||
and continue tracing.
|
||||
|
||||
- Otherwise the condition involves unbacked dynamic shapes. In general we
|
||||
|
|
@ -444,7 +444,7 @@ case it is a symbolic boolean expression. For the latter:
|
|||
user is expected to use an explicit PyTorch operator for tracing to
|
||||
continue. This information is added as a guard in the shape environment,
|
||||
and can also possibly help evaluate other subsequently encountered
|
||||
conditions to `True` or `False`.
|
||||
conditions to ``True`` or ``False``.
|
||||
|
||||
Once the model is exported, **any guards on backed dynamic shapes can be
|
||||
understood as conditions on input dynamic shapes**. These are verified against
|
||||
|
|
@ -478,7 +478,7 @@ In addition, you can define and use
|
|||
Defining a custom operator includes defining a fake implementation for it,
|
||||
just like any other PyTorch operator (see previous section).
|
||||
|
||||
Here's an example of a custom `sin`` operator that wraps NumPy, and its
|
||||
Here's an example of a custom ``sin`` operator that wraps NumPy, and its
|
||||
registered (trivial) fake implementation.
|
||||
|
||||
.. code-block:: python
|
||||
|
|
@ -495,7 +495,7 @@ registered (trivial) fake implementation.
|
|||
|
||||
**Sometimes your custom operator's fake implementation will involve
|
||||
data-dependent shapes**. Here's how a fake implementation for a custom
|
||||
`nonzero` might look like.
|
||||
``nonzero`` might look like.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
@ -518,7 +518,7 @@ Module states include parameters, buffers, and regular attributes.
|
|||
- On the other hand, parameters and buffers are always Tensors.
|
||||
|
||||
Module states can be dynamic or static, based on their types as outlined
|
||||
above. For example, `self.training` is a `bool`, which means it is static; on
|
||||
above. For example, ``self.training`` is a ``bool``, which means it is static; on
|
||||
the other hand, any parameter or buffer is dynamic.
|
||||
|
||||
The *shapes* of any Tensors contained in module states cannot be dynamic, i.e.,
|
||||
|
|
@ -545,11 +545,11 @@ Updating module states is possible, but must follow the rules below:
|
|||
To do so, it must be registered as a buffer during module initialization.
|
||||
|
||||
- **A buffer can be updated**, where the updating can be in-place (e.g.,
|
||||
`self.buffer[:] = ...`) or not (e.g., `self.buffer = ...`).
|
||||
``self.buffer[:] = ...``) or not (e.g., ``self.buffer = ...``).
|
||||
|
||||
- **A parameter cannot be updated**. Typically parameters are updated only
|
||||
during training, not during inference. We recommend exporting with
|
||||
`no_grad()` to avoid parameter updates at export time.
|
||||
:func:`torch.no_grad` to avoid parameter updates at export time.
|
||||
|
||||
Effects of functionalization
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
|
|||
Loading…
Reference in a new issue