pytorch/docs/source/func.api.rst

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

78 lines
1.8 KiB
ReStructuredText
Raw Normal View History

[torch.func] Setup torch.func, populate it with all transforms (#91016) This PR sets up torch.func and populates it with the following APIs: - grad - grad_and_value - vjp - jvp - jacrev - jacfwd - hessian - functionalize - vmap It also renames all instances of `functorch` in the APIs for those docs to `torch.func`. We rewrite the `__module__` fields on some of the above APIs so that the APIs fit PyTorch's public api definition. - For an API to be public, it must have a `__module__` that points to a public PyTorch submodule. However, `torch._functorch.eager_transforms` is not public due to the leading underscore. - The solution is to rewrite `__module__` to point to where the API is exposed (torch.func). This is what both Numpy and JAX do for their APIs. - h/t pmeier in https://github.com/pytorch/pytorch/issues/90284#issuecomment-1348595246 for idea and code - The helper function, `exposed_in`, is confined to torch._functorch/utils for now because we're not completely sure if this should be the long-term solution. Implication for functorch.* APIs: - functorch.grad is the same object as torch.func.grad - this means that the functorch.grad docstring is actually the torch.func.grad docstring and will refer to torch.func instead of functorch. - This isn't really a problem since the plan on record is to deprecate functorch in favor of torch.func. We can fix these if we really want, but I'm not sure if a solution is worth maintaining. Test Plan: - view docs preview Future: - vmap should actually just be torch.vmap. This requires an extra step where I need to test internal callsites, so, I'm separating it into a different PR. - make_fx should be in torch.func to be consistent with `import functorch`. This one is a bit more of a headache to deal with w.r.t. public api, so going to deal with it separately. - beef up func.rst with everything else currently on the functorch documention website. func.rst is currently just an empty shell. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91016 Approved by: https://github.com/samdow
2022-12-19 20:14:25 +00:00
torch.func API Reference
========================
.. currentmodule:: torch.func
.. automodule:: torch.func
Function Transforms
-------------------
.. autosummary::
:toctree: generated
:nosignatures:
vmap
grad
grad_and_value
vjp
jvp
jacrev
jacfwd
hessian
functionalize
Utilities for working with torch.nn.Modules
-------------------------------------------
In general, you can transform over a function that calls a ``torch.nn.Module``.
For example, the following is an example of computing a jacobian of a function
that takes three values and returns three values:
.. code-block:: python
model = torch.nn.Linear(3, 3)
def f(x):
return model(x)
x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)
However, if you want to do something like compute a jacobian over the parameters
of the model, then there needs to be a way to construct a function where the
parameters are the inputs to the function.
That's what :func:`functional_call` is for:
it accepts an nn.Module, the transformed ``parameters``, and the inputs to the
Module's forward pass. It returns the value of running the Module's forward pass
with the replaced parameters.
Here's how we would compute the Jacobian over the parameters
.. code-block:: python
model = torch.nn.Linear(3, 3)
def f(params, x):
return torch.func.functional_call(model, params, x)
x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)
.. autosummary::
:toctree: generated
:nosignatures:
functional_call
stack_module_state
replace_all_batch_norm_modules_
If you're looking for information on fixing Batch Norm modules, please follow the
guidance here
.. toctree::
:maxdepth: 1
func.batch_norm