mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add docs for torch.compile(numpy) (#109710)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109710 Approved by: https://github.com/ev-br, https://github.com/gchanan, https://github.com/peterbell10
This commit is contained in:
parent
7a04ae6fba
commit
13bd4ed933
1 changed files with 156 additions and 2 deletions
|
|
@ -317,8 +317,8 @@ them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2.
|
|||
CUDA graphs with Triton are enabled by default in inductor but removing
|
||||
them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``.
|
||||
|
||||
``torch.func`` works with ``torch.compile`` (for `grad` and `vmap` transforms)
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
Does ``torch.func`` work with ``torch.compile`` (for `grad` and `vmap` transforms)?
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Applying a ``torch.func`` transform to a function that uses ``torch.compile``
|
||||
does not work:
|
||||
|
|
@ -528,6 +528,160 @@ invokes an ``nn.Module``. This is because the outputs now depend on the
|
|||
parameters of the ``nn.Module``. To get this to work, use
|
||||
``torch.func.functional_call`` to extract the module state.
|
||||
|
||||
Does NumPy work with ``torch.compile``?
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Starting in 2.1, ``torch.compile`` understands native NumPy programs that
|
||||
work on NumPy arrays, and mixed PyTorch-NumPy programs that convert from PyTorch
|
||||
to NumPy and back via ``x.numpy()``, ``torch.from_numpy``, and related functions.
|
||||
|
||||
.. _nonsupported-numpy-feats:
|
||||
|
||||
Which NumPy features does ``torch.compile`` support?
|
||||
----------------------------------------------------
|
||||
|
||||
NumPy within ``torch.compile`` follows NumPy 2.0 pre-release.
|
||||
|
||||
Generally, ``torch.compile`` is able to trace through most NumPy constructions,
|
||||
and when it cannot, it falls back to eager and lets NumPy execute that piece of
|
||||
code. Even then, there are a few features where ``torch.compile`` semantics
|
||||
slightly deviate from those of NumPy:
|
||||
|
||||
- NumPy scalars: We model them as 0-D arrays. That is, ``np.float32(3)`` returns
|
||||
a 0-D array under ``torch.compile``. To avoid a graph break, it is best to use this 0-D
|
||||
array. If this breaks your code, you can workaround this by casting the NumPy scalar
|
||||
to the relevant Python scalar type ``bool/int/float``.
|
||||
|
||||
- Negative strides: ``np.flip`` and slicing with a negative step return a copy.
|
||||
|
||||
- Type promotion: NumPy's type promotion will change in NumPy 2.0. The new rules
|
||||
are described in `NEP 50 <https://numpy.org/neps/nep-0050-scalar-promotion.html)>`__.
|
||||
``torch.compile`` implements NEP 50 rather than the current soon-to-be deprecated rules.
|
||||
|
||||
- ``{tril,triu}_indices_from/{tril,triu}_indices`` return arrays rather than a tuple of arrays.
|
||||
|
||||
There are other features for which we do not support tracing and we gracefully
|
||||
fallback to NumPy for their execution:
|
||||
|
||||
- Non-numeric dtypes like datetimes, strings, chars, void, structured dtypes and recarrays.
|
||||
|
||||
- Long dtypes ``np.float128/np.complex256`` and some unsigned dtypes ``np.uint16/np.uint32/np.uint64``.
|
||||
|
||||
- ``ndarray`` subclasses.
|
||||
|
||||
- Masked arrays.
|
||||
|
||||
- Esoteric ufunc machinery like ``axes=[(n,k),(k,m)->(n,m)]`` and ufunc methods (e.g., ``np.add.reduce``).
|
||||
|
||||
- Sorting / ordering ``complex64/complex128`` arrays.
|
||||
|
||||
- NumPy ``np.poly1d`` and ``np.polynomial``.
|
||||
|
||||
- Positional ``out1, out2`` args in functions with 2 or more returns (``out=tuple`` does work).
|
||||
|
||||
- ``__array_function__``, ``__array_interface__`` and ``__array_wrap__``.
|
||||
|
||||
- ``ndarray.ctypes`` attribute.
|
||||
|
||||
Can I execute NumPy code on CUDA via ``torch.compile``?
|
||||
-------------------------------------------------------
|
||||
|
||||
Yes you can! To do so, you may simply execute your code within a ``torch.device("cuda")``
|
||||
context. Consider the example
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
@torch.compile
|
||||
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
|
||||
return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
|
||||
|
||||
X = np.random.randn(1024, 64)
|
||||
Y = np.random.randn(1024, 64)
|
||||
with torch.device("cuda"):
|
||||
Z = numpy_fn(X, Y)
|
||||
|
||||
|
||||
In this example, ``numpy_fn`` will be executed in CUDA. For this to be
|
||||
possible, ``torch.compile`` automatically moves ``X`` and ``Y`` from CPU
|
||||
to CUDA, and then it moves the result ``Z`` from CUDA to CPU. If we are
|
||||
executing this function several times in the same program run, we may want
|
||||
to avoid all these rather expensive memory copies. To do so, we just need
|
||||
to tweak our ``numpy_fn`` so that it accepts cuda Tensors and returns tensors:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@torch.compile
|
||||
def numpy_fn(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
|
||||
X, Y = X.numpy(), Y.numpy()
|
||||
Z = np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
|
||||
return torch.from_numpy(Z)
|
||||
|
||||
X = torch.randn(1024, 64, device="cuda")
|
||||
Y = torch.randn(1024, 64, device="cuda")
|
||||
with torch.device("cuda"):
|
||||
Z = numpy_fn(X, Y)
|
||||
|
||||
By doing this, we explicitly create the tensors in CUDA memory, and we keep
|
||||
them there. In this case ``X.numpy()`` and ``from_numpy()`` are hints to the compiler
|
||||
but no real data movement happens. Note that the original program would not run
|
||||
on eager mode now. If you want to run it in eager mode, you would need to call
|
||||
``.numpy(force=True)`` doing ``Z = Z.cuda()`` before returning
|
||||
``Z``. Of course, doing this would execute the program on eager mode NumPy, and
|
||||
on CPU.
|
||||
|
||||
|
||||
How do I debug NumPy code under ``torch.compile``?
|
||||
--------------------------------------------------
|
||||
|
||||
Debugging JIT compiled code is challenging, given the complexity of modern
|
||||
compilers and the daunting errors that they raise.
|
||||
`The tutorial on how to diagnose runtime errors within torch.compile <https://pytorch.org/docs/main/torch.compiler_troubleshooting.html#diagnosing-runtime-errors>`__
|
||||
contains a few tips and tricks on how to tackle this task.
|
||||
|
||||
If the above is not enough to pinpoint the origin of the issue, there are still
|
||||
a few other NumPy-specific tools we can use. We can discern whether the bug
|
||||
is entirely in the PyTorch code by disabling tracing through NumPy functions:
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from torch._dynamo import config
|
||||
config.trace_numpy = False
|
||||
|
||||
If the bug lies in the traced NumPy code, we can execute the NumPy code eagerly (without ``torch.compile``)
|
||||
using PyTorch as a backend by importing ``import torch._numpy as np``.
|
||||
This should just be used for **debugging purposes** and is in no way a
|
||||
replacement for the PyTorch API, as it is **much less performant** and, as a
|
||||
private API, **may change without notice**. At any rate, ``torch._numpy`` is a
|
||||
Python implementation of NumPy in terms of PyTorch and it is used internally by ``torch.compile`` to
|
||||
transform NumPy code into Pytorch code. It is rather easy to read and modify,
|
||||
so if you find any bug in it feel free to submit a PR fixing it or simply open
|
||||
an issue.
|
||||
|
||||
If the program does work when importing ``torch._numpy as np``, chances are
|
||||
that the bug is in TorchDynamo. If this is the case, please feel open an issue
|
||||
with a `minimal reproducer <https://pytorch.org/docs/2.1/torch.compiler_troubleshooting.html>`__.
|
||||
|
||||
I ``torch.compile`` some NumPy code and I did not see any speed-up.
|
||||
-------------------------------------------------------------------
|
||||
|
||||
The best place to start is the
|
||||
`tutorial with general advice for how to debug these sort of torch.compile issues <https://pytorch.org/docs/main/torch.compiler_faq.html#why-am-i-not-seeing-speedups>`__.
|
||||
|
||||
Some graph breaks may happen because of the use of unsupported features. See
|
||||
:ref:`nonsupported-numpy-feats`. More generally, it is useful to keep in mind
|
||||
that some widely used NumPy features do not play well with compilers. For
|
||||
example, in-place modifications make reasoning difficult within the compiler and
|
||||
often yield worse performance than their out-of-place counterparts.As such, it is best to avoid
|
||||
them. Same goes for the use of the ``out=`` parameter. Instead, prefer
|
||||
out-of-place ops and let ``torch.compile`` optimize the memory use. Same goes
|
||||
for data-dependent ops like masked indexing through boolean masks, or
|
||||
data-dependent control flow like ``if`` or ``while`` constructions.
|
||||
|
||||
|
||||
Which API to use for fine grain tracing?
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue