From 6c80fd005f23a55b3e4e655e867e0eed493ee416 Mon Sep 17 00:00:00 2001 From: Alban Desmaison Date: Thu, 4 Feb 2021 09:22:27 -0800 Subject: [PATCH] Revert D26246231: [FX] Edits after comprehensive pass over docs Test Plan: revert-hammer Differential Revision: D26246231 (https://github.com/pytorch/pytorch/commit/c22bc4821d88d72a56626ad8c827edeac92619d3) Original commit changeset: 8d6278a9fe1d fbshipit-source-id: fdc83289f8fe7986bc02181eec55e4e72be2d812 --- docs/source/fx.rst | 550 +++++++++++-------------------------------- torch/fx/__init__.py | 20 +- 2 files changed, 147 insertions(+), 423 deletions(-) diff --git a/docs/source/fx.rst b/docs/source/fx.rst index 5395dc345da..6ebc03b69a0 100644 --- a/docs/source/fx.rst +++ b/docs/source/fx.rst @@ -17,387 +17,173 @@ What is an FX transform? Essentially, it's a function that looks like this. :: - import torch - import torch.fx + def transform(m: nn.Module) -> nn.Module: + fx_model: GraphModule = fx.symbolice_trace(m) + new_model = ... + return new_model - def transform(m: nn.Module, - tracer_class : type = torch.fx.Tracer) -> torch.nn.Module: - # Step 1: Acquire a Graph representing the code in `m` - - # NOTE: torch.fx.symbolic_trace is a wrapper around a call to - # fx.Tracer.trace and constructing a GraphModule. We'll - # split that out in our transform to allow the caller to - # customize tracing behavior. - graph : torch.fx.Graph = tracer_class().trace(m) - - # Step 2: Modify this Graph or create a new one - graph = ... - - # Step 3: Construct a Module to return - return torch.fx.GraphModule(m, graph) - -Your transform will take in an :class:`torch.nn.Module`, acquire a :class:`Graph` -from it, do some modifications, and return a new -:class:`torch.nn.Module`. You should think of the :class:`torch.nn.Module` that your FX -transform returns as identical to a regular :class:`torch.nn.Module` -- you can pass it to another +Your transform will take in an :class:`torch.nn.Module`, convert it into a +:class:`GraphModule` with :meth:`symbolic_trace`, and return a new +``nn.Module``. You should think of the ``nn.Module`` that your FX transform +returns as identical to a regular ``nn.Module`` -- you can pass it to another FX transform, you can pass it to TorchScript, or you can run it. Ensuring that the inputs and outputs of your FX transform are a -:class:`torch.nn.Module` will allow for composability. +``nn.Module`` will allow for composability. -.. note:: - - It is also possible to modify an existing :class:`GraphModule` instead of - creating a new one, like so:: - - import torch - import torch.fx - - def transform(m : nn.Module) -> nn.Module): - gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m) - - # Modify gm.graph - # <...> - - # Recompile the forward() method of `gm` from its Graph - gm.recompile() - - return gm - - Note that you MUST call :meth:`GraphModule.recompile` to bring the generated - ``forward()`` method on the ``GraphModule`` in sync with the modified :ref:`Graph`. - -Given that you’ve passed in a :class:`torch.nn.Module` that has been traced into a -:class:`Graph`, there are now two primary approaches you can take to building a new -:class:`Graph`. - -A Quick Primer on Graphs -^^^^^^^^^^^^^^^^^^^^^^^^ - -Full treatment of the semantics of graphs can be found in the :class:`Graph` -documentation, but we are going to cover the basics here. A :class:`Graph` is -a data structure that represents a method on a :class:`GraphModule`. The -information that this requires is: - -- What are the inputs to the method? -- What are the operations that run inside the method? -- What is the output (i.e. return) value from the method? - -All three of these concepts are represented with :class:`Node` instances. -Let's see what we mean by that with a short example: - -:: - - import torch - import torch.fx - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - return torch.topk(torch.sum( - self.linear(x + self.linear.weight).relu(), dim=-1), 3) - - m = MyModule() - gm = torch.fx.symbolic_trace(m) - - gm.graph.print_tabular() - -Here we define a module ``MyModule`` for demonstration purposes, instantiate it, -symbolically trace it, then call the :meth:`Graph.print_tabular` method to print -out a table showing the nodes of this :class:`Graph`: - - +---------------+---------------+----------------------------+--------------------+-------------+ - | opcode | name | target | args | kwargs | - +===============+===============+============================+====================+=============+ - | placeholder | x | x | () | {} | - +---------------+---------------+----------------------------+--------------------+-------------+ - | get_attr | linear_weight | linear.weight | () | {} | - +---------------+---------------+----------------------------+--------------------+-------------+ - | call_function | add_1 | | (x, linear_weight) | {} | - +---------------+---------------+----------------------------+--------------------+-------------+ - | call_module | linear_1 | linear | (add_1,) | {} | - +---------------+---------------+----------------------------+--------------------+-------------+ - | call_method | relu_1 | relu | (linear_1,) | {} | - +---------------+---------------+----------------------------+--------------------+-------------+ - | call_function | sum_1 | | (relu_1,) | {'dim': -1} | - +---------------+---------------+----------------------------+--------------------+-------------+ - | call_function | topk_1 | | (sum_1, 3) | {} | - +---------------+---------------+----------------------------+--------------------+-------------+ - | output | output | output | (topk_1,) | {} | - +---------------+---------------+----------------------------+--------------------+-------------+ - -We can use this information to answer the questions we posed above. - -- What are the inputs to the method? In FX, method inputs are specified - via special ``placeholder`` nodes. In this case, we have a single - ``placeholder`` node with a ``target`` of ``x``, meaning we have - a single (non-self) argument named x. -- What are the operations within the method? The ``get_attr``, - ``call_function``, ``call_module``, and ``call_method`` nodes - represent the operations in the method. A full treatment of - the semantics of all of these can be found in the :class:`Node` - documentation. -- What is the return value of the method? The return value in a - :class:`Graph` is specified by a special ``output`` node. - -Given that we now know the basics of how code is represented in -FX, we can now explore how we would edit a :class:`Graph`. +Given that you’ve passed in an ``nn.Module`` that has been traced into a +graph, there are now two primary approaches you can take to building a new +graph. Graph Manipulation ^^^^^^^^^^^^^^^^^^ -Direct Graph Manipulation -~~~~~~~~~~~~~~~~~~~~~~~~~ - -One approach to building this new :class:`Graph` is to directly manipulate your old -one. To aid in this, we can simply take the :class:`Graph` we obtain from symbolic +One approach to building this new graph is to directly manipulate your old +one. To aid in this, we can simply take the graph we obtain from symbolic tracing and modify it. For example, let’s say we desire to replace -:func:`torch.add` calls with :func:`torch.mul` calls. +``torch.add`` with ``torch.mul``. :: - import torch - import torch.fx - # Sample module class M(torch.nn.Module): def forward(self, x, y): return torch.add(x, y) - def transform(m: torch.nn.Module, - tracer_class : type = fx.Tracer) -> torch.nn.Module: - graph : fx.Graph = tracer_class().trace(m) - # FX represents its Graph as an ordered list of - # nodes, so we can iterate through them. - for node in graph.nodes: - # Checks if we're calling a function (i.e: - # torch.add) + def transform(m: nn.Module) -> nn.Module: + fx_model: GraphModule = fx.symbolic_trace(m) + # FX represents its graph as an ordered list of nodes, so we can + # iterate through them. + for node in fx_model.graph.nodes: + # Checks if we're calling a function (i.e: torch.add) if node.op == 'call_function': - # The target attribute is the function - # that call_function calls. + # The target attribute is the function that call_function + # calls. if node.target == torch.add: node.target = torch.mul - graph.lint() # Does some checks to make sure the - # Graph is well-formed. + fx_model.lint() # Does some checks to make sure the graph is well-formed. + # Regenerates the python code that corresponds to fx_model. + fx_model.recompile() + return fx_model - return fx.GraphModule(m, graph) - - -We can also do more involved :class:`Graph` rewrites, such as -deleting or appending nodes. To aid in these transformations, -FX has utility functions for transforming the graph that can -be found in the :class:`Graph` documentation. An -example of using these APIs to append a :func:`torch.relu` call -can be found below. +We can also do more involved graph rewrites, such as deleting or appending +nodes. To aid in these transformations, FX has utility +functions for transforming the graph that can be found in :class:`Graph`. An +example of using these APIs to append a relu can be found below. :: - # Specifies the insertion point. Any nodes added to the - # Graph within this scope will be inserted after `node` - with traced.graph.inserting_after(node): - # Insert a new `call_function` node calling `torch.relu` - new_node = traced.graph.call_function( - torch.relu, args=(node,)) - - # We want all places that used the value of `node` to - # now use that value after the `relu` call we've added. - # We use the `replace_all_uses_with` API to do this. + with traced.graph.inserting_after(node): # Specifies the insertion point + new_node = traced.graph.call_function(torch.relu, args=(node,)) # builds a new relu node node.replace_all_uses_with(new_node) +This approach is also a good fit for graph optimizations such as +`conv/batch norm +fusion! `__ + For simple transformations that only consist of substitutions, you can also make use of the `subgraph rewriter. `__ -Subgraph Rewriting With replace_pattern() -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -FX also provides another level of automation on top of direct graph manipulation. -The :func:`replace_pattern` API is essentially a "find/replace" tool for editing -:class:`Graph`\s. It allows you to specify a ``pattern`` and ``replacement`` function -and it will trace through those functions, find instances of the group of operations -in the ``pattern`` graph, and replace those instances with copies of the ``replacement`` -graph. This can help to greatly automate tedious graph manipulation code, which can -get unwieldy as the transformations get more complex. +In general, writing your transformation through graph manipulation is a good +fit if you need to make a few small changes or if you need to match multiple +nodes at once. However, if you need to entirely rewrite your graph, you may +want to look at constructing your graph with Proxies (i.e. retracing). Graph Manipulation Examples ~~~~~~~~~~~~~~~~~~~~~~~~~~~ - `Replace one - op `__ + op `__ - `Conv/Batch Norm fusion `__ -- `replace_pattern: Basic usage `__ -- `Quantization `__ -- `Invert Transformation `__ +- `Quantization `__ Proxy/Retracing ^^^^^^^^^^^^^^^ -Another way of manipulating :class:`Graph`\s is by reusing the :class:`Proxy` -machinery used in symbolic tracing. For example, let’s -imagine that we wanted to write a transformation that decomposed -PyTorch functions into smaller operations. It would transform every -``F.relu(x)`` call into ``(x > 0) * x``. One possibility would be to -perform the requisite graph rewriting to insert the comparison and -multiplication after the ``F.relu``, and then clean up the original -``F.relu``. However, we can automate this process by using :class:`Proxy` -objects to automatically record operations into the :class:`Graph`. +Although most transformations can be implemented as graph +transformations, transformations that involve a lot of graph rewrites +are often more easily represented through retracing. For example, let’s +imagine that we wanted to write a pass that decomposed +PyTorch functions. It would transform every ``F.relu(x)`` +into ``(x > 0)*x``. One possibility would be to perform the requisite +graph rewriting to insert the comparison and multiplication after the +``F.relu``, and then clean up the original ``F.relu``. However, graph +manipulation can be awkward, and it’s often easier to implicitly +generate the graph by retracing. -To use this method, we write the operations that we want inserted as regular -PyTorch code and invoke that code with :class:`Proxy` objects as arugments. -These :class:`Proxy` objects will capture the operations that are performed -on them and append them to the :class:`Graph`. +To use this method, we write the graph that we want inserted as regular +PyTorch code and pass in Proxy objects. These Proxy objects +will capture the operations that are performed on them and append them to +the graph. :: # Note that this decomposition rule can be read as regular Python def relu_decomposition(x): - return (x > 0) * x + return (x > 0)*x decomposition_rules = {} decomposition_rules[F.relu] = relu_decomposition - def decompose(model: torch.nn.Module, - tracer_class : type = fx.Tracer) -> torch.nn.Module: - """ - Decompose `model` into smaller constituent operations. - Currently,this only supports decomposing ReLU into its - mathematical definition: (x > 0) * x - """ - graph : fx.Graph = tracer_class().trace(model) + def decompose(model: torch.nn.Module) -> torch.nn.Module: + model = fx.symbolic_trace(model) new_graph = fx.Graph() - for node in graph.nodes: + env = {} + for node in model.graph.nodes: if node.op == 'call_function' and node.target in decomposition_rules: - # By wrapping the arguments with proxies, - # we can dispatch to the appropriate - # decomposition rule and implicitly add it - # to the Graph by symbolically tracing it. - proxy_args = [ - fx.Proxy(env[x.name]) if isinstance(x, fx.Node) else x for x in node.args] - output_proxy = decomposition_rules[node.target](*proxy_args) - - # Operations on `Proxy` always yield new `Proxy`s, and the - # return value of our decomposition rule is no exception. - # We need to extract the underlying `Node` from the `Proxy` - # to use it in subsequent iterations of this transform. - new_node = output_proxy.node + # By wrapping the arguments with proxies, we can dispatch to + # the appropriate decomposition rule and add it to the graph by + # symbolically tracing it. + proxy_args = [fx.Proxy(env[x.name]) if isinstance(x, fx.Node) else x for x in node.args] + new_node = decomposition_rules[node.target](*proxy_args).node env[node.name] = new_node else: - # Default case: we don't have a decomposition rule for this - # node, so just copy the node over into the new graph. new_node = new_graph.node_copy(node, lambda x: env[x.name]) env[node.name] = new_node return fx.GraphModule(model, new_graph) -In addition to avoiding explicit graph manipulation, using :class:`Proxy`\s -also allows you to specify your rewrite rules as native Python code. -For transformations that require a large amount of rewrite rules -(such as vmap or grad), this can often improve readability and -maintainability of the rules. +In addition to avoiding explicit graph manipulation, using Proxies also allows you to +specify your rewrite rules as native Python code. For transformations +that require a large amount of rewrite rules (such as vmap or grad), +this can often improve readability and maintainability of the rules. -A worked example of using :class:`Proxy`\s for :class:`Graph` manipulation -can be found -`here `__. +TODO: Example transformations (need to be included first) The Interpreter Pattern ^^^^^^^^^^^^^^^^^^^^^^^ -A useful code organizational pattern in FX is to loop over all the :class:`Node`\s -in a :class:`Graph` and execute them. This can be used for several things including -runtime analysis of values flowing through the graph or transformation of the code -via retracing with :class:`Proxy`\s. For example, suppose we want to run a -:class:`GraphModule` and record the :class:`torch.Tensor` shape and dtype -properties on the nodes as we see them at runtime. That might look like: +In addition to FX passes that take in a module and return a module, +there may be other things you wish to do with the FX graph. For example, +let’s say that you’d like to obtain +the shape information of tensors in your graph. In this case, instead of +looping over the FX graph and modifying it, you can write an interpreter +on top of the FX graph! As the FX IR is quite simple, it’s easy to +reimplement an interpreter that also captures your desired attributes. -:: +As this pattern is quite useful, we we can also use an abstraction of this +pattern +-- the `Interpreter +`__. +You can see an example using this for `shape propagation +`__, +which reinterprets the FX graph with example inputs while annotating the +graph with the shapes. - import torch - import torch.fx - from torch.fx.node import Node +Reinterpreting the FX graph is generally most useful when you want +runtime information that FX typically doesn’t capture (due to being a +symbolic trace). This can be used for capturing shape information for +downstream passes, but it can also be used to capture other information +about execution. +TODO: Add roofline analysis pass once it gets merged. - from typing import Dict - - class ShapeProp: - """ - Shape propagation. This class takes a `GraphModule`. - Then, its `propagate` method executes the `GraphModule` - node-by-node with the given arguments. As each operation - executes, the ShapeProp class stores away the shape and - element type for the output values of each operation on - the `shape` and `dtype` attributes of the operation's - `Node`. - """ - def __init__(self, mod): - self.mod = mod - self.graph = mod.graph - self.modules = dict(self.mod.named_modules()) - - def propagate(self, *args): - args_iter = iter(args) - env : Dict[str, Node] = {} - - def load_arg(a): - return torch.fx.graph.map_arg(a, lambda n: env[n.name]) - - def fetch_attr(target : str): - target_atoms = target.split('.') - attr_itr = self.mod - for i, atom in enumerate(target_atoms): - if not hasattr(attr_itr, atom): - raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}") - attr_itr = getattr(attr_itr, atom) - return attr_itr - - for node in self.graph.nodes: - if node.op == 'placeholder': - result = next(args_iter) - elif node.op == 'get_attr': - result = fetch_attr(node.target) - elif node.op == 'call_function': - result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) - elif node.op == 'call_method': - self_obj, *args = load_arg(node.args) - kwargs = load_arg(node.kwargs) - result = getattr(self_obj, node.target)(*args, **kwargs) - elif node.op == 'call_module': - result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) - - # This is the only code specific to shape propagation. - # you can delete this `if` branch and this becomes - # a generic GraphModule interpreter. - if isinstance(result, torch.Tensor): - node.shape = result.shape - node.dtype = result.dtype - - env[node.name] = result - - return load_arg(self.graph.result) - -As you can see, a full interpreter for FX is not that complicated -but it can be very useful. To ease using this pattern, we provide -the :class:`Interpreter` class, which encompasses the above logic -in a way that certain aspects of the interpreter's execution can -be overridden via method overrides. - -In addition to executing operations, we can also generate a new -`Graph` by feeding :class:`Proxy` values through an interpreter. -Similarly, we provide the :class:`Transformer` class to encompass -this pattern. :class:`Transformer` behaves similarly to -:class:`Interpreter`, but instead of calling the ``run`` method to -get a concrete output value from the Module, you would call the -:meth:`Transformer.transform` method to return a new -:class:`GraphModule` which was subject to any transformation rules -you installed as overridden methods. - -Examples of the Interpreter Pattern -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Examples +~~~~~~~~ - `Shape Propagation `__ -- `Performance Profiler `__ +- `Roofline + Analyzer `__ Debugging @@ -406,72 +192,21 @@ Debugging Introduction ^^^^^^^^^^^^^^^^ -Often in the course of authoring transformations, our code will not be quite right. -In this case, we may need to do some debugging. The key is to work -backwards: first, check the results of invoking the generated module to prove or -disprove correctness. Then, inspect and debug the generated code. Then, debug the -process of transformations that led to the generated code. +After symbolically tracing an ``nn.Module`` and performing some number +of transformations on the resulting GraphModule, we'll want to verify +that the proper semantics were preserved after those transforms. If they +weren't, we may need to do some debugging. The key is to work +backwards: first, check the results of the generated module, then debug +the generated code, then debug the process of transformations that lead +to the generated code. If you’re not familiar with debuggers, please see the auxiliary section :ref:`Available Debuggers`. -Checking Correctness of Modules -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Because the output of most deep learning modules consists of floating -point :class:`torch.Tensor` instances, checking for equivalence between -the results of two :class:`torch.nn.Module` is not as straightforward -as doing a simple equality check. To motivate this, let's use an -example: - -:: - - import torch - import torch.fx - import torchvision.models as models - - def transform(m : torch.nn.Module) -> torch.nn.Module: - gm = torch.fx.symbolic_trace(m) - - # Imagine we're doing some transforms here - # <...> - - gm.recompile() - - return gm - - resnet18 = models.resnet18() - transformed_resnet18 = transform(resnet18) - - input_image = torch.randn(5, 3, 224, 224) - - assert resnet18(input_image) == transformed_resnet18(input_image) - """ - RuntimeError: Boolean value of Tensor with more than one value is ambiguous - """ - -Here, we've tried to check equality of the values of two deep learning -models with the ``==`` equality operator. However, this is not well- -defined both due to the issue of that operator returning a tensor -and not a bool, but also because comparison of floating point values -should use a margin of error (or epsilon) to account for the -non-commutativity of floating point operations (see -`here `__ for more -details). We can use :func:`torch.allclose` instead, which will give -us an approximate comparison taking into account a relative and -absolute tolerance threshold: - -:: - - assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image)) - -This is the first tool in our toolbox to check if transformed modules are -behaving as we expect compared to a reference implementation. - Debugging the Generated Code ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Because FX generates the ``forward()`` function on :class:`GraphModule`\s, using +Because FX generates the ``forward()`` function on GraphModules, using traditional debugging techniques like ``print`` statements or ``pdb`` is not as straightfoward. Luckily, we have several techniques we can use for debugging the generated code. @@ -479,34 +214,22 @@ for debugging the generated code. Use ``pdb`` ~~~~~~~~~~~~~ Invoke ``pdb`` to step into the running program. Although the code that -represents the :class:`Graph` is not in any source file, we can still step +represents the FX graph is not in any source file, we can still step into it manually using ``pdb`` when the forward pass is invoked. :: - import torch - import torch.fx - import torchvision.models as models - - def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module: - graph = tracer_class().trace(inp) + def my_pass(in: torch.nn.Module) -> torch.nn.Module: + traced = torch.fx.symbolic_trace(in) # Transformation logic here - # <...> - - # Return new Module - return fx.GraphModule(inp, graph) - - my_module = models.resnet18() - my_module_transformed = my_pass(my_module) - - input_value = torch.randn(5, 3, 224, 224) + return traced # When this line is executed at runtime, we will be dropped into an # interactive `pdb` prompt. We can use the `step` or `s` command to # step into the execution of the next line import pdb; pdb.set_trace() - my_module_transformed(input_value) + my_pass(my_module) .. _Print the Generated Code: @@ -522,8 +245,6 @@ your code and examine it from there. # Assume that `traced` is a GraphModule that has undergone some # number of transforms - # Copy this code for later - print(traced) # Print the code generated from symbolic tracing. This outputs: """ def forward(self, y): @@ -531,6 +252,8 @@ your code and examine it from there. add_1 = x + y; x = y = None return add_1 """ + # Copy this code for later + print(traced) # Subclass the original Module class SubclassM(M): @@ -538,7 +261,7 @@ your code and examine it from there. super().__init__() # Paste the generated `forward` function (the one we printed and - # copied above) here + # copied on line 22) here def forward(self, y): x = self.x add_1 = x + y; x = y = None @@ -574,7 +297,7 @@ Debugging the Transformation Now that we've identified that a transformation is creating incorrect code, it's time to debug the transformation itself. First, we'll check the :ref:`Limitations of Symbolic Tracing` section in the documentation. -Once we verify that tracing is working as expected, the goal +Once we verify that ``symbolic_trace`` is working as expected, the goal becomes figuring out what went wrong during our ``GraphModule`` transformation. There may be a quick answer in :ref:`Writing Transformations`, but, if not, there are several ways to @@ -596,27 +319,24 @@ examine our traced module: # sake of brevity. traced = symbolic_trace(m) - # Print the code produced by tracing the module. - print(traced) - # The generated `forward` function is: + # Print the code produced by tracing the module. The generated `forward` + # function is: """ def forward(self, x, y): add_1 = x + y; x = y = None return add_1 """ + print(traced) - # Print the internal Graph. - print(traced.graph) - # This print-out returns: + # Print the internal Graph. This representation returns: """ graph(x, y): %add_1 : [#users=1] = call_function[target=](args = (%x, %y), kwargs = {}) return add_1 """ + print(traced.graph) - # Print a tabular representation of the internal Graph. - traced.graph.print_tabular() - # This gives us: + # Print a tabular representation of the internal Graph. This gives us: """ opcode name target args kwargs ------------- ------ ----------------------- -------- -------- @@ -624,9 +344,10 @@ examine our traced module: placeholder y y () {} call_function add_1 (x, y) {} """ + traced.graph.print_tabular() Using the utility functions above, we can compare our traced Module -before and after we've applied our transformations. Sometimes, a +before and after we've apply our transformations. Sometimes, a simple visual comparison is enough to trace down a bug. If it's still not clear what's going wrong, a debugger like ``pdb`` can be a good next step. @@ -636,22 +357,26 @@ Going off of the example above, consider the following code: :: # Sample user-defined function - def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module: + def transform_graph(gm: GraphModule) -> None: + # Get the Graph from our traced Module - g = tracer_class().trace(module) + g = gm.graph """ Transformations on `g` go here """ - return fx.GraphModule(module, g) + # Recompile the GraphModule. This must be called after editing + # the Graph `g`, otherwise the generated code will still reflect + # the old Graph before any transforms + gm.recompile() # Transform the Graph - transformed = transform_graph(traced) + transform_graph(traced) # Print the new code after our transforms. Check to see if it was # what we expected - print(transformed) + print(traced) Using the above example, let’s say that the call to ``print(traced)`` showed us that there was an error in our transforms. We want to find @@ -704,10 +429,10 @@ Limitations of Symbolic Tracing FX uses a system of **symbolic tracing** (a.k.a `symbolic execution `__) to capture the semantics of programs in a transformable/analyzable form. -The system is **tracing** in that it executes the program (really a -:class:`torch.nn.Module` or function) to record operations. It is +The system is **tracing** in that it executes the program (really an +``nn.Module`` or function) to gather this information. It is **symbolic** in that the data flowing through the program during this -execution is not real data, but rather symbols (:class:`Proxy` in FX parlance). +execution is not real data, but rather symbols (“Proxy” in FX parlance). Although symbolic tracing works for most neural net code, it has some limitations. @@ -817,10 +542,10 @@ symbolic tracing: fx.symbolic_trace(f) # Fails! - def wrapper(flag): + def g(flag): return lambda x: f(x, flag) - new_f = wrapper(flag=True) + new_f = g(flag=True) fx.symbolic_trace(new_f) In the case of truly dynamic control flow, the sections of the program @@ -841,8 +566,6 @@ them in symbolic tracing. For example: :: - import torch - import torch.fx from math import sqrt def normalize(x): @@ -1013,3 +736,4 @@ API Reference :members: .. autofunction:: torch.fx.replace_pattern + diff --git a/torch/fx/__init__.py b/torch/fx/__init__.py index 1bc8037ee3b..a5c614a58df 100644 --- a/torch/fx/__init__.py +++ b/torch/fx/__init__.py @@ -48,24 +48,28 @@ demonstration of these components in action: return clamp_1 """ -The **symbolic tracer** performs “symbolic execution” of the Python +The **symbolic tracer** performs “abstract interpretation” of the Python code. It feeds fake values, called Proxies, through the code. Operations on theses Proxies are recorded. More information about symbolic tracing -can be found in the :func:`symbolic_trace` and :class:`Tracer` +can be found in the +`symbolic\_trace `__ +and `Tracer `__ documentation. The **intermediate representation** is the container for the operations that were recorded during symbolic tracing. It consists of a list of Nodes that represent function inputs, callsites (to functions, methods, -or :class:`torch.nn.Module` instances), and return values. More information -about the IR can be found in the documentation for :class:`Graph`. The +or ``nn.Module`` instances), and return values. More information about +the IR can be found in the documentation for +`Graph `__. The IR is the format on which transformations are applied. **Python code generation** is what makes FX a Python-to-Python (or Module-to-Module) transformation toolkit. For each Graph IR, we can create valid Python code matching the Graph’s semantics. This -functionality is wrapped up in :class:`GraphModule`, which is a -:class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a +functionality is wrapped up in +`GraphModule `__, +which is an ``nn.Module`` instance that holds a ``Graph`` as well as a ``forward`` method generated from the Graph. Taken together, this pipeline of components (symbolic tracing → @@ -76,10 +80,6 @@ symbolic tracing can be used in isolation to capture a form of the code for analysis (and not transformation) purposes. Code generation can be used for programmatically generating models, for example from a config file. There are many uses for FX! - -Several example transformations can be found at the -`examples `__ -repository. ''' from .graph_module import GraphModule