mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert D26246231: [FX] Edits after comprehensive pass over docs
Test Plan: revert-hammer
Differential Revision:
D26246231 (c22bc4821d)
Original commit changeset: 8d6278a9fe1d
fbshipit-source-id: fdc83289f8fe7986bc02181eec55e4e72be2d812
This commit is contained in:
parent
4d85e30133
commit
6c80fd005f
2 changed files with 147 additions and 423 deletions
|
|
@ -17,387 +17,173 @@ What is an FX transform? Essentially, it's a function that looks like this.
|
|||
|
||||
::
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
def transform(m: nn.Module) -> nn.Module:
|
||||
fx_model: GraphModule = fx.symbolice_trace(m)
|
||||
new_model = ...
|
||||
return new_model
|
||||
|
||||
def transform(m: nn.Module,
|
||||
tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
|
||||
# Step 1: Acquire a Graph representing the code in `m`
|
||||
|
||||
# NOTE: torch.fx.symbolic_trace is a wrapper around a call to
|
||||
# fx.Tracer.trace and constructing a GraphModule. We'll
|
||||
# split that out in our transform to allow the caller to
|
||||
# customize tracing behavior.
|
||||
graph : torch.fx.Graph = tracer_class().trace(m)
|
||||
|
||||
# Step 2: Modify this Graph or create a new one
|
||||
graph = ...
|
||||
|
||||
# Step 3: Construct a Module to return
|
||||
return torch.fx.GraphModule(m, graph)
|
||||
|
||||
Your transform will take in an :class:`torch.nn.Module`, acquire a :class:`Graph`
|
||||
from it, do some modifications, and return a new
|
||||
:class:`torch.nn.Module`. You should think of the :class:`torch.nn.Module` that your FX
|
||||
transform returns as identical to a regular :class:`torch.nn.Module` -- you can pass it to another
|
||||
Your transform will take in an :class:`torch.nn.Module`, convert it into a
|
||||
:class:`GraphModule` with :meth:`symbolic_trace`, and return a new
|
||||
``nn.Module``. You should think of the ``nn.Module`` that your FX transform
|
||||
returns as identical to a regular ``nn.Module`` -- you can pass it to another
|
||||
FX transform, you can pass it to TorchScript, or you can
|
||||
run it. Ensuring that the inputs and outputs of your FX transform are a
|
||||
:class:`torch.nn.Module` will allow for composability.
|
||||
``nn.Module`` will allow for composability.
|
||||
|
||||
.. note::
|
||||
|
||||
It is also possible to modify an existing :class:`GraphModule` instead of
|
||||
creating a new one, like so::
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
def transform(m : nn.Module) -> nn.Module):
|
||||
gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)
|
||||
|
||||
# Modify gm.graph
|
||||
# <...>
|
||||
|
||||
# Recompile the forward() method of `gm` from its Graph
|
||||
gm.recompile()
|
||||
|
||||
return gm
|
||||
|
||||
Note that you MUST call :meth:`GraphModule.recompile` to bring the generated
|
||||
``forward()`` method on the ``GraphModule`` in sync with the modified :ref:`Graph`.
|
||||
|
||||
Given that you’ve passed in a :class:`torch.nn.Module` that has been traced into a
|
||||
:class:`Graph`, there are now two primary approaches you can take to building a new
|
||||
:class:`Graph`.
|
||||
|
||||
A Quick Primer on Graphs
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Full treatment of the semantics of graphs can be found in the :class:`Graph`
|
||||
documentation, but we are going to cover the basics here. A :class:`Graph` is
|
||||
a data structure that represents a method on a :class:`GraphModule`. The
|
||||
information that this requires is:
|
||||
|
||||
- What are the inputs to the method?
|
||||
- What are the operations that run inside the method?
|
||||
- What is the output (i.e. return) value from the method?
|
||||
|
||||
All three of these concepts are represented with :class:`Node` instances.
|
||||
Let's see what we mean by that with a short example:
|
||||
|
||||
::
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
||||
self.linear = torch.nn.Linear(4, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.topk(torch.sum(
|
||||
self.linear(x + self.linear.weight).relu(), dim=-1), 3)
|
||||
|
||||
m = MyModule()
|
||||
gm = torch.fx.symbolic_trace(m)
|
||||
|
||||
gm.graph.print_tabular()
|
||||
|
||||
Here we define a module ``MyModule`` for demonstration purposes, instantiate it,
|
||||
symbolically trace it, then call the :meth:`Graph.print_tabular` method to print
|
||||
out a table showing the nodes of this :class:`Graph`:
|
||||
|
||||
+---------------+---------------+----------------------------+--------------------+-------------+
|
||||
| opcode | name | target | args | kwargs |
|
||||
+===============+===============+============================+====================+=============+
|
||||
| placeholder | x | x | () | {} |
|
||||
+---------------+---------------+----------------------------+--------------------+-------------+
|
||||
| get_attr | linear_weight | linear.weight | () | {} |
|
||||
+---------------+---------------+----------------------------+--------------------+-------------+
|
||||
| call_function | add_1 | <built-in function add> | (x, linear_weight) | {} |
|
||||
+---------------+---------------+----------------------------+--------------------+-------------+
|
||||
| call_module | linear_1 | linear | (add_1,) | {} |
|
||||
+---------------+---------------+----------------------------+--------------------+-------------+
|
||||
| call_method | relu_1 | relu | (linear_1,) | {} |
|
||||
+---------------+---------------+----------------------------+--------------------+-------------+
|
||||
| call_function | sum_1 | <built-in method sum ...> | (relu_1,) | {'dim': -1} |
|
||||
+---------------+---------------+----------------------------+--------------------+-------------+
|
||||
| call_function | topk_1 | <built-in method topk ...> | (sum_1, 3) | {} |
|
||||
+---------------+---------------+----------------------------+--------------------+-------------+
|
||||
| output | output | output | (topk_1,) | {} |
|
||||
+---------------+---------------+----------------------------+--------------------+-------------+
|
||||
|
||||
We can use this information to answer the questions we posed above.
|
||||
|
||||
- What are the inputs to the method? In FX, method inputs are specified
|
||||
via special ``placeholder`` nodes. In this case, we have a single
|
||||
``placeholder`` node with a ``target`` of ``x``, meaning we have
|
||||
a single (non-self) argument named x.
|
||||
- What are the operations within the method? The ``get_attr``,
|
||||
``call_function``, ``call_module``, and ``call_method`` nodes
|
||||
represent the operations in the method. A full treatment of
|
||||
the semantics of all of these can be found in the :class:`Node`
|
||||
documentation.
|
||||
- What is the return value of the method? The return value in a
|
||||
:class:`Graph` is specified by a special ``output`` node.
|
||||
|
||||
Given that we now know the basics of how code is represented in
|
||||
FX, we can now explore how we would edit a :class:`Graph`.
|
||||
Given that you’ve passed in an ``nn.Module`` that has been traced into a
|
||||
graph, there are now two primary approaches you can take to building a new
|
||||
graph.
|
||||
|
||||
Graph Manipulation
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Direct Graph Manipulation
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
One approach to building this new :class:`Graph` is to directly manipulate your old
|
||||
one. To aid in this, we can simply take the :class:`Graph` we obtain from symbolic
|
||||
One approach to building this new graph is to directly manipulate your old
|
||||
one. To aid in this, we can simply take the graph we obtain from symbolic
|
||||
tracing and modify it. For example, let’s say we desire to replace
|
||||
:func:`torch.add` calls with :func:`torch.mul` calls.
|
||||
``torch.add`` with ``torch.mul``.
|
||||
|
||||
::
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
# Sample module
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return torch.add(x, y)
|
||||
|
||||
def transform(m: torch.nn.Module,
|
||||
tracer_class : type = fx.Tracer) -> torch.nn.Module:
|
||||
graph : fx.Graph = tracer_class().trace(m)
|
||||
# FX represents its Graph as an ordered list of
|
||||
# nodes, so we can iterate through them.
|
||||
for node in graph.nodes:
|
||||
# Checks if we're calling a function (i.e:
|
||||
# torch.add)
|
||||
def transform(m: nn.Module) -> nn.Module:
|
||||
fx_model: GraphModule = fx.symbolic_trace(m)
|
||||
# FX represents its graph as an ordered list of nodes, so we can
|
||||
# iterate through them.
|
||||
for node in fx_model.graph.nodes:
|
||||
# Checks if we're calling a function (i.e: torch.add)
|
||||
if node.op == 'call_function':
|
||||
# The target attribute is the function
|
||||
# that call_function calls.
|
||||
# The target attribute is the function that call_function
|
||||
# calls.
|
||||
if node.target == torch.add:
|
||||
node.target = torch.mul
|
||||
|
||||
graph.lint() # Does some checks to make sure the
|
||||
# Graph is well-formed.
|
||||
fx_model.lint() # Does some checks to make sure the graph is well-formed.
|
||||
# Regenerates the python code that corresponds to fx_model.
|
||||
fx_model.recompile()
|
||||
return fx_model
|
||||
|
||||
return fx.GraphModule(m, graph)
|
||||
|
||||
|
||||
We can also do more involved :class:`Graph` rewrites, such as
|
||||
deleting or appending nodes. To aid in these transformations,
|
||||
FX has utility functions for transforming the graph that can
|
||||
be found in the :class:`Graph` documentation. An
|
||||
example of using these APIs to append a :func:`torch.relu` call
|
||||
can be found below.
|
||||
We can also do more involved graph rewrites, such as deleting or appending
|
||||
nodes. To aid in these transformations, FX has utility
|
||||
functions for transforming the graph that can be found in :class:`Graph`. An
|
||||
example of using these APIs to append a relu can be found below.
|
||||
|
||||
::
|
||||
|
||||
# Specifies the insertion point. Any nodes added to the
|
||||
# Graph within this scope will be inserted after `node`
|
||||
with traced.graph.inserting_after(node):
|
||||
# Insert a new `call_function` node calling `torch.relu`
|
||||
new_node = traced.graph.call_function(
|
||||
torch.relu, args=(node,))
|
||||
|
||||
# We want all places that used the value of `node` to
|
||||
# now use that value after the `relu` call we've added.
|
||||
# We use the `replace_all_uses_with` API to do this.
|
||||
with traced.graph.inserting_after(node): # Specifies the insertion point
|
||||
new_node = traced.graph.call_function(torch.relu, args=(node,)) # builds a new relu node
|
||||
node.replace_all_uses_with(new_node)
|
||||
|
||||
This approach is also a good fit for graph optimizations such as
|
||||
`conv/batch norm
|
||||
fusion! <https://github.com/pytorch/pytorch/blob/ec86cec20a8a2312a2295d7bc8be6e88256a2de4/torch/fx/experimental/fuser.py>`__
|
||||
|
||||
For simple transformations that only consist of substitutions, you can also
|
||||
make use of the `subgraph rewriter. <https://github.com/pytorch/pytorch/blob/master/torch/fx/subgraph_rewriter.py>`__
|
||||
|
||||
Subgraph Rewriting With replace_pattern()
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
FX also provides another level of automation on top of direct graph manipulation.
|
||||
The :func:`replace_pattern` API is essentially a "find/replace" tool for editing
|
||||
:class:`Graph`\s. It allows you to specify a ``pattern`` and ``replacement`` function
|
||||
and it will trace through those functions, find instances of the group of operations
|
||||
in the ``pattern`` graph, and replace those instances with copies of the ``replacement``
|
||||
graph. This can help to greatly automate tedious graph manipulation code, which can
|
||||
get unwieldy as the transformations get more complex.
|
||||
In general, writing your transformation through graph manipulation is a good
|
||||
fit if you need to make a few small changes or if you need to match multiple
|
||||
nodes at once. However, if you need to entirely rewrite your graph, you may
|
||||
want to look at constructing your graph with Proxies (i.e. retracing).
|
||||
|
||||
Graph Manipulation Examples
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- `Replace one
|
||||
op <https://github.com/pytorch/examples/blob/master/fx/replace_op.py>`__
|
||||
op <https://github.com/pytorch/pytorch/blob/master/torch/fx/examples/replace_op.py>`__
|
||||
- `Conv/Batch Norm
|
||||
fusion <https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py>`__
|
||||
- `replace_pattern: Basic usage <https://github.com/pytorch/examples/blob/master/fx/subgraph_rewriter_basic_use.py>`__
|
||||
- `Quantization <https://pytorch.org/docs/master/quantization.html#prototype-fx-graph-mode-quantization>`__
|
||||
- `Invert Transformation <https://github.com/pytorch/examples/blob/master/fx/invert.py>`__
|
||||
- `Quantization <https://github.com/pytorch/pytorch/tree/master/torch/quantization/fx>`__
|
||||
|
||||
Proxy/Retracing
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
Another way of manipulating :class:`Graph`\s is by reusing the :class:`Proxy`
|
||||
machinery used in symbolic tracing. For example, let’s
|
||||
imagine that we wanted to write a transformation that decomposed
|
||||
PyTorch functions into smaller operations. It would transform every
|
||||
``F.relu(x)`` call into ``(x > 0) * x``. One possibility would be to
|
||||
perform the requisite graph rewriting to insert the comparison and
|
||||
multiplication after the ``F.relu``, and then clean up the original
|
||||
``F.relu``. However, we can automate this process by using :class:`Proxy`
|
||||
objects to automatically record operations into the :class:`Graph`.
|
||||
Although most transformations can be implemented as graph
|
||||
transformations, transformations that involve a lot of graph rewrites
|
||||
are often more easily represented through retracing. For example, let’s
|
||||
imagine that we wanted to write a pass that decomposed
|
||||
PyTorch functions. It would transform every ``F.relu(x)``
|
||||
into ``(x > 0)*x``. One possibility would be to perform the requisite
|
||||
graph rewriting to insert the comparison and multiplication after the
|
||||
``F.relu``, and then clean up the original ``F.relu``. However, graph
|
||||
manipulation can be awkward, and it’s often easier to implicitly
|
||||
generate the graph by retracing.
|
||||
|
||||
To use this method, we write the operations that we want inserted as regular
|
||||
PyTorch code and invoke that code with :class:`Proxy` objects as arugments.
|
||||
These :class:`Proxy` objects will capture the operations that are performed
|
||||
on them and append them to the :class:`Graph`.
|
||||
To use this method, we write the graph that we want inserted as regular
|
||||
PyTorch code and pass in Proxy objects. These Proxy objects
|
||||
will capture the operations that are performed on them and append them to
|
||||
the graph.
|
||||
|
||||
::
|
||||
|
||||
# Note that this decomposition rule can be read as regular Python
|
||||
def relu_decomposition(x):
|
||||
return (x > 0) * x
|
||||
return (x > 0)*x
|
||||
|
||||
decomposition_rules = {}
|
||||
decomposition_rules[F.relu] = relu_decomposition
|
||||
|
||||
def decompose(model: torch.nn.Module,
|
||||
tracer_class : type = fx.Tracer) -> torch.nn.Module:
|
||||
"""
|
||||
Decompose `model` into smaller constituent operations.
|
||||
Currently,this only supports decomposing ReLU into its
|
||||
mathematical definition: (x > 0) * x
|
||||
"""
|
||||
graph : fx.Graph = tracer_class().trace(model)
|
||||
def decompose(model: torch.nn.Module) -> torch.nn.Module:
|
||||
model = fx.symbolic_trace(model)
|
||||
new_graph = fx.Graph()
|
||||
for node in graph.nodes:
|
||||
env = {}
|
||||
for node in model.graph.nodes:
|
||||
if node.op == 'call_function' and node.target in decomposition_rules:
|
||||
# By wrapping the arguments with proxies,
|
||||
# we can dispatch to the appropriate
|
||||
# decomposition rule and implicitly add it
|
||||
# to the Graph by symbolically tracing it.
|
||||
proxy_args = [
|
||||
fx.Proxy(env[x.name]) if isinstance(x, fx.Node) else x for x in node.args]
|
||||
output_proxy = decomposition_rules[node.target](*proxy_args)
|
||||
|
||||
# Operations on `Proxy` always yield new `Proxy`s, and the
|
||||
# return value of our decomposition rule is no exception.
|
||||
# We need to extract the underlying `Node` from the `Proxy`
|
||||
# to use it in subsequent iterations of this transform.
|
||||
new_node = output_proxy.node
|
||||
# By wrapping the arguments with proxies, we can dispatch to
|
||||
# the appropriate decomposition rule and add it to the graph by
|
||||
# symbolically tracing it.
|
||||
proxy_args = [fx.Proxy(env[x.name]) if isinstance(x, fx.Node) else x for x in node.args]
|
||||
new_node = decomposition_rules[node.target](*proxy_args).node
|
||||
env[node.name] = new_node
|
||||
else:
|
||||
# Default case: we don't have a decomposition rule for this
|
||||
# node, so just copy the node over into the new graph.
|
||||
new_node = new_graph.node_copy(node, lambda x: env[x.name])
|
||||
env[node.name] = new_node
|
||||
return fx.GraphModule(model, new_graph)
|
||||
|
||||
In addition to avoiding explicit graph manipulation, using :class:`Proxy`\s
|
||||
also allows you to specify your rewrite rules as native Python code.
|
||||
For transformations that require a large amount of rewrite rules
|
||||
(such as vmap or grad), this can often improve readability and
|
||||
maintainability of the rules.
|
||||
In addition to avoiding explicit graph manipulation, using Proxies also allows you to
|
||||
specify your rewrite rules as native Python code. For transformations
|
||||
that require a large amount of rewrite rules (such as vmap or grad),
|
||||
this can often improve readability and maintainability of the rules.
|
||||
|
||||
A worked example of using :class:`Proxy`\s for :class:`Graph` manipulation
|
||||
can be found
|
||||
`here <https://github.com/pytorch/examples/blob/master/fx/proxy_based_graph_creation.py>`__.
|
||||
TODO: Example transformations (need to be included first)
|
||||
|
||||
The Interpreter Pattern
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
A useful code organizational pattern in FX is to loop over all the :class:`Node`\s
|
||||
in a :class:`Graph` and execute them. This can be used for several things including
|
||||
runtime analysis of values flowing through the graph or transformation of the code
|
||||
via retracing with :class:`Proxy`\s. For example, suppose we want to run a
|
||||
:class:`GraphModule` and record the :class:`torch.Tensor` shape and dtype
|
||||
properties on the nodes as we see them at runtime. That might look like:
|
||||
In addition to FX passes that take in a module and return a module,
|
||||
there may be other things you wish to do with the FX graph. For example,
|
||||
let’s say that you’d like to obtain
|
||||
the shape information of tensors in your graph. In this case, instead of
|
||||
looping over the FX graph and modifying it, you can write an interpreter
|
||||
on top of the FX graph! As the FX IR is quite simple, it’s easy to
|
||||
reimplement an interpreter that also captures your desired attributes.
|
||||
|
||||
::
|
||||
As this pattern is quite useful, we we can also use an abstraction of this
|
||||
pattern
|
||||
-- the `Interpreter
|
||||
<https://github.com/pytorch/pytorch/blob/master/torch/fx/interpreter.py>`__.
|
||||
You can see an example using this for `shape propagation
|
||||
<https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py>`__,
|
||||
which reinterprets the FX graph with example inputs while annotating the
|
||||
graph with the shapes.
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.node import Node
|
||||
Reinterpreting the FX graph is generally most useful when you want
|
||||
runtime information that FX typically doesn’t capture (due to being a
|
||||
symbolic trace). This can be used for capturing shape information for
|
||||
downstream passes, but it can also be used to capture other information
|
||||
about execution.
|
||||
TODO: Add roofline analysis pass once it gets merged.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
class ShapeProp:
|
||||
"""
|
||||
Shape propagation. This class takes a `GraphModule`.
|
||||
Then, its `propagate` method executes the `GraphModule`
|
||||
node-by-node with the given arguments. As each operation
|
||||
executes, the ShapeProp class stores away the shape and
|
||||
element type for the output values of each operation on
|
||||
the `shape` and `dtype` attributes of the operation's
|
||||
`Node`.
|
||||
"""
|
||||
def __init__(self, mod):
|
||||
self.mod = mod
|
||||
self.graph = mod.graph
|
||||
self.modules = dict(self.mod.named_modules())
|
||||
|
||||
def propagate(self, *args):
|
||||
args_iter = iter(args)
|
||||
env : Dict[str, Node] = {}
|
||||
|
||||
def load_arg(a):
|
||||
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
|
||||
|
||||
def fetch_attr(target : str):
|
||||
target_atoms = target.split('.')
|
||||
attr_itr = self.mod
|
||||
for i, atom in enumerate(target_atoms):
|
||||
if not hasattr(attr_itr, atom):
|
||||
raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
return attr_itr
|
||||
|
||||
for node in self.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
result = next(args_iter)
|
||||
elif node.op == 'get_attr':
|
||||
result = fetch_attr(node.target)
|
||||
elif node.op == 'call_function':
|
||||
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
|
||||
elif node.op == 'call_method':
|
||||
self_obj, *args = load_arg(node.args)
|
||||
kwargs = load_arg(node.kwargs)
|
||||
result = getattr(self_obj, node.target)(*args, **kwargs)
|
||||
elif node.op == 'call_module':
|
||||
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
|
||||
|
||||
# This is the only code specific to shape propagation.
|
||||
# you can delete this `if` branch and this becomes
|
||||
# a generic GraphModule interpreter.
|
||||
if isinstance(result, torch.Tensor):
|
||||
node.shape = result.shape
|
||||
node.dtype = result.dtype
|
||||
|
||||
env[node.name] = result
|
||||
|
||||
return load_arg(self.graph.result)
|
||||
|
||||
As you can see, a full interpreter for FX is not that complicated
|
||||
but it can be very useful. To ease using this pattern, we provide
|
||||
the :class:`Interpreter` class, which encompasses the above logic
|
||||
in a way that certain aspects of the interpreter's execution can
|
||||
be overridden via method overrides.
|
||||
|
||||
In addition to executing operations, we can also generate a new
|
||||
`Graph` by feeding :class:`Proxy` values through an interpreter.
|
||||
Similarly, we provide the :class:`Transformer` class to encompass
|
||||
this pattern. :class:`Transformer` behaves similarly to
|
||||
:class:`Interpreter`, but instead of calling the ``run`` method to
|
||||
get a concrete output value from the Module, you would call the
|
||||
:meth:`Transformer.transform` method to return a new
|
||||
:class:`GraphModule` which was subject to any transformation rules
|
||||
you installed as overridden methods.
|
||||
|
||||
Examples of the Interpreter Pattern
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
Examples
|
||||
~~~~~~~~
|
||||
|
||||
- `Shape
|
||||
Propagation <https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/shape_prop.py>`__
|
||||
- `Performance Profiler <https://github.com/pytorch/tutorials/pull/1319>`__
|
||||
- `Roofline
|
||||
Analyzer <https://github.com/pytorch/pytorch/blob/a9f88511b8155ba9620730fb175dee8c54e346d5/torch/fx/experimental/cost_model.py>`__
|
||||
|
||||
|
||||
Debugging
|
||||
|
|
@ -406,72 +192,21 @@ Debugging
|
|||
Introduction
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
Often in the course of authoring transformations, our code will not be quite right.
|
||||
In this case, we may need to do some debugging. The key is to work
|
||||
backwards: first, check the results of invoking the generated module to prove or
|
||||
disprove correctness. Then, inspect and debug the generated code. Then, debug the
|
||||
process of transformations that led to the generated code.
|
||||
After symbolically tracing an ``nn.Module`` and performing some number
|
||||
of transformations on the resulting GraphModule, we'll want to verify
|
||||
that the proper semantics were preserved after those transforms. If they
|
||||
weren't, we may need to do some debugging. The key is to work
|
||||
backwards: first, check the results of the generated module, then debug
|
||||
the generated code, then debug the process of transformations that lead
|
||||
to the generated code.
|
||||
|
||||
If you’re not familiar with debuggers, please see the auxiliary section
|
||||
:ref:`Available Debuggers`.
|
||||
|
||||
Checking Correctness of Modules
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Because the output of most deep learning modules consists of floating
|
||||
point :class:`torch.Tensor` instances, checking for equivalence between
|
||||
the results of two :class:`torch.nn.Module` is not as straightforward
|
||||
as doing a simple equality check. To motivate this, let's use an
|
||||
example:
|
||||
|
||||
::
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
import torchvision.models as models
|
||||
|
||||
def transform(m : torch.nn.Module) -> torch.nn.Module:
|
||||
gm = torch.fx.symbolic_trace(m)
|
||||
|
||||
# Imagine we're doing some transforms here
|
||||
# <...>
|
||||
|
||||
gm.recompile()
|
||||
|
||||
return gm
|
||||
|
||||
resnet18 = models.resnet18()
|
||||
transformed_resnet18 = transform(resnet18)
|
||||
|
||||
input_image = torch.randn(5, 3, 224, 224)
|
||||
|
||||
assert resnet18(input_image) == transformed_resnet18(input_image)
|
||||
"""
|
||||
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
|
||||
"""
|
||||
|
||||
Here, we've tried to check equality of the values of two deep learning
|
||||
models with the ``==`` equality operator. However, this is not well-
|
||||
defined both due to the issue of that operator returning a tensor
|
||||
and not a bool, but also because comparison of floating point values
|
||||
should use a margin of error (or epsilon) to account for the
|
||||
non-commutativity of floating point operations (see
|
||||
`here <https://floating-point-gui.de/errors/comparison/>`__ for more
|
||||
details). We can use :func:`torch.allclose` instead, which will give
|
||||
us an approximate comparison taking into account a relative and
|
||||
absolute tolerance threshold:
|
||||
|
||||
::
|
||||
|
||||
assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
|
||||
|
||||
This is the first tool in our toolbox to check if transformed modules are
|
||||
behaving as we expect compared to a reference implementation.
|
||||
|
||||
Debugging the Generated Code
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Because FX generates the ``forward()`` function on :class:`GraphModule`\s, using
|
||||
Because FX generates the ``forward()`` function on GraphModules, using
|
||||
traditional debugging techniques like ``print`` statements or ``pdb`` is
|
||||
not as straightfoward. Luckily, we have several techniques we can use
|
||||
for debugging the generated code.
|
||||
|
|
@ -479,34 +214,22 @@ for debugging the generated code.
|
|||
Use ``pdb``
|
||||
~~~~~~~~~~~~~
|
||||
Invoke ``pdb`` to step into the running program. Although the code that
|
||||
represents the :class:`Graph` is not in any source file, we can still step
|
||||
represents the FX graph is not in any source file, we can still step
|
||||
into it manually using ``pdb`` when the forward pass is invoked.
|
||||
|
||||
::
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
import torchvision.models as models
|
||||
|
||||
def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
|
||||
graph = tracer_class().trace(inp)
|
||||
def my_pass(in: torch.nn.Module) -> torch.nn.Module:
|
||||
traced = torch.fx.symbolic_trace(in)
|
||||
# Transformation logic here
|
||||
# <...>
|
||||
|
||||
# Return new Module
|
||||
return fx.GraphModule(inp, graph)
|
||||
|
||||
my_module = models.resnet18()
|
||||
my_module_transformed = my_pass(my_module)
|
||||
|
||||
input_value = torch.randn(5, 3, 224, 224)
|
||||
return traced
|
||||
|
||||
# When this line is executed at runtime, we will be dropped into an
|
||||
# interactive `pdb` prompt. We can use the `step` or `s` command to
|
||||
# step into the execution of the next line
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
my_module_transformed(input_value)
|
||||
my_pass(my_module)
|
||||
|
||||
.. _Print the Generated Code:
|
||||
|
||||
|
|
@ -522,8 +245,6 @@ your code and examine it from there.
|
|||
# Assume that `traced` is a GraphModule that has undergone some
|
||||
# number of transforms
|
||||
|
||||
# Copy this code for later
|
||||
print(traced)
|
||||
# Print the code generated from symbolic tracing. This outputs:
|
||||
"""
|
||||
def forward(self, y):
|
||||
|
|
@ -531,6 +252,8 @@ your code and examine it from there.
|
|||
add_1 = x + y; x = y = None
|
||||
return add_1
|
||||
"""
|
||||
# Copy this code for later
|
||||
print(traced)
|
||||
|
||||
# Subclass the original Module
|
||||
class SubclassM(M):
|
||||
|
|
@ -538,7 +261,7 @@ your code and examine it from there.
|
|||
super().__init__()
|
||||
|
||||
# Paste the generated `forward` function (the one we printed and
|
||||
# copied above) here
|
||||
# copied on line 22) here
|
||||
def forward(self, y):
|
||||
x = self.x
|
||||
add_1 = x + y; x = y = None
|
||||
|
|
@ -574,7 +297,7 @@ Debugging the Transformation
|
|||
Now that we've identified that a transformation is creating incorrect
|
||||
code, it's time to debug the transformation itself. First, we'll check
|
||||
the :ref:`Limitations of Symbolic Tracing` section in the documentation.
|
||||
Once we verify that tracing is working as expected, the goal
|
||||
Once we verify that ``symbolic_trace`` is working as expected, the goal
|
||||
becomes figuring out what went wrong during our ``GraphModule``
|
||||
transformation. There may be a quick answer in
|
||||
:ref:`Writing Transformations`, but, if not, there are several ways to
|
||||
|
|
@ -596,27 +319,24 @@ examine our traced module:
|
|||
# sake of brevity.
|
||||
traced = symbolic_trace(m)
|
||||
|
||||
# Print the code produced by tracing the module.
|
||||
print(traced)
|
||||
# The generated `forward` function is:
|
||||
# Print the code produced by tracing the module. The generated `forward`
|
||||
# function is:
|
||||
"""
|
||||
def forward(self, x, y):
|
||||
add_1 = x + y; x = y = None
|
||||
return add_1
|
||||
"""
|
||||
print(traced)
|
||||
|
||||
# Print the internal Graph.
|
||||
print(traced.graph)
|
||||
# This print-out returns:
|
||||
# Print the internal Graph. This representation returns:
|
||||
"""
|
||||
graph(x, y):
|
||||
%add_1 : [#users=1] = call_function[target=<built-in function add>](args = (%x, %y), kwargs = {})
|
||||
return add_1
|
||||
"""
|
||||
print(traced.graph)
|
||||
|
||||
# Print a tabular representation of the internal Graph.
|
||||
traced.graph.print_tabular()
|
||||
# This gives us:
|
||||
# Print a tabular representation of the internal Graph. This gives us:
|
||||
"""
|
||||
opcode name target args kwargs
|
||||
------------- ------ ----------------------- -------- --------
|
||||
|
|
@ -624,9 +344,10 @@ examine our traced module:
|
|||
placeholder y y () {}
|
||||
call_function add_1 <built-in function add> (x, y) {}
|
||||
"""
|
||||
traced.graph.print_tabular()
|
||||
|
||||
Using the utility functions above, we can compare our traced Module
|
||||
before and after we've applied our transformations. Sometimes, a
|
||||
before and after we've apply our transformations. Sometimes, a
|
||||
simple visual comparison is enough to trace down a bug. If it's still
|
||||
not clear what's going wrong, a debugger like ``pdb`` can be a good
|
||||
next step.
|
||||
|
|
@ -636,22 +357,26 @@ Going off of the example above, consider the following code:
|
|||
::
|
||||
|
||||
# Sample user-defined function
|
||||
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
|
||||
def transform_graph(gm: GraphModule) -> None:
|
||||
|
||||
# Get the Graph from our traced Module
|
||||
g = tracer_class().trace(module)
|
||||
g = gm.graph
|
||||
|
||||
"""
|
||||
Transformations on `g` go here
|
||||
"""
|
||||
|
||||
return fx.GraphModule(module, g)
|
||||
# Recompile the GraphModule. This must be called after editing
|
||||
# the Graph `g`, otherwise the generated code will still reflect
|
||||
# the old Graph before any transforms
|
||||
gm.recompile()
|
||||
|
||||
# Transform the Graph
|
||||
transformed = transform_graph(traced)
|
||||
transform_graph(traced)
|
||||
|
||||
# Print the new code after our transforms. Check to see if it was
|
||||
# what we expected
|
||||
print(transformed)
|
||||
print(traced)
|
||||
|
||||
Using the above example, let’s say that the call to ``print(traced)``
|
||||
showed us that there was an error in our transforms. We want to find
|
||||
|
|
@ -704,10 +429,10 @@ Limitations of Symbolic Tracing
|
|||
FX uses a system of **symbolic tracing** (a.k.a `symbolic
|
||||
execution <https://en.wikipedia.org/wiki/Symbolic_execution>`__)
|
||||
to capture the semantics of programs in a transformable/analyzable form.
|
||||
The system is **tracing** in that it executes the program (really a
|
||||
:class:`torch.nn.Module` or function) to record operations. It is
|
||||
The system is **tracing** in that it executes the program (really an
|
||||
``nn.Module`` or function) to gather this information. It is
|
||||
**symbolic** in that the data flowing through the program during this
|
||||
execution is not real data, but rather symbols (:class:`Proxy` in FX parlance).
|
||||
execution is not real data, but rather symbols (“Proxy” in FX parlance).
|
||||
|
||||
Although symbolic tracing works for most neural net code, it has some
|
||||
limitations.
|
||||
|
|
@ -817,10 +542,10 @@ symbolic tracing:
|
|||
|
||||
fx.symbolic_trace(f) # Fails!
|
||||
|
||||
def wrapper(flag):
|
||||
def g(flag):
|
||||
return lambda x: f(x, flag)
|
||||
|
||||
new_f = wrapper(flag=True)
|
||||
new_f = g(flag=True)
|
||||
fx.symbolic_trace(new_f)
|
||||
|
||||
In the case of truly dynamic control flow, the sections of the program
|
||||
|
|
@ -841,8 +566,6 @@ them in symbolic tracing. For example:
|
|||
|
||||
::
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from math import sqrt
|
||||
|
||||
def normalize(x):
|
||||
|
|
@ -1013,3 +736,4 @@ API Reference
|
|||
:members:
|
||||
|
||||
.. autofunction:: torch.fx.replace_pattern
|
||||
|
||||
|
|
|
|||
|
|
@ -48,24 +48,28 @@ demonstration of these components in action:
|
|||
return clamp_1
|
||||
"""
|
||||
|
||||
The **symbolic tracer** performs “symbolic execution” of the Python
|
||||
The **symbolic tracer** performs “abstract interpretation” of the Python
|
||||
code. It feeds fake values, called Proxies, through the code. Operations
|
||||
on theses Proxies are recorded. More information about symbolic tracing
|
||||
can be found in the :func:`symbolic_trace` and :class:`Tracer`
|
||||
can be found in the
|
||||
`symbolic\_trace <https://pytorch.org/docs/master/fx.html#torch.fx.symbolic_trace>`__
|
||||
and `Tracer <https://pytorch.org/docs/master/fx.html#torch.fx.Tracer>`__
|
||||
documentation.
|
||||
|
||||
The **intermediate representation** is the container for the operations
|
||||
that were recorded during symbolic tracing. It consists of a list of
|
||||
Nodes that represent function inputs, callsites (to functions, methods,
|
||||
or :class:`torch.nn.Module` instances), and return values. More information
|
||||
about the IR can be found in the documentation for :class:`Graph`. The
|
||||
or ``nn.Module`` instances), and return values. More information about
|
||||
the IR can be found in the documentation for
|
||||
`Graph <https://pytorch.org/docs/master/fx.html#torch.fx.Graph>`__. The
|
||||
IR is the format on which transformations are applied.
|
||||
|
||||
**Python code generation** is what makes FX a Python-to-Python (or
|
||||
Module-to-Module) transformation toolkit. For each Graph IR, we can
|
||||
create valid Python code matching the Graph’s semantics. This
|
||||
functionality is wrapped up in :class:`GraphModule`, which is a
|
||||
:class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a
|
||||
functionality is wrapped up in
|
||||
`GraphModule <https://pytorch.org/docs/master/fx.html#torch.fx.GraphModule>`__,
|
||||
which is an ``nn.Module`` instance that holds a ``Graph`` as well as a
|
||||
``forward`` method generated from the Graph.
|
||||
|
||||
Taken together, this pipeline of components (symbolic tracing →
|
||||
|
|
@ -76,10 +80,6 @@ symbolic tracing can be used in isolation to capture a form of
|
|||
the code for analysis (and not transformation) purposes. Code
|
||||
generation can be used for programmatically generating models, for
|
||||
example from a config file. There are many uses for FX!
|
||||
|
||||
Several example transformations can be found at the
|
||||
`examples <https://github.com/pytorch/examples/tree/master/fx>`__
|
||||
repository.
|
||||
'''
|
||||
|
||||
from .graph_module import GraphModule
|
||||
|
|
|
|||
Loading…
Reference in a new issue