mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[FX] Add wrap() docstring to docs and add decorator example (#50555)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50555 Test Plan: Imported from OSS Reviewed By: Chillee Differential Revision: D25917564 Pulled By: jamesr66a fbshipit-source-id: 20c7c8b1192fa80c6a0bb9e18910791bd7167232
This commit is contained in:
parent
adc65e7c8d
commit
6882f9cc1c
2 changed files with 17 additions and 4 deletions
|
|
@ -13,6 +13,8 @@ API Reference
|
|||
|
||||
.. autofunction:: torch.fx.symbolic_trace
|
||||
|
||||
.. autofunction:: torch.fx.wrap
|
||||
|
||||
.. autoclass:: torch.fx.GraphModule
|
||||
:members:
|
||||
|
||||
|
|
|
|||
|
|
@ -391,9 +391,9 @@ def _unpatch_wrapped_functions(orig_fns : List[PatchedFn]):
|
|||
|
||||
def wrap(fn_or_name : Union[str, Callable]):
|
||||
"""
|
||||
This function can be called at global scope in a module to cause
|
||||
references to the global function secified by `fn_name` to use
|
||||
them in FX.
|
||||
This function can be called at module-level scope to register fn_or_name as a "leaf function".
|
||||
A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being
|
||||
traced through::
|
||||
|
||||
# foo/bar/baz.py
|
||||
def my_custom_function(x, y):
|
||||
|
|
@ -406,9 +406,20 @@ def wrap(fn_or_name : Union[str, Callable]):
|
|||
# the graph rather than tracing it.
|
||||
return my_custom_function(x, y)
|
||||
|
||||
This function can also equivalently be used as a decorator::
|
||||
|
||||
# foo/bar/baz.py
|
||||
@torch.fx.wrap
|
||||
def my_custom_function(x, y):
|
||||
return x * x + y * y
|
||||
|
||||
A wrapped function can be thought of a "leaf function", analogous to the concept of
|
||||
"leaf modules", that is, they are functions that are left as calls in the FX trace
|
||||
rather than traced through.
|
||||
|
||||
Args:
|
||||
|
||||
fn_name (Union[str, Callable]): The function or name of the global function to insert into the
|
||||
fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the
|
||||
graph when it's called
|
||||
"""
|
||||
if callable(fn_or_name):
|
||||
|
|
|
|||
Loading…
Reference in a new issue