[Easy] GraphTransformObserver Refactoring (#139292)

Uses `torch._inductor.config.trace.log_url_for_graph_xform` by default as the log url. It was only ever instantiated with this as the log_url argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139292
Approved by: https://github.com/shengfukevin, https://github.com/shunting314
This commit is contained in:
eellison 2024-10-30 11:07:22 -07:00 committed by PyTorch MergeBot
parent 8fa0bc3358
commit 4db6b740bc
10 changed files with 30 additions and 48 deletions

View file

@ -34,7 +34,9 @@ class TestGraphTransformObserver(TestCase):
log_url = tempfile.mkdtemp()
with GraphTransformObserver(traced, "replace_neg_with_relu", log_url) as ob:
with GraphTransformObserver(
traced, "replace_neg_with_relu", log_url=log_url
) as ob:
subgraph_rewriter.replace_pattern(traced, pattern, replacement)
self.assertTrue("relu" in ob.created_nodes)

View file

@ -26,7 +26,6 @@ from torch.fx.passes.graph_transform_observer import GraphTransformObserver
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from .. import config
from ..fx_utils import get_fake_args_kwargs
from ..virtualized import V
@ -583,9 +582,7 @@ def fuse_ddp_communication(
) -> None:
for i, pa in enumerate(passes):
with GraphTransformObserver(
graph.owning_module,
f"fuse_ddp_communication_pass_{i}",
config.trace.log_url_for_graph_xform,
graph.owning_module, f"fuse_ddp_communication_pass_{i}"
):
if isinstance(pa, str):
func = globals()[pa]

View file

@ -1399,6 +1399,5 @@ def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
with GraphTransformObserver(
graph.owning_module,
f"group_batch_fusion_{i}",
config.trace.log_url_for_graph_xform,
):
apply_group_batch_fusion(graph, rule) # type: ignore[arg-type]

View file

@ -439,9 +439,7 @@ def joint_graph_passes(graph: torch.fx.GraphModule):
lazy_init()
count = 0
if config.joint_custom_pre_pass is not None:
with GraphTransformObserver(
graph, "joint_custom_pre_pass", config.trace.log_url_for_graph_xform
):
with GraphTransformObserver(graph, "joint_custom_pre_pass"):
config.joint_custom_pre_pass(graph.graph)
count += 1
@ -450,9 +448,7 @@ def joint_graph_passes(graph: torch.fx.GraphModule):
remove_noop_ops(graph.graph)
if config.joint_graph_constant_folding:
with GraphTransformObserver(
graph, "constant_fold_uniform_value", config.trace.log_url_for_graph_xform
):
with GraphTransformObserver(graph, "constant_fold_uniform_value"):
constant_fold_uniform_value(graph)
if config.pattern_matcher:
@ -463,9 +459,7 @@ def joint_graph_passes(graph: torch.fx.GraphModule):
count += replace_random_passes(graph)
if config.joint_custom_post_pass is not None:
with GraphTransformObserver(
graph, "joint_custom_post_pass", config.trace.log_url_for_graph_xform
):
with GraphTransformObserver(graph, "joint_custom_post_pass"):
config.joint_custom_post_pass(graph.graph)
count += 1

View file

@ -98,9 +98,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
fake_tensor_updater = FakeTensorUpdater(gm.graph)
if post_grad_custom_pre_pass := config.post_grad_custom_pre_pass:
with GraphTransformObserver(
gm, "post_grad_custom_pre_pass", config.trace.log_url_for_graph_xform
):
with GraphTransformObserver(gm, "post_grad_custom_pre_pass"):
apply_pass(
lambda: post_grad_custom_pre_pass(gm.graph), "post_grad_custom_pre_pass"
)
@ -145,9 +143,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
)
if post_grad_custom_post_pass := config.post_grad_custom_post_pass:
with GraphTransformObserver(
gm, "post_grad_custom_post_pass", config.trace.log_url_for_graph_xform
):
with GraphTransformObserver(gm, "post_grad_custom_post_pass"):
apply_pass(
lambda: post_grad_custom_post_pass(gm.graph),
"post_grad_custom_post_pass",

View file

@ -280,9 +280,7 @@ def pre_grad_passes(
efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type]
if config.pre_grad_custom_pass is not None:
with GraphTransformObserver(
gm, "pre_grad_custom_pass", config.trace.log_url_for_graph_xform
):
with GraphTransformObserver(gm, "pre_grad_custom_pass"):
config.pre_grad_custom_pass(gm.graph)
stable_topological_sort(gm.graph)
@ -324,30 +322,20 @@ def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule:
# For linear permute fusion, we need to check input info to identify
# and perform proper permutation/transpose
ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
with GraphTransformObserver(
gm, "linear_permute_fusion", config.trace.log_url_for_graph_xform
):
with GraphTransformObserver(gm, "linear_permute_fusion"):
gm = linear_permute_fusion(gm)
with GraphTransformObserver(
gm, "permute_linear_fusion", config.trace.log_url_for_graph_xform
):
with GraphTransformObserver(gm, "permute_linear_fusion"):
gm = permute_linear_fusion(gm)
with GraphTransformObserver(
gm, "permute_matmul_fusion", config.trace.log_url_for_graph_xform
):
with GraphTransformObserver(gm, "permute_matmul_fusion"):
gm = permute_matmul_fusion(gm)
# make sure the autograd is disabled.
if torch.is_grad_enabled() or not is_cpu:
return gm
if config.freezing:
with GraphTransformObserver(
gm, "remove_identity", config.trace.log_url_for_graph_xform
):
with GraphTransformObserver(gm, "remove_identity"):
gm = remove_identity(gm)
with GraphTransformObserver(
gm, "fuse_conv_bn", config.trace.log_url_for_graph_xform
):
with GraphTransformObserver(gm, "fuse_conv_bn"):
gm = fuse_conv_bn(gm)
return gm

View file

@ -27,9 +27,7 @@ def replace_random_passes(gm: torch.fx.GraphModule):
return 0
count = patterns.apply(gm)
with GraphTransformObserver(
gm, "fuse_seed_creation_pass", config.trace.log_url_for_graph_xform
):
with GraphTransformObserver(gm, "fuse_seed_creation_pass"):
count += fuse_seed_creation_pass(gm.graph)
return count

View file

@ -78,7 +78,6 @@ import torch.fx
import torch.utils._pytree as pytree
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import counters
from torch._inductor.config import trace as trace_config
from torch._prims_common import is_integer_dtype
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch.fx.experimental.proxy_tensor import make_fx
@ -1746,9 +1745,7 @@ class PatternMatcherPass:
if has_call_module:
nodes.append(graph.find_nodes(op="call_module", sort=False))
pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher"
with GraphTransformObserver(
gm, pass_name, trace_config.log_url_for_graph_xform
):
with GraphTransformObserver(gm, pass_name):
for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
target = extract_target(node)
if node.op == "call_module":

View file

@ -1638,7 +1638,7 @@ def pass_execution_and_save(func, gm, inp, msg):
print(f"Before:\n{gm.graph}", file=f)
print(gm.graph, file=before_io)
start_time = datetime.now()
with GraphTransformObserver(gm, msg, config.trace.log_url_for_graph_xform):
with GraphTransformObserver(gm, msg):
func(gm.graph)
time_elapsed = datetime.now() - start_time
# recompile graph

View file

@ -15,8 +15,19 @@ __all__ = ["GraphTransformObserver"]
class GraphTransformObserver:
__pass_count = 0
def __init__(self, gm: GraphModule, passname: str, log_url: Optional[str] = None):
def __init__(
self, gm: GraphModule, passname: str, *, log_url: Optional[str] = None
):
"""
log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified
"""
# If log_url is None, we don't log anything
if log_url is None:
from torch._inductor.config import trace
log_url = trace.log_url_for_graph_xform
self.log_url = log_url
if self.log_url is None:
return