[FX] Add wrap() docstring to docs and add decorator example (#50555)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50555

Test Plan: Imported from OSS

Reviewed By: Chillee

Differential Revision: D25917564

Pulled By: jamesr66a

fbshipit-source-id: 20c7c8b1192fa80c6a0bb9e18910791bd7167232
This commit is contained in:
James Reed 2021-01-14 21:30:12 -08:00 committed by Facebook GitHub Bot
parent adc65e7c8d
commit 6882f9cc1c
2 changed files with 17 additions and 4 deletions

View file

@ -13,6 +13,8 @@ API Reference
.. autofunction:: torch.fx.symbolic_trace
.. autofunction:: torch.fx.wrap
.. autoclass:: torch.fx.GraphModule
:members:

View file

@ -391,9 +391,9 @@ def _unpatch_wrapped_functions(orig_fns : List[PatchedFn]):
def wrap(fn_or_name : Union[str, Callable]):
"""
This function can be called at global scope in a module to cause
references to the global function secified by `fn_name` to use
them in FX.
This function can be called at module-level scope to register fn_or_name as a "leaf function".
A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being
traced through::
# foo/bar/baz.py
def my_custom_function(x, y):
@ -406,9 +406,20 @@ def wrap(fn_or_name : Union[str, Callable]):
# the graph rather than tracing it.
return my_custom_function(x, y)
This function can also equivalently be used as a decorator::
# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):
return x * x + y * y
A wrapped function can be thought of a "leaf function", analogous to the concept of
"leaf modules", that is, they are functions that are left as calls in the FX trace
rather than traced through.
Args:
fn_name (Union[str, Callable]): The function or name of the global function to insert into the
fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the
graph when it's called
"""
if callable(fn_or_name):