pytorch/torch/func/__init__.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

32 lines
656 B
Python
Raw Normal View History

from torch._functorch.apis import grad, grad_and_value, vmap
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
[torch.func] Setup torch.func, populate it with all transforms (#91016) This PR sets up torch.func and populates it with the following APIs: - grad - grad_and_value - vjp - jvp - jacrev - jacfwd - hessian - functionalize - vmap It also renames all instances of `functorch` in the APIs for those docs to `torch.func`. We rewrite the `__module__` fields on some of the above APIs so that the APIs fit PyTorch's public api definition. - For an API to be public, it must have a `__module__` that points to a public PyTorch submodule. However, `torch._functorch.eager_transforms` is not public due to the leading underscore. - The solution is to rewrite `__module__` to point to where the API is exposed (torch.func). This is what both Numpy and JAX do for their APIs. - h/t pmeier in https://github.com/pytorch/pytorch/issues/90284#issuecomment-1348595246 for idea and code - The helper function, `exposed_in`, is confined to torch._functorch/utils for now because we're not completely sure if this should be the long-term solution. Implication for functorch.* APIs: - functorch.grad is the same object as torch.func.grad - this means that the functorch.grad docstring is actually the torch.func.grad docstring and will refer to torch.func instead of functorch. - This isn't really a problem since the plan on record is to deprecate functorch in favor of torch.func. We can fix these if we really want, but I'm not sure if a solution is worth maintaining. Test Plan: - view docs preview Future: - vmap should actually just be torch.vmap. This requires an extra step where I need to test internal callsites, so, I'm separating it into a different PR. - make_fx should be in torch.func to be consistent with `import functorch`. This one is a bit more of a headache to deal with w.r.t. public api, so going to deal with it separately. - beef up func.rst with everything else currently on the functorch documention website. func.rst is currently just an empty shell. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91016 Approved by: https://github.com/samdow
2022-12-19 20:14:25 +00:00
from torch._functorch.eager_transforms import (
debug_unwrap,
[torch.func] Setup torch.func, populate it with all transforms (#91016) This PR sets up torch.func and populates it with the following APIs: - grad - grad_and_value - vjp - jvp - jacrev - jacfwd - hessian - functionalize - vmap It also renames all instances of `functorch` in the APIs for those docs to `torch.func`. We rewrite the `__module__` fields on some of the above APIs so that the APIs fit PyTorch's public api definition. - For an API to be public, it must have a `__module__` that points to a public PyTorch submodule. However, `torch._functorch.eager_transforms` is not public due to the leading underscore. - The solution is to rewrite `__module__` to point to where the API is exposed (torch.func). This is what both Numpy and JAX do for their APIs. - h/t pmeier in https://github.com/pytorch/pytorch/issues/90284#issuecomment-1348595246 for idea and code - The helper function, `exposed_in`, is confined to torch._functorch/utils for now because we're not completely sure if this should be the long-term solution. Implication for functorch.* APIs: - functorch.grad is the same object as torch.func.grad - this means that the functorch.grad docstring is actually the torch.func.grad docstring and will refer to torch.func instead of functorch. - This isn't really a problem since the plan on record is to deprecate functorch in favor of torch.func. We can fix these if we really want, but I'm not sure if a solution is worth maintaining. Test Plan: - view docs preview Future: - vmap should actually just be torch.vmap. This requires an extra step where I need to test internal callsites, so, I'm separating it into a different PR. - make_fx should be in torch.func to be consistent with `import functorch`. This one is a bit more of a headache to deal with w.r.t. public api, so going to deal with it separately. - beef up func.rst with everything else currently on the functorch documention website. func.rst is currently just an empty shell. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91016 Approved by: https://github.com/samdow
2022-12-19 20:14:25 +00:00
functionalize,
hessian,
jacfwd,
jacrev,
jvp,
linearize,
vjp,
[torch.func] Setup torch.func, populate it with all transforms (#91016) This PR sets up torch.func and populates it with the following APIs: - grad - grad_and_value - vjp - jvp - jacrev - jacfwd - hessian - functionalize - vmap It also renames all instances of `functorch` in the APIs for those docs to `torch.func`. We rewrite the `__module__` fields on some of the above APIs so that the APIs fit PyTorch's public api definition. - For an API to be public, it must have a `__module__` that points to a public PyTorch submodule. However, `torch._functorch.eager_transforms` is not public due to the leading underscore. - The solution is to rewrite `__module__` to point to where the API is exposed (torch.func). This is what both Numpy and JAX do for their APIs. - h/t pmeier in https://github.com/pytorch/pytorch/issues/90284#issuecomment-1348595246 for idea and code - The helper function, `exposed_in`, is confined to torch._functorch/utils for now because we're not completely sure if this should be the long-term solution. Implication for functorch.* APIs: - functorch.grad is the same object as torch.func.grad - this means that the functorch.grad docstring is actually the torch.func.grad docstring and will refer to torch.func instead of functorch. - This isn't really a problem since the plan on record is to deprecate functorch in favor of torch.func. We can fix these if we really want, but I'm not sure if a solution is worth maintaining. Test Plan: - view docs preview Future: - vmap should actually just be torch.vmap. This requires an extra step where I need to test internal callsites, so, I'm separating it into a different PR. - make_fx should be in torch.func to be consistent with `import functorch`. This one is a bit more of a headache to deal with w.r.t. public api, so going to deal with it separately. - beef up func.rst with everything else currently on the functorch documention website. func.rst is currently just an empty shell. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91016 Approved by: https://github.com/samdow
2022-12-19 20:14:25 +00:00
)
from torch._functorch.functional_call import functional_call, stack_module_state
__all__ = [
"grad",
"grad_and_value",
"vmap",
"replace_all_batch_norm_modules_",
"functionalize",
"hessian",
"jacfwd",
"jacrev",
"jvp",
"linearize",
"vjp",
"functional_call",
"stack_module_state",
"debug_unwrap",
]