Fix links rendering when surrounding code in Dynamo deepdive (#123427)

I thought the RST was rendering correctly, but here we are.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123427
Approved by: https://github.com/peterbell10
This commit is contained in:
lezcano 2024-04-13 01:51:52 +00:00 committed by PyTorch MergeBot
parent 7e3f80f00f
commit 891736f115
4 changed files with 49 additions and 57 deletions

View file

@ -668,7 +668,8 @@ Read More
:caption: Deep Dive for PyTorch Developers
:maxdepth: 1
torch.compiler_deepdive
torch.compiler_dynamo_overview
torch.compiler_dynamo_deepdive
torch.compiler_dynamic_shapes
torch.compiler_fake_tensor

View file

@ -102,7 +102,7 @@ Read More
:caption: Deep Dive for PyTorch Developers
:maxdepth: 1
torch.compiler_deepdive
torch.compiler_dynamo_overview
torch.compiler_dynamo_deepdive
torch.compiler_dynamic_shapes
torch.compiler_nn_module

View file

@ -1,3 +1,5 @@
.. _torch.compiler_dynamo_deepdive:
Dynamo Deep-Dive
================
@ -14,7 +16,7 @@ ground up. We will discuss the functionality it provides, and how it is
implemented. By the end of this post, you will have a better
understanding of what went wrong when you ``torch.compiled`` a PyTorch
program and the compilation errored out, or succeeded but the speed-up
was not what you expected. [1]_
was not what you expected.
A Gentle Introduction to Dynamo
-------------------------------
@ -60,11 +62,11 @@ we see the output that Dynamo traced
We call this a **graph (or trace) of the function for the given
inputs**. This is represented via an `FX
graph <https://pytorch.org/docs/stable/fx.html>`__. We will simply think
graph <https://pytorch.org/docs/main/fx.html>`__. We will simply think
of an FX graph as a container that stores a list of function calls.
The first thing we should notice is that the graph is a linear sequence
of PyTorch operations. [2]_ Dynamo records all the PyTorch operations
of PyTorch operations. [1]_ Dynamo records all the PyTorch operations
and stores them sequentially. For example, it split ``z = (x - y) ** 2``
into its two constituting operations, ``sub = l_x_ - l_y_`` and
``z = sub ** 2``.
@ -215,10 +217,10 @@ variables and their names - The builtin functions like ``abs`` or
``print``
You can see all the fields
`here <https://github.com/pytorch/pytorch/blob/e891a3bba9f05697d72776f6e89347231a141f03/torch/csrc/dynamo/eval_frame.c#L50-L59>`__. [3]_
`here <https://github.com/pytorch/pytorch/blob/e891a3bba9f05697d72776f6e89347231a141f03/torch/csrc/dynamo/eval_frame.c#L50-L59>`__. [2]_
In summary, CPython provides the users interpreter with all the
information necessary to execute the function. [4]_
information necessary to execute the function. [3]_
With this API, we can implement a tracer by implementing an interpreter
that runs the code and records in a graph all the PyTorch operations
@ -242,10 +244,10 @@ Implementing CPython in Python
So, we are back in the Python world. We have the bytecode of a function,
and all the context necessary to execute it. In particular, we have
landed at
```_convert_frame_assert`` <https://github.com/pytorch/pytorch/blob/b6df8414601e1e086e830ca9e919e7fdc8874e71/torch/_dynamo/convert_frame.py#L272-L274>`__.
`_convert_frame_assert <https://github.com/pytorch/pytorch/blob/b6df8414601e1e086e830ca9e919e7fdc8874e71/torch/_dynamo/convert_frame.py#L272-L274>`__.
This is the function that the decorator ``torch.compile`` returns! We
get to this function from
```_dynamo.optimize`` <https://github.com/pytorch/pytorch/blob/b6df8414601e1e086e830ca9e919e7fdc8874e71/torch/_dynamo/eval_frame.py#L715-L727>`__.
`_dynamo.optimize <https://github.com/pytorch/pytorch/blob/b6df8414601e1e086e830ca9e919e7fdc8874e71/torch/_dynamo/eval_frame.py#L715-L727>`__.
The decorator ``torch.compile`` is just a nice API around
``_dynamo.optimize``.
@ -259,8 +261,7 @@ of Dynamo.
The parent class of the internal class structure is ``VariableTracker``
and represents the different objects that Dynamo understands. For
example, ``ListVariable``, represents a ``list`` object, and keeps
internally a `list of
``VariableTracker``\ s <https://github.com/pytorch/pytorch/blob/e38a3a6079a3861b4bc9f256120ec661f34e726d/torch/_dynamo/variables/lists.py#L48-L56>`__.
internally a `list of VariableTrackers <https://github.com/pytorch/pytorch/blob/e38a3a6079a3861b4bc9f256120ec661f34e726d/torch/_dynamo/variables/lists.py#L48-L56>`__.
Another example of ``VariableTracker`` is
`ConstantVariable <https://github.com/pytorch/pytorch/blob/83c0763dda1f93c6cf552ba88260a0dc7a3ecb70/torch/_dynamo/variables/constant.py#L30>`__.
ConstantVariable wraps all the `objects considered constant by
@ -269,12 +270,12 @@ We also have special subclasses for objects that require special
attention, like
`TensorVariable <https://github.com/pytorch/pytorch/blob/83c0763dda1f93c6cf552ba88260a0dc7a3ecb70/torch/_dynamo/variables/tensor.py#L68-L69>`__.
All these internal classes are defined in the
```torch/_dynamo/variables`` <https://github.com/pytorch/pytorch/tree/83c0763dda1f93c6cf552ba88260a0dc7a3ecb70/torch/_dynamo/variables>`__
`torch/_dynamo/variables <https://github.com/pytorch/pytorch/tree/83c0763dda1f93c6cf552ba88260a0dc7a3ecb70/torch/_dynamo/variables>`__
folder.
Python objects are wrapped into their corresponding ``VariableTracker``
class in
```VariableBuilder._wrap`` <https://github.com/pytorch/pytorch/blob/83c0763dda1f93c6cf552ba88260a0dc7a3ecb70/torch/_dynamo/variables/builder.py#L365>`__.
`VariableBuilder._wrap <https://github.com/pytorch/pytorch/blob/83c0763dda1f93c6cf552ba88260a0dc7a3ecb70/torch/_dynamo/variables/builder.py#L365>`__.
This function is just a very long chain of ``elif``\ s that tries to
recursively pattern-match the Python inputs into the appropriate type of
``VariableTracker``.
@ -304,9 +305,9 @@ traced into the right ``VariableTracker``.
Ok, so we have an IR for our tracer, now we *just* need to reimplement
CPythons stack machine. This is implemented by
```InstructorTranslatorBase`` <https://github.com/pytorch/pytorch/blob/69f112d5867f785a3a090a0c6d6644ae047033ac/torch/_dynamo/symbolic_convert.py#L576-L594>`__
`InstructorTranslatorBase <https://github.com/pytorch/pytorch/blob/69f112d5867f785a3a090a0c6d6644ae047033ac/torch/_dynamo/symbolic_convert.py#L576-L594>`__
in
```symbolic_convert.py`` <https://github.com/pytorch/pytorch/blob/69f112d5867f785a3a090a0c6d6644ae047033ac/torch/_dynamo/symbolic_convert.py>`__.
`symbolic_convert.py <https://github.com/pytorch/pytorch/blob/69f112d5867f785a3a090a0c6d6644ae047033ac/torch/_dynamo/symbolic_convert.py>`__.
``InstructionTranslatorBase`` has about 200 methods, implementing almost
all of Python bytecodes. As an example, we can see the implementation of
@ -330,10 +331,9 @@ Generating the Output Graph
With a way to symbolically execute Python code, we are set to extract
the PyTorch operations that happen during the symbolic execution of a
program given some inputs. This is implemented in Dynamo via the
```OutputGraph`` <https://github.com/pytorch/pytorch/blob/69f112d5867f785a3a090a0c6d6644ae047033ac/torch/_dynamo/output_graph.py#L221-L230>`__
`OutputGraph <https://github.com/pytorch/pytorch/blob/69f112d5867f785a3a090a0c6d6644ae047033ac/torch/_dynamo/output_graph.py#L221-L230>`__
object. The ``OutputGraph`` object is `bound to an
``InstructionTranslator``
object <https://github.com/pytorch/pytorch/blob/69f112d5867f785a3a090a0c6d6644ae047033ac/torch/_dynamo/symbolic_convert.py#L2060-L2071>`__
`InstructionTranslator object <https://github.com/pytorch/pytorch/blob/69f112d5867f785a3a090a0c6d6644ae047033ac/torch/_dynamo/symbolic_convert.py#L2060-L2071>`__
and it tracks all the data necessary to create the FX graph which will
be returned by Dynamo.
@ -342,9 +342,9 @@ All the inputs and intermediary elements of the FX graph are
``fx.Proxy``\ s. ``fx.Proxy``\ s are used to build the FX graph.
In particular, they record every PyTorch operation performed on them
into the graph. You can can create a new operation to be added to
the graph by calling ```create_proxy`` <https://github.com/pytorch/pytorch/blob/fb80f05ee2e1cba17892980701bfd5dbce58349f/torch/_dynamo/output_graph.py#L430-L431>`__.
the graph by calling `create_proxy <https://github.com/pytorch/pytorch/blob/fb80f05ee2e1cba17892980701bfd5dbce58349f/torch/_dynamo/output_graph.py#L430-L431>`__.
Then, we can add it to the graph through the function
```wrap_fx_proxy`` <https://github.com/pytorch/pytorch/blob/fb80f05ee2e1cba17892980701bfd5dbce58349f/torch/_dynamo/variables/builder.py#L1311>`__.
`wrap_fx_proxy <https://github.com/pytorch/pytorch/blob/fb80f05ee2e1cba17892980701bfd5dbce58349f/torch/_dynamo/variables/builder.py#L1311>`__.
A graph stores operations on tensors… and operations on symbolic
integers. We will discuss symbolic integers later on, but first we will
@ -358,7 +358,7 @@ Making Dynamo Sound: Guards
At this point, we have a way to trace programs completely disregarding control flow.
And for that, we have reimplemented all of CPython… If this sounds like a bit of an
overkill, that is because it is.
```torch.jit.trace`` <https://pytorch.org/docs/stable/generated/torch.jit.trace.html>`__
`torch.jit.trace <https://pytorch.org/docs/main/generated/torch.jit.trace.html>`__
already implements this without all this machinery, so what gives?
The issue with ``torch.jit.trace``, as it is warned in its docs, is that
@ -399,7 +399,7 @@ with ``TORCH_LOGS=guards`` prints (among other guards)
L['b'] == 'Hello'
This reads as “the local variable ``b`` should have a specific type
(``str`` in this case, represented by the constant `9433...`) and
(``str`` in this case, represented by the constant ``9433...``) and
its value should be ``'Hello'``”. If we then execute the function
again passing a different argument
@ -442,15 +442,15 @@ the objects they contain. In
return a * x
``x`` and ``y`` have
```LocalSource`` <https://github.com/pytorch/pytorch/blob/40dc0580a69565b06ec5263efe5d87cecc8200f7/torch/_dynamo/source.py#L80-L92>`__
`LocalSource <https://github.com/pytorch/pytorch/blob/40dc0580a69565b06ec5263efe5d87cecc8200f7/torch/_dynamo/source.py#L80-L92>`__
as their source, and ``y[0]`` has
```GetItemSource`` <https://github.com/pytorch/pytorch/blob/40dc0580a69565b06ec5263efe5d87cecc8200f7/torch/_dynamo/source.py#L302>`__,
`GetItemSource <https://github.com/pytorch/pytorch/blob/40dc0580a69565b06ec5263efe5d87cecc8200f7/torch/_dynamo/source.py#L302>`__,
which stores a ``LocalSource`` inside. On the other hand, ``a`` will not
have a source as it is an intermediate variable that only exists within
the fx graph.
All these are defined in
```torch/_dynamo/source.py`` <https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/source.py>`__.
`torch/_dynamo/source.py <https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/source.py>`__.
We can see the guard generated by ``GetItemSource`` in the following
example:
@ -496,9 +496,9 @@ Symbolic Shapes
Another point we discussed in the introduction is that Dynamo knows how
to trace integers. In order to implement this, we use a symbolic class
```torch.SymInt`` <https://github.com/pytorch/pytorch/blob/fb80f05ee2e1cba17892980701bfd5dbce58349f/torch/__init__.py#L244-L249>`__\ [5]_
`torch.SymInt <https://github.com/pytorch/pytorch/blob/fb80f05ee2e1cba17892980701bfd5dbce58349f/torch/__init__.py#L244-L249>`__
that acts like an ``int`` but it records all the operations performed on
it in the output FX graph. We already saw this class in the introduction
it in the output FX graph. [4]_ We already saw this class in the introduction
when introducing symbolic integer tracing.
Let us now discuss the three properties that define symbolic shape
@ -588,7 +588,7 @@ more general guards on this more generic kernel.
**Compilation performance tip**. If you know that a dimension will vary
in size, you can mark it as dynamic by calling
```torch._dynamo.mark_dynamic`` <https://github.com/pytorch/pytorch/blob/66a76516bfc341b2b55bb2056d2faa9c2de46d69/torch/_dynamo/decorators.py#L176>`__
`torch._dynamo.mark_dynamic <https://github.com/pytorch/pytorch/blob/66a76516bfc341b2b55bb2056d2faa9c2de46d69/torch/_dynamo/decorators.py#L176>`__
before calling ``torch.compile``. This will avoid the first compilation
with a static shape. There are other useful utility functions like
``maybe_mark_dynamic`` or ``mark_static``. You can also have all
@ -671,7 +671,7 @@ arbitrary Python code” is perhaps a bit too general. Dynamo implements a
good part of Python, but does it implement the more complex parts, like
coroutines or async? Does it implement the whole Python standard
library? NumPy also has a Python API. Does ``torch.compile`` also
understand NumPy? and Django? [6]_
understand NumPy? and Django? [5]_
Pythons ecosystem is massive, and a good part of it is written in other
more performant languages like C++ or Rust, and it just exposes Python
@ -683,15 +683,15 @@ The usual way machine learning tracers handle this issue is by informing
the user that the operation they choked on and giving up tracing
altogether. This would pose a real usability issue in the case of
PyTorch, where its users are used to the flexibility it gives them. As a
real-world example the ```doctr_det_predictor`` model uses NumPy and the
``cv2`` library to postprocess the models
real-world example the ``doctr_det_predictor`` model uses NumPy and the
``cv2`` library to `postprocess the models
result <https://github.com/mindee/doctr/blob/f2114758d529ed8d3d0030581638f0520b6b98d8/doctr/models/detection/core.py#L86>`__.
Here is another place where having access to CPython is interesting.
Rather than erroring out, Dynamo can let CPython run that problematic
code! To do this, Dynamo generates at trace time one graph with all the
operations before the problematic code, and one with all the operations
after. [7]_ Then, at runtime, it will delegate to CPython to execute the
after. [6]_ Then, at runtime, it will delegate to CPython to execute the
first graph, then the problematic code, and then the second graph. This
process of stopping the tracing and generating multiple graphs is called
a **graph break**.
@ -811,10 +811,9 @@ implementing the strategy that we described before
The code generation of the stack in Dynamo is delegated to
``VariableTracker`` subclasses. Every ``VariableTracker`` object in
Dynamo has a ```reconstruct``
method <https://github.com/pytorch/pytorch/blob/e891a3bba9f05697d72776f6e89347231a141f03/torch/_dynamo/variables/lists.py#L307-L309>`__
that generates the necessary bytecode to create the python object it
represents on the stack.
Dynamo has a `reconstruct <https://github.com/pytorch/pytorch/blob/e891a3bba9f05697d72776f6e89347231a141f03/torch/_dynamo/variables/lists.py#L307-L309>`__
method that generates the necessary bytecode to create the python object
it represents on the stack.
**Debugging tip**. Graph breaks hamper performance, and as such, it is
best to avoid them. Running a program with ``TORCH_LOGS=graph_breaks``
@ -843,34 +842,24 @@ github <https://github.com/pytorch/pytorch/issues?q=is%3Aissue+is%3Aopen+label%3
Many of them require very minor changes in the code, once you find where
you need to make those changes.
.. [1]
In the same way that Dynamo takes its name from
[Dynamorio].(https://dynamorio.org/), this blog posts name is a
small nod to `You Could Have Invented Spectral
Sequences <https://www.ams.org/notices/200601/fea-chow.pdf>`__.
Footnotes
---------
.. [2]
In the literature, this is called a Directed Acyclical Graph (DAG).
.. [1] In the literature, this is called a Directed Acyclical Graph (DAG).
.. [3]
All this binding code lives in ``torch/csrc/dynamo/eval_frame.c``.
.. [2] All this binding code lives in ``torch/csrc/dynamo/eval_frame.c``.
.. [4]
In CPython lingo, the set of all these objects are called `a
.. [3] In CPython lingo, the set of all these objects are called `a
frame <https://github.com/python/cpython/blob/f26bfe4b25f7e5a4f68fcac26207b7175abad208/Include/internal/pycore_frame.h#L57-L71>`__.
.. [5]
There are also ``SymBool`` and ``SymFloat`` classes. The latter one
.. [4] There are also ``SymBool`` and ``SymFloat`` classes. The latter one
is not used all that much at the time of this writing.
.. [6]
Interestingly enough, it does understand NumPy code! Have a look at
.. [5] Interestingly enough, it does understand NumPy code! Have a look at
`this blogpost <https://pytorch.org/blog/compiling-numpy-code/>`__
and `the
docs <https://pytorch.org/docs/stable/torch.compiler_faq.html#does-numpy-work-with-torch-compile>`__.
and `the docs <https://pytorch.org/docs/main/torch.compiler_faq.html#does-numpy-work-with-torch-compile>`__.
Now, this is just possible because we reimplemented NumPy using
PyTorch. Good luck implementing Django in PyTorch though…
.. [7]
Assuming there is just one piece of problematic code. If there are
.. [6] Assuming there is just one piece of problematic code. If there are
more, Dynamo can split the code into as many graphs as it needs.

View file

@ -1,5 +1,5 @@
TorchDynamo Deep Dive
=====================
TorchDynamo Overview
====================
Before you read this section, read :ref:`torch.compiler_overview`.
@ -346,3 +346,5 @@ To summarize, the compiled code is conceptually equivalent to the code below:
The following diagram demonstrates how ``torch.compile`` transforms and optimizes user-written code: it first extracts computation graphs from the user-written function, and compiles these graphs into optimized functions, then assembles them into a new function, which is functionally equivalent to the user-written code but optimized to have a good computation speed.
.. image:: _static/img/dynamo/flowchart.jpg
To learn more about how all this is implemented internally, see :ref:`torch.compiler_dynamo_deepdive`.