mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[PT2][Optimus][Observability] Log the optimus graph transformation to the scuba (#119745)
Summary: Current everstore upload logging may cuase excessive compilation time when the model has lots of graph breaks (post: https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/), we here log the transformation only when the graph changed Test Plan: timeout flows: f528209775 f530084719 Differential Revision: D53692344 Pull Request resolved: https://github.com/pytorch/pytorch/pull/119745 Approved by: https://github.com/jackiexu1992
This commit is contained in:
parent
006eead7d2
commit
7b1f5c874f
10 changed files with 93 additions and 55 deletions
|
|
@ -6,7 +6,7 @@ import unittest
|
|||
import torch
|
||||
import torch._inductor
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._dynamo.utils import counters, optimus_scuba_log
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
||||
try:
|
||||
|
|
@ -285,6 +285,7 @@ class TestGroupBatchFusion(TestCase):
|
|||
counters["inductor"]["batch_fusion"],
|
||||
0,
|
||||
)
|
||||
self.assertNotIn("group_batch_fusion_pre_grad", optimus_scuba_log)
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
self.compare_parameters(module, traced)
|
||||
|
|
@ -297,6 +298,7 @@ class TestGroupBatchFusion(TestCase):
|
|||
counters["inductor"]["batch_fusion"],
|
||||
3,
|
||||
)
|
||||
self.assertIn("group_batch_fusion_post_grad", optimus_scuba_log)
|
||||
counters.clear()
|
||||
|
||||
@unittest.skipIf(not has_fbgemm, "requires fbgemm")
|
||||
|
|
@ -468,6 +470,8 @@ class TestPostGradBatchLinearFusion(TestCase):
|
|||
counters["inductor"]["batch_fusion"],
|
||||
2,
|
||||
)
|
||||
self.assertNotIn("group_batch_fusion_pre_grad", optimus_scuba_log)
|
||||
self.assertIn("group_batch_fusion_post_grad", optimus_scuba_log)
|
||||
|
||||
|
||||
class TestFindIndependentSubsetGreedy(TestCase):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import torch
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._dynamo.utils import counters, optimus_scuba_log
|
||||
from torch._inductor.fx_passes.misc_patterns import numpy_compat_normalization
|
||||
from torch.testing._internal.common_utils import IS_LINUX
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
|
@ -90,6 +90,10 @@ class TestSplitCatFxPasses(TestCase):
|
|||
counters["inductor"]["split_cat_norm"],
|
||||
expected_split_norm_count,
|
||||
)
|
||||
if expected_split_norm_count > 0:
|
||||
self.assertIn(
|
||||
"split_cat_pattern_normalization_pass_pre_grad", optimus_scuba_log
|
||||
)
|
||||
counters.clear()
|
||||
|
||||
@patch
|
||||
|
|
@ -251,6 +255,10 @@ class TestSplitCatFxPasses(TestCase):
|
|||
counters["inductor"]["consecutive_split_merged"],
|
||||
expected_split_merged,
|
||||
)
|
||||
if expected_split_merged > 0:
|
||||
self.assertIn(
|
||||
"split_cat_pattern_merge_splits_pass_pre_grad", optimus_scuba_log
|
||||
)
|
||||
counters.clear()
|
||||
|
||||
@patch
|
||||
|
|
@ -1063,6 +1071,9 @@ class TestSplitCatFxPasses(TestCase):
|
|||
counters["inductor"]["stack_tahn_unbind_merged"],
|
||||
expected_stack_tahn_unbind_merged,
|
||||
)
|
||||
self.assertIn(
|
||||
"split_cat_pattern_merge_getitem_cat_pass_pre_grad", optimus_scuba_log
|
||||
)
|
||||
counters.clear()
|
||||
|
||||
def test_numpy_compat_normalization(self):
|
||||
|
|
|
|||
|
|
@ -101,6 +101,7 @@ from torch.utils._pytree import tree_map_only
|
|||
|
||||
|
||||
counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter)
|
||||
optimus_scuba_log: Dict[str, Any] = {}
|
||||
troubleshooting_url = "https://pytorch.org/docs/master/compile/troubleshooting.html"
|
||||
nnmodule_doc_url = "https://pytorch.org/docs/master/compile/nn-module.html"
|
||||
nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations."
|
||||
|
|
@ -1154,10 +1155,7 @@ def dict_keys_repr(const_keys, *, local) -> str:
|
|||
GLOBAL_KEY_PREFIX = "__dict_key"
|
||||
|
||||
|
||||
from torch._subclasses import ( # noqa: F401
|
||||
FakeTensorMode,
|
||||
UnsupportedFakeTensorException,
|
||||
)
|
||||
from torch._subclasses import UnsupportedFakeTensorException # noqa: F401
|
||||
|
||||
|
||||
def wrap_fake_exception(fn):
|
||||
|
|
|
|||
|
|
@ -33,13 +33,19 @@ from torch._dynamo import (
|
|||
logging as dynamo_logging,
|
||||
utils as dynamo_utils,
|
||||
)
|
||||
from torch._dynamo.utils import counters, detect_fake_mode, lazy_format_graph_code
|
||||
from torch._dynamo.utils import (
|
||||
counters,
|
||||
detect_fake_mode,
|
||||
lazy_format_graph_code,
|
||||
optimus_scuba_log,
|
||||
)
|
||||
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
|
||||
from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache
|
||||
|
||||
from torch._inductor.debug import save_args_for_compile_fx_inner
|
||||
from torch._ops import OpOverload
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch._utils_internal import signpost_event
|
||||
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
||||
|
||||
from .._dynamo.backends.common import aot_autograd
|
||||
|
|
@ -621,9 +627,11 @@ def fx_codegen_and_compile(
|
|||
post_grad_passes(gm, is_inference=is_inference)
|
||||
V.debug.fx_graph_transformed(gm, example_inputs)
|
||||
post_grad_graphs_log.debug("%s", lazy_format_graph_code("AFTER POST GRAD", gm))
|
||||
log.debug(
|
||||
"counters of inductor dict after apply passes on the input FX graph in the post grad pass: %s",
|
||||
counters["inductor"],
|
||||
optimus_scuba_log["inductor_post_grad"] = counters["inductor"]
|
||||
signpost_event(
|
||||
"optimus",
|
||||
"compile_fx.post_grad_passes",
|
||||
optimus_scuba_log,
|
||||
)
|
||||
|
||||
with V.set_fake_mode(fake_mode):
|
||||
|
|
@ -1159,9 +1167,11 @@ def compile_fx(
|
|||
)
|
||||
|
||||
model_ = pre_grad_passes(model_, example_inputs_)
|
||||
log.debug(
|
||||
"counters of inductor dict after apply passes on the input FX graph in the pre grad pass: %s",
|
||||
counters["inductor"],
|
||||
optimus_scuba_log["inductor_pre_grad"] = counters["inductor"]
|
||||
signpost_event(
|
||||
"optimus",
|
||||
"compile_fx.pre_grad_passes",
|
||||
optimus_scuba_log,
|
||||
)
|
||||
|
||||
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from typing import (
|
|||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._utils_internal import print_graph
|
||||
|
||||
from .. import config
|
||||
from ..pattern_matcher import (
|
||||
|
|
@ -936,7 +935,6 @@ def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True):
|
|||
|
||||
|
||||
def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
|
||||
print_graph(graph, "Before group_batch fusion in pre grad pass.")
|
||||
fusions: List[GroupBatchFusionBase] = []
|
||||
# we keep all current pre grad fusions to keep
|
||||
# current implementation, will remove this later
|
||||
|
|
@ -965,4 +963,3 @@ def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
|
|||
|
||||
for rule in fusions:
|
||||
apply_group_batch_fusion(graph, rule) # type: ignore[arg-type]
|
||||
print_graph(graph, f"Apply fusion {rule.__class__.__name__}.")
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import numpy
|
|||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch._utils_internal import print_graph
|
||||
|
||||
from .. import config
|
||||
|
||||
|
|
@ -43,8 +42,8 @@ def clean_memory() -> None:
|
|||
def compare_dict_tensors(dict_base, dict_control, precision):
|
||||
if len(set(dict_base.keys())) != len(set(dict_control.keys())):
|
||||
logger.warning("Mismatch keys found before and after pre/post grad fx passes.")
|
||||
print_graph(dict_base.keys(), "keys before pre/post grad fx passes.")
|
||||
print_graph(dict_control.keys(), "keys after pre/post grad fx passes.")
|
||||
logger.debug("keys before pre/post grad fx passes %s", dict_base.keys())
|
||||
logger.debug("keys after pre/post grad fx passes %s", dict_control.keys())
|
||||
return False
|
||||
is_allclose = True
|
||||
for key in dict_base.keys():
|
||||
|
|
@ -66,8 +65,8 @@ def compare_dict_tensors(dict_base, dict_control, precision):
|
|||
logger.warning(
|
||||
"Mismatch parameter values found before and after pre/post grad fx passes."
|
||||
)
|
||||
print_graph(dict_base[key], "value before pre/post grad fx passes.")
|
||||
print_graph(dict_control[key], "value after pre/post grad fx passes.")
|
||||
logger.debug("value before pre/post grad fx passes %s", dict_base[key])
|
||||
logger.debug("value after pre/post grad fx passes %s", dict_control[key])
|
||||
is_allclose = False
|
||||
return is_allclose
|
||||
|
||||
|
|
@ -92,9 +91,11 @@ def compare_tuple_tensors(tuple_base, tuple_control, precision):
|
|||
atol=precision,
|
||||
equal_nan=True,
|
||||
):
|
||||
print_graph(tuple_base[i], "forward output before pre/post grad fx passes.")
|
||||
print_graph(
|
||||
tuple_control[i], "forward output after pre/post grad fx passes."
|
||||
logger.debug(
|
||||
"forward output before pre/post grad fx passes %s", tuple_base[i]
|
||||
)
|
||||
logger.debug(
|
||||
"forward output after pre/post grad fx passes %s", tuple_control[i]
|
||||
)
|
||||
is_allclose = False
|
||||
return is_allclose
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import copy
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
|
|
@ -12,10 +13,11 @@ import torch._inductor as inductor
|
|||
import torch.utils._pytree as pytree
|
||||
from torch import fx
|
||||
from torch._decomp import register_decomposition
|
||||
from torch._dynamo.utils import counters, optimus_scuba_log
|
||||
|
||||
from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
|
||||
|
||||
from torch._utils_internal import print_graph
|
||||
from torch._utils_internal import upload_graph
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
|
||||
|
||||
from .. import config, ir, pattern_matcher
|
||||
|
|
@ -80,18 +82,13 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
|||
|
||||
if config.pattern_matcher:
|
||||
lazy_init()
|
||||
|
||||
print_graph(gm.graph, "Before group batch fusion in post grad pass.")
|
||||
inductor_before_change = copy.deepcopy(counters["inductor"])
|
||||
group_batch_fusion_passes(gm.graph, pre_grad=False)
|
||||
print_graph(gm.graph, "After group batch fusion in post grad pass.")
|
||||
if counters["inductor"] != inductor_before_change:
|
||||
optimus_scuba_log["group_batch_fusion_post_grad"] = upload_graph(gm.graph)
|
||||
remove_noop_ops(gm.graph)
|
||||
print_graph(gm.graph, "Before split cat in post grad pass.")
|
||||
for patterns in pass_patterns:
|
||||
patterns.apply(gm.graph) # type: ignore[arg-type]
|
||||
print_graph(
|
||||
gm.graph,
|
||||
"Apply split cat pattern matcher PatternMatcherPass in post grad.",
|
||||
)
|
||||
if is_inference:
|
||||
inference_patterns.apply(gm.graph) # type: ignore[arg-type]
|
||||
|
||||
|
|
@ -112,8 +109,6 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
|||
gm.recompile()
|
||||
gm.graph.lint()
|
||||
|
||||
print_graph(gm.graph, "After recompile in post grad pass.")
|
||||
|
||||
|
||||
@init_once_fakemode
|
||||
def lazy_init():
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from typing import List, Optional
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch._dynamo.utils import detect_fake_mode
|
||||
from torch._utils_internal import print_graph
|
||||
from torch._dynamo.utils import counters, detect_fake_mode, optimus_scuba_log
|
||||
from torch._utils_internal import upload_graph
|
||||
from torch.fx.experimental.optimization import (
|
||||
matches_module_pattern,
|
||||
replace_node_module,
|
||||
|
|
@ -13,6 +13,7 @@ from torch.fx.experimental.optimization import (
|
|||
from torch.fx.passes.shape_prop import ShapeProp
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights
|
||||
|
||||
from .. import config
|
||||
|
||||
from ..fx_utils import matches_module_function_pattern
|
||||
|
|
@ -27,13 +28,27 @@ from .misc_patterns import numpy_compat_normalization
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
normalization_pass = PatternMatcherPass(prevent_match_across_mutations=True)
|
||||
merge_splits_pass = PatternMatcherPass(prevent_match_across_mutations=True)
|
||||
split_cat_pass = PatternMatcherPass(prevent_match_across_mutations=True)
|
||||
unbind_stack_pass = PatternMatcherPass(prevent_match_across_mutations=True)
|
||||
efficient_conv_bn_eval_pass = PatternMatcherPass(prevent_match_across_mutations=True)
|
||||
merge_getitem_cat_pass = PatternMatcherPass(prevent_match_across_mutations=True)
|
||||
predispatch_pass = PatternMatcherPass(prevent_match_across_mutations=True)
|
||||
normalization_pass = PatternMatcherPass(
|
||||
prevent_match_across_mutations=True, pass_name="normalization_pass"
|
||||
)
|
||||
merge_splits_pass = PatternMatcherPass(
|
||||
prevent_match_across_mutations=True, pass_name="merge_splits_pass"
|
||||
)
|
||||
split_cat_pass = PatternMatcherPass(
|
||||
prevent_match_across_mutations=True, pass_name="split_cat_pass"
|
||||
)
|
||||
unbind_stack_pass = PatternMatcherPass(
|
||||
prevent_match_across_mutations=True, pass_name="unbind_stack_pass"
|
||||
)
|
||||
efficient_conv_bn_eval_pass = PatternMatcherPass(
|
||||
prevent_match_across_mutations=True, pass_name="efficient_conv_bn_eval_pass"
|
||||
)
|
||||
merge_getitem_cat_pass = PatternMatcherPass(
|
||||
prevent_match_across_mutations=True, pass_name="merge_getitem_cat_pass"
|
||||
)
|
||||
predispatch_pass = PatternMatcherPass(
|
||||
prevent_match_across_mutations=True, pass_name="predispatch_pass"
|
||||
)
|
||||
# based on predispatch aten IR
|
||||
normalization_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
|
||||
merge_splits_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
|
||||
|
|
@ -114,17 +129,23 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
|
|||
f"[Pre grad(predispatch IR)]Apply split_cat, index: {ind}",
|
||||
)
|
||||
else:
|
||||
# We only log the graph with changes to avoid the excessive compilation time
|
||||
# https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/
|
||||
gm = fuse_fx(gm, example_inputs)
|
||||
numpy_compat_normalization(gm.graph)
|
||||
print_graph(gm.graph, "Before group batch fusion in pre grad pass.")
|
||||
inductor_before_change = copy.deepcopy(counters["inductor"])
|
||||
group_batch_fusion_passes(gm.graph, pre_grad=True)
|
||||
print_graph(gm.graph, "Before split cat in pre grad pass.")
|
||||
for pattern_matcher_pass in pattern_matcher_passes:
|
||||
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
|
||||
print_graph(
|
||||
gm.graph,
|
||||
"Apply split cat pattern matcher PatternMatcherPass in pre grad.",
|
||||
if counters["inductor"] != inductor_before_change:
|
||||
optimus_scuba_log["group_batch_fusion_pre_grad"] = upload_graph(
|
||||
gm.graph
|
||||
)
|
||||
for pattern_matcher_pass in pattern_matcher_passes:
|
||||
inductor_before_change = copy.deepcopy(counters["inductor"])
|
||||
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
|
||||
if counters["inductor"] != inductor_before_change:
|
||||
optimus_scuba_log[
|
||||
f"split_cat_pattern_{pattern_matcher_pass.pass_name}_pre_grad"
|
||||
] = upload_graph(gm.graph)
|
||||
|
||||
if config.pre_grad_custom_pass is not None:
|
||||
config.pre_grad_custom_pass(gm.graph)
|
||||
|
|
@ -148,8 +169,6 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
|
|||
config.fx_passes_numeric_check.get("precision", 1e-4),
|
||||
)
|
||||
|
||||
print_graph(gm.graph, "After recompile in pre grad pass.")
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1214,12 +1214,15 @@ def compute_mutation_region_ids(graph: torch.fx.GraphModule):
|
|||
|
||||
|
||||
class PatternMatcherPass:
|
||||
def __init__(self, prevent_match_across_mutations=False):
|
||||
def __init__(
|
||||
self, prevent_match_across_mutations=False, pass_name: Optional[str] = None
|
||||
):
|
||||
super().__init__()
|
||||
self.patterns: DefaultDict[
|
||||
torch.fx.node.Target, List[PatternEntry]
|
||||
] = defaultdict(list)
|
||||
self.prevent_match_across_mutations = prevent_match_across_mutations
|
||||
self.pass_name = pass_name
|
||||
|
||||
def __getitem__(self, item: torch.fx.node.Target) -> List[PatternEntry]:
|
||||
return self.patterns[item]
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ def log_compilation_event(metrics):
|
|||
log.info("%s", metrics)
|
||||
|
||||
|
||||
def print_graph(graph, msg: str):
|
||||
def upload_graph(graph):
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue