From 13bd4ed9331d068e8fae236a7acdfafbd4843202 Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 20 Sep 2023 21:20:57 +0000 Subject: [PATCH] Add docs for torch.compile(numpy) (#109710) Pull Request resolved: https://github.com/pytorch/pytorch/pull/109710 Approved by: https://github.com/ev-br, https://github.com/gchanan, https://github.com/peterbell10 --- docs/source/torch.compiler_faq.rst | 158 ++++++++++++++++++++++++++++- 1 file changed, 156 insertions(+), 2 deletions(-) diff --git a/docs/source/torch.compiler_faq.rst b/docs/source/torch.compiler_faq.rst index 1250d7677eb..5dac0f0a8ce 100644 --- a/docs/source/torch.compiler_faq.rst +++ b/docs/source/torch.compiler_faq.rst @@ -317,8 +317,8 @@ them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2. CUDA graphs with Triton are enabled by default in inductor but removing them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``. -``torch.func`` works with ``torch.compile`` (for `grad` and `vmap` transforms) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Does ``torch.func`` work with ``torch.compile`` (for `grad` and `vmap` transforms)? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Applying a ``torch.func`` transform to a function that uses ``torch.compile`` does not work: @@ -528,6 +528,160 @@ invokes an ``nn.Module``. This is because the outputs now depend on the parameters of the ``nn.Module``. To get this to work, use ``torch.func.functional_call`` to extract the module state. +Does NumPy work with ``torch.compile``? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Starting in 2.1, ``torch.compile`` understands native NumPy programs that +work on NumPy arrays, and mixed PyTorch-NumPy programs that convert from PyTorch +to NumPy and back via ``x.numpy()``, ``torch.from_numpy``, and related functions. + +.. _nonsupported-numpy-feats: + +Which NumPy features does ``torch.compile`` support? +---------------------------------------------------- + +NumPy within ``torch.compile`` follows NumPy 2.0 pre-release. + +Generally, ``torch.compile`` is able to trace through most NumPy constructions, +and when it cannot, it falls back to eager and lets NumPy execute that piece of +code. Even then, there are a few features where ``torch.compile`` semantics +slightly deviate from those of NumPy: + +- NumPy scalars: We model them as 0-D arrays. That is, ``np.float32(3)`` returns + a 0-D array under ``torch.compile``. To avoid a graph break, it is best to use this 0-D + array. If this breaks your code, you can workaround this by casting the NumPy scalar + to the relevant Python scalar type ``bool/int/float``. + +- Negative strides: ``np.flip`` and slicing with a negative step return a copy. + +- Type promotion: NumPy's type promotion will change in NumPy 2.0. The new rules + are described in `NEP 50 `__. + ``torch.compile`` implements NEP 50 rather than the current soon-to-be deprecated rules. + +- ``{tril,triu}_indices_from/{tril,triu}_indices`` return arrays rather than a tuple of arrays. + +There are other features for which we do not support tracing and we gracefully +fallback to NumPy for their execution: + +- Non-numeric dtypes like datetimes, strings, chars, void, structured dtypes and recarrays. + +- Long dtypes ``np.float128/np.complex256`` and some unsigned dtypes ``np.uint16/np.uint32/np.uint64``. + +- ``ndarray`` subclasses. + +- Masked arrays. + +- Esoteric ufunc machinery like ``axes=[(n,k),(k,m)->(n,m)]`` and ufunc methods (e.g., ``np.add.reduce``). + +- Sorting / ordering ``complex64/complex128`` arrays. + +- NumPy ``np.poly1d`` and ``np.polynomial``. + +- Positional ``out1, out2`` args in functions with 2 or more returns (``out=tuple`` does work). + +- ``__array_function__``, ``__array_interface__`` and ``__array_wrap__``. + +- ``ndarray.ctypes`` attribute. + +Can I execute NumPy code on CUDA via ``torch.compile``? +------------------------------------------------------- + +Yes you can! To do so, you may simply execute your code within a ``torch.device("cuda")`` +context. Consider the example + +.. code-block:: python + + import torch + import numpy as np + + @torch.compile + def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray: + return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)) + + X = np.random.randn(1024, 64) + Y = np.random.randn(1024, 64) + with torch.device("cuda"): + Z = numpy_fn(X, Y) + + +In this example, ``numpy_fn`` will be executed in CUDA. For this to be +possible, ``torch.compile`` automatically moves ``X`` and ``Y`` from CPU +to CUDA, and then it moves the result ``Z`` from CUDA to CPU. If we are +executing this function several times in the same program run, we may want +to avoid all these rather expensive memory copies. To do so, we just need +to tweak our ``numpy_fn`` so that it accepts cuda Tensors and returns tensors: + +.. code-block:: python + + @torch.compile + def numpy_fn(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: + X, Y = X.numpy(), Y.numpy() + Z = np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)) + return torch.from_numpy(Z) + + X = torch.randn(1024, 64, device="cuda") + Y = torch.randn(1024, 64, device="cuda") + with torch.device("cuda"): + Z = numpy_fn(X, Y) + +By doing this, we explicitly create the tensors in CUDA memory, and we keep +them there. In this case ``X.numpy()`` and ``from_numpy()`` are hints to the compiler +but no real data movement happens. Note that the original program would not run +on eager mode now. If you want to run it in eager mode, you would need to call +``.numpy(force=True)`` doing ``Z = Z.cuda()`` before returning +``Z``. Of course, doing this would execute the program on eager mode NumPy, and +on CPU. + + +How do I debug NumPy code under ``torch.compile``? +-------------------------------------------------- + +Debugging JIT compiled code is challenging, given the complexity of modern +compilers and the daunting errors that they raise. +`The tutorial on how to diagnose runtime errors within torch.compile `__ +contains a few tips and tricks on how to tackle this task. + +If the above is not enough to pinpoint the origin of the issue, there are still +a few other NumPy-specific tools we can use. We can discern whether the bug +is entirely in the PyTorch code by disabling tracing through NumPy functions: + + +.. code-block:: python + + from torch._dynamo import config + config.trace_numpy = False + +If the bug lies in the traced NumPy code, we can execute the NumPy code eagerly (without ``torch.compile``) +using PyTorch as a backend by importing ``import torch._numpy as np``. +This should just be used for **debugging purposes** and is in no way a +replacement for the PyTorch API, as it is **much less performant** and, as a +private API, **may change without notice**. At any rate, ``torch._numpy`` is a +Python implementation of NumPy in terms of PyTorch and it is used internally by ``torch.compile`` to +transform NumPy code into Pytorch code. It is rather easy to read and modify, +so if you find any bug in it feel free to submit a PR fixing it or simply open +an issue. + +If the program does work when importing ``torch._numpy as np``, chances are +that the bug is in TorchDynamo. If this is the case, please feel open an issue +with a `minimal reproducer `__. + +I ``torch.compile`` some NumPy code and I did not see any speed-up. +------------------------------------------------------------------- + +The best place to start is the +`tutorial with general advice for how to debug these sort of torch.compile issues `__. + +Some graph breaks may happen because of the use of unsupported features. See +:ref:`nonsupported-numpy-feats`. More generally, it is useful to keep in mind +that some widely used NumPy features do not play well with compilers. For +example, in-place modifications make reasoning difficult within the compiler and +often yield worse performance than their out-of-place counterparts.As such, it is best to avoid +them. Same goes for the use of the ``out=`` parameter. Instead, prefer +out-of-place ops and let ``torch.compile`` optimize the memory use. Same goes +for data-dependent ops like masked indexing through boolean masks, or +data-dependent control flow like ``if`` or ``while`` constructions. + + Which API to use for fine grain tracing? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~