mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Easy] GraphTransformObserver Refactoring (#139292)
Uses `torch._inductor.config.trace.log_url_for_graph_xform` by default as the log url. It was only ever instantiated with this as the log_url argument. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139292 Approved by: https://github.com/shengfukevin, https://github.com/shunting314
This commit is contained in:
parent
8fa0bc3358
commit
4db6b740bc
10 changed files with 30 additions and 48 deletions
|
|
@ -34,7 +34,9 @@ class TestGraphTransformObserver(TestCase):
|
|||
|
||||
log_url = tempfile.mkdtemp()
|
||||
|
||||
with GraphTransformObserver(traced, "replace_neg_with_relu", log_url) as ob:
|
||||
with GraphTransformObserver(
|
||||
traced, "replace_neg_with_relu", log_url=log_url
|
||||
) as ob:
|
||||
subgraph_rewriter.replace_pattern(traced, pattern, replacement)
|
||||
|
||||
self.assertTrue("relu" in ob.created_nodes)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
|||
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
from .. import config
|
||||
from ..fx_utils import get_fake_args_kwargs
|
||||
from ..virtualized import V
|
||||
|
||||
|
|
@ -583,9 +582,7 @@ def fuse_ddp_communication(
|
|||
) -> None:
|
||||
for i, pa in enumerate(passes):
|
||||
with GraphTransformObserver(
|
||||
graph.owning_module,
|
||||
f"fuse_ddp_communication_pass_{i}",
|
||||
config.trace.log_url_for_graph_xform,
|
||||
graph.owning_module, f"fuse_ddp_communication_pass_{i}"
|
||||
):
|
||||
if isinstance(pa, str):
|
||||
func = globals()[pa]
|
||||
|
|
|
|||
|
|
@ -1399,6 +1399,5 @@ def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
|
|||
with GraphTransformObserver(
|
||||
graph.owning_module,
|
||||
f"group_batch_fusion_{i}",
|
||||
config.trace.log_url_for_graph_xform,
|
||||
):
|
||||
apply_group_batch_fusion(graph, rule) # type: ignore[arg-type]
|
||||
|
|
|
|||
|
|
@ -439,9 +439,7 @@ def joint_graph_passes(graph: torch.fx.GraphModule):
|
|||
lazy_init()
|
||||
count = 0
|
||||
if config.joint_custom_pre_pass is not None:
|
||||
with GraphTransformObserver(
|
||||
graph, "joint_custom_pre_pass", config.trace.log_url_for_graph_xform
|
||||
):
|
||||
with GraphTransformObserver(graph, "joint_custom_pre_pass"):
|
||||
config.joint_custom_pre_pass(graph.graph)
|
||||
count += 1
|
||||
|
||||
|
|
@ -450,9 +448,7 @@ def joint_graph_passes(graph: torch.fx.GraphModule):
|
|||
remove_noop_ops(graph.graph)
|
||||
|
||||
if config.joint_graph_constant_folding:
|
||||
with GraphTransformObserver(
|
||||
graph, "constant_fold_uniform_value", config.trace.log_url_for_graph_xform
|
||||
):
|
||||
with GraphTransformObserver(graph, "constant_fold_uniform_value"):
|
||||
constant_fold_uniform_value(graph)
|
||||
|
||||
if config.pattern_matcher:
|
||||
|
|
@ -463,9 +459,7 @@ def joint_graph_passes(graph: torch.fx.GraphModule):
|
|||
count += replace_random_passes(graph)
|
||||
|
||||
if config.joint_custom_post_pass is not None:
|
||||
with GraphTransformObserver(
|
||||
graph, "joint_custom_post_pass", config.trace.log_url_for_graph_xform
|
||||
):
|
||||
with GraphTransformObserver(graph, "joint_custom_post_pass"):
|
||||
config.joint_custom_post_pass(graph.graph)
|
||||
count += 1
|
||||
|
||||
|
|
|
|||
|
|
@ -98,9 +98,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
|||
fake_tensor_updater = FakeTensorUpdater(gm.graph)
|
||||
|
||||
if post_grad_custom_pre_pass := config.post_grad_custom_pre_pass:
|
||||
with GraphTransformObserver(
|
||||
gm, "post_grad_custom_pre_pass", config.trace.log_url_for_graph_xform
|
||||
):
|
||||
with GraphTransformObserver(gm, "post_grad_custom_pre_pass"):
|
||||
apply_pass(
|
||||
lambda: post_grad_custom_pre_pass(gm.graph), "post_grad_custom_pre_pass"
|
||||
)
|
||||
|
|
@ -145,9 +143,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
|||
)
|
||||
|
||||
if post_grad_custom_post_pass := config.post_grad_custom_post_pass:
|
||||
with GraphTransformObserver(
|
||||
gm, "post_grad_custom_post_pass", config.trace.log_url_for_graph_xform
|
||||
):
|
||||
with GraphTransformObserver(gm, "post_grad_custom_post_pass"):
|
||||
apply_pass(
|
||||
lambda: post_grad_custom_post_pass(gm.graph),
|
||||
"post_grad_custom_post_pass",
|
||||
|
|
|
|||
|
|
@ -280,9 +280,7 @@ def pre_grad_passes(
|
|||
efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type]
|
||||
|
||||
if config.pre_grad_custom_pass is not None:
|
||||
with GraphTransformObserver(
|
||||
gm, "pre_grad_custom_pass", config.trace.log_url_for_graph_xform
|
||||
):
|
||||
with GraphTransformObserver(gm, "pre_grad_custom_pass"):
|
||||
config.pre_grad_custom_pass(gm.graph)
|
||||
stable_topological_sort(gm.graph)
|
||||
|
||||
|
|
@ -324,30 +322,20 @@ def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule:
|
|||
# For linear permute fusion, we need to check input info to identify
|
||||
# and perform proper permutation/transpose
|
||||
ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
|
||||
with GraphTransformObserver(
|
||||
gm, "linear_permute_fusion", config.trace.log_url_for_graph_xform
|
||||
):
|
||||
with GraphTransformObserver(gm, "linear_permute_fusion"):
|
||||
gm = linear_permute_fusion(gm)
|
||||
with GraphTransformObserver(
|
||||
gm, "permute_linear_fusion", config.trace.log_url_for_graph_xform
|
||||
):
|
||||
with GraphTransformObserver(gm, "permute_linear_fusion"):
|
||||
gm = permute_linear_fusion(gm)
|
||||
with GraphTransformObserver(
|
||||
gm, "permute_matmul_fusion", config.trace.log_url_for_graph_xform
|
||||
):
|
||||
with GraphTransformObserver(gm, "permute_matmul_fusion"):
|
||||
gm = permute_matmul_fusion(gm)
|
||||
|
||||
# make sure the autograd is disabled.
|
||||
if torch.is_grad_enabled() or not is_cpu:
|
||||
return gm
|
||||
if config.freezing:
|
||||
with GraphTransformObserver(
|
||||
gm, "remove_identity", config.trace.log_url_for_graph_xform
|
||||
):
|
||||
with GraphTransformObserver(gm, "remove_identity"):
|
||||
gm = remove_identity(gm)
|
||||
with GraphTransformObserver(
|
||||
gm, "fuse_conv_bn", config.trace.log_url_for_graph_xform
|
||||
):
|
||||
with GraphTransformObserver(gm, "fuse_conv_bn"):
|
||||
gm = fuse_conv_bn(gm)
|
||||
return gm
|
||||
|
||||
|
|
|
|||
|
|
@ -27,9 +27,7 @@ def replace_random_passes(gm: torch.fx.GraphModule):
|
|||
return 0
|
||||
|
||||
count = patterns.apply(gm)
|
||||
with GraphTransformObserver(
|
||||
gm, "fuse_seed_creation_pass", config.trace.log_url_for_graph_xform
|
||||
):
|
||||
with GraphTransformObserver(gm, "fuse_seed_creation_pass"):
|
||||
count += fuse_seed_creation_pass(gm.graph)
|
||||
|
||||
return count
|
||||
|
|
|
|||
|
|
@ -78,7 +78,6 @@ import torch.fx
|
|||
import torch.utils._pytree as pytree
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.config import trace as trace_config
|
||||
from torch._prims_common import is_integer_dtype
|
||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
|
@ -1746,9 +1745,7 @@ class PatternMatcherPass:
|
|||
if has_call_module:
|
||||
nodes.append(graph.find_nodes(op="call_module", sort=False))
|
||||
pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher"
|
||||
with GraphTransformObserver(
|
||||
gm, pass_name, trace_config.log_url_for_graph_xform
|
||||
):
|
||||
with GraphTransformObserver(gm, pass_name):
|
||||
for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
|
||||
target = extract_target(node)
|
||||
if node.op == "call_module":
|
||||
|
|
|
|||
|
|
@ -1638,7 +1638,7 @@ def pass_execution_and_save(func, gm, inp, msg):
|
|||
print(f"Before:\n{gm.graph}", file=f)
|
||||
print(gm.graph, file=before_io)
|
||||
start_time = datetime.now()
|
||||
with GraphTransformObserver(gm, msg, config.trace.log_url_for_graph_xform):
|
||||
with GraphTransformObserver(gm, msg):
|
||||
func(gm.graph)
|
||||
time_elapsed = datetime.now() - start_time
|
||||
# recompile graph
|
||||
|
|
|
|||
|
|
@ -15,8 +15,19 @@ __all__ = ["GraphTransformObserver"]
|
|||
class GraphTransformObserver:
|
||||
__pass_count = 0
|
||||
|
||||
def __init__(self, gm: GraphModule, passname: str, log_url: Optional[str] = None):
|
||||
def __init__(
|
||||
self, gm: GraphModule, passname: str, *, log_url: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified
|
||||
"""
|
||||
|
||||
# If log_url is None, we don't log anything
|
||||
if log_url is None:
|
||||
from torch._inductor.config import trace
|
||||
|
||||
log_url = trace.log_url_for_graph_xform
|
||||
|
||||
self.log_url = log_url
|
||||
if self.log_url is None:
|
||||
return
|
||||
|
|
|
|||
Loading…
Reference in a new issue