Revert D26246231: [FX] Edits after comprehensive pass over docs

Test Plan: revert-hammer

Differential Revision:
D26246231 (c22bc4821d)

Original commit changeset: 8d6278a9fe1d

fbshipit-source-id: fdc83289f8fe7986bc02181eec55e4e72be2d812
This commit is contained in:
Alban Desmaison 2021-02-04 09:22:27 -08:00 committed by Facebook GitHub Bot
parent 4d85e30133
commit 6c80fd005f
2 changed files with 147 additions and 423 deletions

View file

@ -17,387 +17,173 @@ What is an FX transform? Essentially, it's a function that looks like this.
::
import torch
import torch.fx
def transform(m: nn.Module) -> nn.Module:
fx_model: GraphModule = fx.symbolice_trace(m)
new_model = ...
return new_model
def transform(m: nn.Module,
tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
# Step 1: Acquire a Graph representing the code in `m`
# NOTE: torch.fx.symbolic_trace is a wrapper around a call to
# fx.Tracer.trace and constructing a GraphModule. We'll
# split that out in our transform to allow the caller to
# customize tracing behavior.
graph : torch.fx.Graph = tracer_class().trace(m)
# Step 2: Modify this Graph or create a new one
graph = ...
# Step 3: Construct a Module to return
return torch.fx.GraphModule(m, graph)
Your transform will take in an :class:`torch.nn.Module`, acquire a :class:`Graph`
from it, do some modifications, and return a new
:class:`torch.nn.Module`. You should think of the :class:`torch.nn.Module` that your FX
transform returns as identical to a regular :class:`torch.nn.Module` -- you can pass it to another
Your transform will take in an :class:`torch.nn.Module`, convert it into a
:class:`GraphModule` with :meth:`symbolic_trace`, and return a new
``nn.Module``. You should think of the ``nn.Module`` that your FX transform
returns as identical to a regular ``nn.Module`` -- you can pass it to another
FX transform, you can pass it to TorchScript, or you can
run it. Ensuring that the inputs and outputs of your FX transform are a
:class:`torch.nn.Module` will allow for composability.
``nn.Module`` will allow for composability.
.. note::
It is also possible to modify an existing :class:`GraphModule` instead of
creating a new one, like so::
import torch
import torch.fx
def transform(m : nn.Module) -> nn.Module):
gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)
# Modify gm.graph
# <...>
# Recompile the forward() method of `gm` from its Graph
gm.recompile()
return gm
Note that you MUST call :meth:`GraphModule.recompile` to bring the generated
``forward()`` method on the ``GraphModule`` in sync with the modified :ref:`Graph`.
Given that youve passed in a :class:`torch.nn.Module` that has been traced into a
:class:`Graph`, there are now two primary approaches you can take to building a new
:class:`Graph`.
A Quick Primer on Graphs
^^^^^^^^^^^^^^^^^^^^^^^^
Full treatment of the semantics of graphs can be found in the :class:`Graph`
documentation, but we are going to cover the basics here. A :class:`Graph` is
a data structure that represents a method on a :class:`GraphModule`. The
information that this requires is:
- What are the inputs to the method?
- What are the operations that run inside the method?
- What is the output (i.e. return) value from the method?
All three of these concepts are represented with :class:`Node` instances.
Let's see what we mean by that with a short example:
::
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(
self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
gm.graph.print_tabular()
Here we define a module ``MyModule`` for demonstration purposes, instantiate it,
symbolically trace it, then call the :meth:`Graph.print_tabular` method to print
out a table showing the nodes of this :class:`Graph`:
+---------------+---------------+----------------------------+--------------------+-------------+
| opcode | name | target | args | kwargs |
+===============+===============+============================+====================+=============+
| placeholder | x | x | () | {} |
+---------------+---------------+----------------------------+--------------------+-------------+
| get_attr | linear_weight | linear.weight | () | {} |
+---------------+---------------+----------------------------+--------------------+-------------+
| call_function | add_1 | <built-in function add> | (x, linear_weight) | {} |
+---------------+---------------+----------------------------+--------------------+-------------+
| call_module | linear_1 | linear | (add_1,) | {} |
+---------------+---------------+----------------------------+--------------------+-------------+
| call_method | relu_1 | relu | (linear_1,) | {} |
+---------------+---------------+----------------------------+--------------------+-------------+
| call_function | sum_1 | <built-in method sum ...> | (relu_1,) | {'dim': -1} |
+---------------+---------------+----------------------------+--------------------+-------------+
| call_function | topk_1 | <built-in method topk ...> | (sum_1, 3) | {} |
+---------------+---------------+----------------------------+--------------------+-------------+
| output | output | output | (topk_1,) | {} |
+---------------+---------------+----------------------------+--------------------+-------------+
We can use this information to answer the questions we posed above.
- What are the inputs to the method? In FX, method inputs are specified
via special ``placeholder`` nodes. In this case, we have a single
``placeholder`` node with a ``target`` of ``x``, meaning we have
a single (non-self) argument named x.
- What are the operations within the method? The ``get_attr``,
``call_function``, ``call_module``, and ``call_method`` nodes
represent the operations in the method. A full treatment of
the semantics of all of these can be found in the :class:`Node`
documentation.
- What is the return value of the method? The return value in a
:class:`Graph` is specified by a special ``output`` node.
Given that we now know the basics of how code is represented in
FX, we can now explore how we would edit a :class:`Graph`.
Given that youve passed in an ``nn.Module`` that has been traced into a
graph, there are now two primary approaches you can take to building a new
graph.
Graph Manipulation
^^^^^^^^^^^^^^^^^^
Direct Graph Manipulation
~~~~~~~~~~~~~~~~~~~~~~~~~
One approach to building this new :class:`Graph` is to directly manipulate your old
one. To aid in this, we can simply take the :class:`Graph` we obtain from symbolic
One approach to building this new graph is to directly manipulate your old
one. To aid in this, we can simply take the graph we obtain from symbolic
tracing and modify it. For example, lets say we desire to replace
:func:`torch.add` calls with :func:`torch.mul` calls.
``torch.add`` with ``torch.mul``.
::
import torch
import torch.fx
# Sample module
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
def transform(m: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph : fx.Graph = tracer_class().trace(m)
# FX represents its Graph as an ordered list of
# nodes, so we can iterate through them.
for node in graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
def transform(m: nn.Module) -> nn.Module:
fx_model: GraphModule = fx.symbolic_trace(m)
# FX represents its graph as an ordered list of nodes, so we can
# iterate through them.
for node in fx_model.graph.nodes:
# Checks if we're calling a function (i.e: torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
# The target attribute is the function that call_function
# calls.
if node.target == torch.add:
node.target = torch.mul
graph.lint() # Does some checks to make sure the
# Graph is well-formed.
fx_model.lint() # Does some checks to make sure the graph is well-formed.
# Regenerates the python code that corresponds to fx_model.
fx_model.recompile()
return fx_model
return fx.GraphModule(m, graph)
We can also do more involved :class:`Graph` rewrites, such as
deleting or appending nodes. To aid in these transformations,
FX has utility functions for transforming the graph that can
be found in the :class:`Graph` documentation. An
example of using these APIs to append a :func:`torch.relu` call
can be found below.
We can also do more involved graph rewrites, such as deleting or appending
nodes. To aid in these transformations, FX has utility
functions for transforming the graph that can be found in :class:`Graph`. An
example of using these APIs to append a relu can be found below.
::
# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
# Insert a new `call_function` node calling `torch.relu`
new_node = traced.graph.call_function(
torch.relu, args=(node,))
# We want all places that used the value of `node` to
# now use that value after the `relu` call we've added.
# We use the `replace_all_uses_with` API to do this.
with traced.graph.inserting_after(node): # Specifies the insertion point
new_node = traced.graph.call_function(torch.relu, args=(node,)) # builds a new relu node
node.replace_all_uses_with(new_node)
This approach is also a good fit for graph optimizations such as
`conv/batch norm
fusion! <https://github.com/pytorch/pytorch/blob/ec86cec20a8a2312a2295d7bc8be6e88256a2de4/torch/fx/experimental/fuser.py>`__
For simple transformations that only consist of substitutions, you can also
make use of the `subgraph rewriter. <https://github.com/pytorch/pytorch/blob/master/torch/fx/subgraph_rewriter.py>`__
Subgraph Rewriting With replace_pattern()
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
FX also provides another level of automation on top of direct graph manipulation.
The :func:`replace_pattern` API is essentially a "find/replace" tool for editing
:class:`Graph`\s. It allows you to specify a ``pattern`` and ``replacement`` function
and it will trace through those functions, find instances of the group of operations
in the ``pattern`` graph, and replace those instances with copies of the ``replacement``
graph. This can help to greatly automate tedious graph manipulation code, which can
get unwieldy as the transformations get more complex.
In general, writing your transformation through graph manipulation is a good
fit if you need to make a few small changes or if you need to match multiple
nodes at once. However, if you need to entirely rewrite your graph, you may
want to look at constructing your graph with Proxies (i.e. retracing).
Graph Manipulation Examples
~~~~~~~~~~~~~~~~~~~~~~~~~~~
- `Replace one
op <https://github.com/pytorch/examples/blob/master/fx/replace_op.py>`__
op <https://github.com/pytorch/pytorch/blob/master/torch/fx/examples/replace_op.py>`__
- `Conv/Batch Norm
fusion <https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py>`__
- `replace_pattern: Basic usage <https://github.com/pytorch/examples/blob/master/fx/subgraph_rewriter_basic_use.py>`__
- `Quantization <https://pytorch.org/docs/master/quantization.html#prototype-fx-graph-mode-quantization>`__
- `Invert Transformation <https://github.com/pytorch/examples/blob/master/fx/invert.py>`__
- `Quantization <https://github.com/pytorch/pytorch/tree/master/torch/quantization/fx>`__
Proxy/Retracing
^^^^^^^^^^^^^^^
Another way of manipulating :class:`Graph`\s is by reusing the :class:`Proxy`
machinery used in symbolic tracing. For example, lets
imagine that we wanted to write a transformation that decomposed
PyTorch functions into smaller operations. It would transform every
``F.relu(x)`` call into ``(x > 0) * x``. One possibility would be to
perform the requisite graph rewriting to insert the comparison and
multiplication after the ``F.relu``, and then clean up the original
``F.relu``. However, we can automate this process by using :class:`Proxy`
objects to automatically record operations into the :class:`Graph`.
Although most transformations can be implemented as graph
transformations, transformations that involve a lot of graph rewrites
are often more easily represented through retracing. For example, lets
imagine that we wanted to write a pass that decomposed
PyTorch functions. It would transform every ``F.relu(x)``
into ``(x > 0)*x``. One possibility would be to perform the requisite
graph rewriting to insert the comparison and multiplication after the
``F.relu``, and then clean up the original ``F.relu``. However, graph
manipulation can be awkward, and its often easier to implicitly
generate the graph by retracing.
To use this method, we write the operations that we want inserted as regular
PyTorch code and invoke that code with :class:`Proxy` objects as arugments.
These :class:`Proxy` objects will capture the operations that are performed
on them and append them to the :class:`Graph`.
To use this method, we write the graph that we want inserted as regular
PyTorch code and pass in Proxy objects. These Proxy objects
will capture the operations that are performed on them and append them to
the graph.
::
# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
return (x > 0) * x
return (x > 0)*x
decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition
def decompose(model: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
"""
Decompose `model` into smaller constituent operations.
Currently,this only supports decomposing ReLU into its
mathematical definition: (x > 0) * x
"""
graph : fx.Graph = tracer_class().trace(model)
def decompose(model: torch.nn.Module) -> torch.nn.Module:
model = fx.symbolic_trace(model)
new_graph = fx.Graph()
for node in graph.nodes:
env = {}
for node in model.graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
# By wrapping the arguments with proxies,
# we can dispatch to the appropriate
# decomposition rule and implicitly add it
# to the Graph by symbolically tracing it.
proxy_args = [
fx.Proxy(env[x.name]) if isinstance(x, fx.Node) else x for x in node.args]
output_proxy = decomposition_rules[node.target](*proxy_args)
# Operations on `Proxy` always yield new `Proxy`s, and the
# return value of our decomposition rule is no exception.
# We need to extract the underlying `Node` from the `Proxy`
# to use it in subsequent iterations of this transform.
new_node = output_proxy.node
# By wrapping the arguments with proxies, we can dispatch to
# the appropriate decomposition rule and add it to the graph by
# symbolically tracing it.
proxy_args = [fx.Proxy(env[x.name]) if isinstance(x, fx.Node) else x for x in node.args]
new_node = decomposition_rules[node.target](*proxy_args).node
env[node.name] = new_node
else:
# Default case: we don't have a decomposition rule for this
# node, so just copy the node over into the new graph.
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
return fx.GraphModule(model, new_graph)
In addition to avoiding explicit graph manipulation, using :class:`Proxy`\s
also allows you to specify your rewrite rules as native Python code.
For transformations that require a large amount of rewrite rules
(such as vmap or grad), this can often improve readability and
maintainability of the rules.
In addition to avoiding explicit graph manipulation, using Proxies also allows you to
specify your rewrite rules as native Python code. For transformations
that require a large amount of rewrite rules (such as vmap or grad),
this can often improve readability and maintainability of the rules.
A worked example of using :class:`Proxy`\s for :class:`Graph` manipulation
can be found
`here <https://github.com/pytorch/examples/blob/master/fx/proxy_based_graph_creation.py>`__.
TODO: Example transformations (need to be included first)
The Interpreter Pattern
^^^^^^^^^^^^^^^^^^^^^^^
A useful code organizational pattern in FX is to loop over all the :class:`Node`\s
in a :class:`Graph` and execute them. This can be used for several things including
runtime analysis of values flowing through the graph or transformation of the code
via retracing with :class:`Proxy`\s. For example, suppose we want to run a
:class:`GraphModule` and record the :class:`torch.Tensor` shape and dtype
properties on the nodes as we see them at runtime. That might look like:
In addition to FX passes that take in a module and return a module,
there may be other things you wish to do with the FX graph. For example,
lets say that youd like to obtain
the shape information of tensors in your graph. In this case, instead of
looping over the FX graph and modifying it, you can write an interpreter
on top of the FX graph! As the FX IR is quite simple, its easy to
reimplement an interpreter that also captures your desired attributes.
::
As this pattern is quite useful, we we can also use an abstraction of this
pattern
-- the `Interpreter
<https://github.com/pytorch/pytorch/blob/master/torch/fx/interpreter.py>`__.
You can see an example using this for `shape propagation
<https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py>`__,
which reinterprets the FX graph with example inputs while annotating the
graph with the shapes.
import torch
import torch.fx
from torch.fx.node import Node
Reinterpreting the FX graph is generally most useful when you want
runtime information that FX typically doesnt capture (due to being a
symbolic trace). This can be used for capturing shape information for
downstream passes, but it can also be used to capture other information
about execution.
TODO: Add roofline analysis pass once it gets merged.
from typing import Dict
class ShapeProp:
"""
Shape propagation. This class takes a `GraphModule`.
Then, its `propagate` method executes the `GraphModule`
node-by-node with the given arguments. As each operation
executes, the ShapeProp class stores away the shape and
element type for the output values of each operation on
the `shape` and `dtype` attributes of the operation's
`Node`.
"""
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
def propagate(self, *args):
args_iter = iter(args)
env : Dict[str, Node] = {}
def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
def fetch_attr(target : str):
target_atoms = target.split('.')
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
for node in self.graph.nodes:
if node.op == 'placeholder':
result = next(args_iter)
elif node.op == 'get_attr':
result = fetch_attr(node.target)
elif node.op == 'call_function':
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'call_method':
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
# This is the only code specific to shape propagation.
# you can delete this `if` branch and this becomes
# a generic GraphModule interpreter.
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype
env[node.name] = result
return load_arg(self.graph.result)
As you can see, a full interpreter for FX is not that complicated
but it can be very useful. To ease using this pattern, we provide
the :class:`Interpreter` class, which encompasses the above logic
in a way that certain aspects of the interpreter's execution can
be overridden via method overrides.
In addition to executing operations, we can also generate a new
`Graph` by feeding :class:`Proxy` values through an interpreter.
Similarly, we provide the :class:`Transformer` class to encompass
this pattern. :class:`Transformer` behaves similarly to
:class:`Interpreter`, but instead of calling the ``run`` method to
get a concrete output value from the Module, you would call the
:meth:`Transformer.transform` method to return a new
:class:`GraphModule` which was subject to any transformation rules
you installed as overridden methods.
Examples of the Interpreter Pattern
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Examples
~~~~~~~~
- `Shape
Propagation <https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/shape_prop.py>`__
- `Performance Profiler <https://github.com/pytorch/tutorials/pull/1319>`__
- `Roofline
Analyzer <https://github.com/pytorch/pytorch/blob/a9f88511b8155ba9620730fb175dee8c54e346d5/torch/fx/experimental/cost_model.py>`__
Debugging
@ -406,72 +192,21 @@ Debugging
Introduction
^^^^^^^^^^^^^^^^
Often in the course of authoring transformations, our code will not be quite right.
In this case, we may need to do some debugging. The key is to work
backwards: first, check the results of invoking the generated module to prove or
disprove correctness. Then, inspect and debug the generated code. Then, debug the
process of transformations that led to the generated code.
After symbolically tracing an ``nn.Module`` and performing some number
of transformations on the resulting GraphModule, we'll want to verify
that the proper semantics were preserved after those transforms. If they
weren't, we may need to do some debugging. The key is to work
backwards: first, check the results of the generated module, then debug
the generated code, then debug the process of transformations that lead
to the generated code.
If youre not familiar with debuggers, please see the auxiliary section
:ref:`Available Debuggers`.
Checking Correctness of Modules
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Because the output of most deep learning modules consists of floating
point :class:`torch.Tensor` instances, checking for equivalence between
the results of two :class:`torch.nn.Module` is not as straightforward
as doing a simple equality check. To motivate this, let's use an
example:
::
import torch
import torch.fx
import torchvision.models as models
def transform(m : torch.nn.Module) -> torch.nn.Module:
gm = torch.fx.symbolic_trace(m)
# Imagine we're doing some transforms here
# <...>
gm.recompile()
return gm
resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)
input_image = torch.randn(5, 3, 224, 224)
assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""
Here, we've tried to check equality of the values of two deep learning
models with the ``==`` equality operator. However, this is not well-
defined both due to the issue of that operator returning a tensor
and not a bool, but also because comparison of floating point values
should use a margin of error (or epsilon) to account for the
non-commutativity of floating point operations (see
`here <https://floating-point-gui.de/errors/comparison/>`__ for more
details). We can use :func:`torch.allclose` instead, which will give
us an approximate comparison taking into account a relative and
absolute tolerance threshold:
::
assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
This is the first tool in our toolbox to check if transformed modules are
behaving as we expect compared to a reference implementation.
Debugging the Generated Code
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Because FX generates the ``forward()`` function on :class:`GraphModule`\s, using
Because FX generates the ``forward()`` function on GraphModules, using
traditional debugging techniques like ``print`` statements or ``pdb`` is
not as straightfoward. Luckily, we have several techniques we can use
for debugging the generated code.
@ -479,34 +214,22 @@ for debugging the generated code.
Use ``pdb``
~~~~~~~~~~~~~
Invoke ``pdb`` to step into the running program. Although the code that
represents the :class:`Graph` is not in any source file, we can still step
represents the FX graph is not in any source file, we can still step
into it manually using ``pdb`` when the forward pass is invoked.
::
import torch
import torch.fx
import torchvision.models as models
def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph = tracer_class().trace(inp)
def my_pass(in: torch.nn.Module) -> torch.nn.Module:
traced = torch.fx.symbolic_trace(in)
# Transformation logic here
# <...>
# Return new Module
return fx.GraphModule(inp, graph)
my_module = models.resnet18()
my_module_transformed = my_pass(my_module)
input_value = torch.randn(5, 3, 224, 224)
return traced
# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()
my_module_transformed(input_value)
my_pass(my_module)
.. _Print the Generated Code:
@ -522,8 +245,6 @@ your code and examine it from there.
# Assume that `traced` is a GraphModule that has undergone some
# number of transforms
# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):
@ -531,6 +252,8 @@ your code and examine it from there.
add_1 = x + y; x = y = None
return add_1
"""
# Copy this code for later
print(traced)
# Subclass the original Module
class SubclassM(M):
@ -538,7 +261,7 @@ your code and examine it from there.
super().__init__()
# Paste the generated `forward` function (the one we printed and
# copied above) here
# copied on line 22) here
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
@ -574,7 +297,7 @@ Debugging the Transformation
Now that we've identified that a transformation is creating incorrect
code, it's time to debug the transformation itself. First, we'll check
the :ref:`Limitations of Symbolic Tracing` section in the documentation.
Once we verify that tracing is working as expected, the goal
Once we verify that ``symbolic_trace`` is working as expected, the goal
becomes figuring out what went wrong during our ``GraphModule``
transformation. There may be a quick answer in
:ref:`Writing Transformations`, but, if not, there are several ways to
@ -596,27 +319,24 @@ examine our traced module:
# sake of brevity.
traced = symbolic_trace(m)
# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
# Print the code produced by tracing the module. The generated `forward`
# function is:
"""
def forward(self, x, y):
add_1 = x + y; x = y = None
return add_1
"""
print(traced)
# Print the internal Graph.
print(traced.graph)
# This print-out returns:
# Print the internal Graph. This representation returns:
"""
graph(x, y):
%add_1 : [#users=1] = call_function[target=<built-in function add>](args = (%x, %y), kwargs = {})
return add_1
"""
print(traced.graph)
# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
# Print a tabular representation of the internal Graph. This gives us:
"""
opcode name target args kwargs
------------- ------ ----------------------- -------- --------
@ -624,9 +344,10 @@ examine our traced module:
placeholder y y () {}
call_function add_1 <built-in function add> (x, y) {}
"""
traced.graph.print_tabular()
Using the utility functions above, we can compare our traced Module
before and after we've applied our transformations. Sometimes, a
before and after we've apply our transformations. Sometimes, a
simple visual comparison is enough to trace down a bug. If it's still
not clear what's going wrong, a debugger like ``pdb`` can be a good
next step.
@ -636,22 +357,26 @@ Going off of the example above, consider the following code:
::
# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
def transform_graph(gm: GraphModule) -> None:
# Get the Graph from our traced Module
g = tracer_class().trace(module)
g = gm.graph
"""
Transformations on `g` go here
"""
return fx.GraphModule(module, g)
# Recompile the GraphModule. This must be called after editing
# the Graph `g`, otherwise the generated code will still reflect
# the old Graph before any transforms
gm.recompile()
# Transform the Graph
transformed = transform_graph(traced)
transform_graph(traced)
# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)
print(traced)
Using the above example, lets say that the call to ``print(traced)``
showed us that there was an error in our transforms. We want to find
@ -704,10 +429,10 @@ Limitations of Symbolic Tracing
FX uses a system of **symbolic tracing** (a.k.a `symbolic
execution <https://en.wikipedia.org/wiki/Symbolic_execution>`__)
to capture the semantics of programs in a transformable/analyzable form.
The system is **tracing** in that it executes the program (really a
:class:`torch.nn.Module` or function) to record operations. It is
The system is **tracing** in that it executes the program (really an
``nn.Module`` or function) to gather this information. It is
**symbolic** in that the data flowing through the program during this
execution is not real data, but rather symbols (:class:`Proxy` in FX parlance).
execution is not real data, but rather symbols (“Proxy” in FX parlance).
Although symbolic tracing works for most neural net code, it has some
limitations.
@ -817,10 +542,10 @@ symbolic tracing:
fx.symbolic_trace(f) # Fails!
def wrapper(flag):
def g(flag):
return lambda x: f(x, flag)
new_f = wrapper(flag=True)
new_f = g(flag=True)
fx.symbolic_trace(new_f)
In the case of truly dynamic control flow, the sections of the program
@ -841,8 +566,6 @@ them in symbolic tracing. For example:
::
import torch
import torch.fx
from math import sqrt
def normalize(x):
@ -1013,3 +736,4 @@ API Reference
:members:
.. autofunction:: torch.fx.replace_pattern

View file

@ -48,24 +48,28 @@ demonstration of these components in action:
return clamp_1
"""
The **symbolic tracer** performs symbolic execution of the Python
The **symbolic tracer** performs abstract interpretation of the Python
code. It feeds fake values, called Proxies, through the code. Operations
on theses Proxies are recorded. More information about symbolic tracing
can be found in the :func:`symbolic_trace` and :class:`Tracer`
can be found in the
`symbolic\_trace <https://pytorch.org/docs/master/fx.html#torch.fx.symbolic_trace>`__
and `Tracer <https://pytorch.org/docs/master/fx.html#torch.fx.Tracer>`__
documentation.
The **intermediate representation** is the container for the operations
that were recorded during symbolic tracing. It consists of a list of
Nodes that represent function inputs, callsites (to functions, methods,
or :class:`torch.nn.Module` instances), and return values. More information
about the IR can be found in the documentation for :class:`Graph`. The
or ``nn.Module`` instances), and return values. More information about
the IR can be found in the documentation for
`Graph <https://pytorch.org/docs/master/fx.html#torch.fx.Graph>`__. The
IR is the format on which transformations are applied.
**Python code generation** is what makes FX a Python-to-Python (or
Module-to-Module) transformation toolkit. For each Graph IR, we can
create valid Python code matching the Graphs semantics. This
functionality is wrapped up in :class:`GraphModule`, which is a
:class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a
functionality is wrapped up in
`GraphModule <https://pytorch.org/docs/master/fx.html#torch.fx.GraphModule>`__,
which is an ``nn.Module`` instance that holds a ``Graph`` as well as a
``forward`` method generated from the Graph.
Taken together, this pipeline of components (symbolic tracing
@ -76,10 +80,6 @@ symbolic tracing can be used in isolation to capture a form of
the code for analysis (and not transformation) purposes. Code
generation can be used for programmatically generating models, for
example from a config file. There are many uses for FX!
Several example transformations can be found at the
`examples <https://github.com/pytorch/examples/tree/master/fx>`__
repository.
'''
from .graph_module import GraphModule