From fbfb9a164800aafc13bd694b19366b6b717e2b73 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 11 Aug 2023 20:46:04 +0000 Subject: [PATCH] [Dynamo] Improve PT2 fbcode logging observability (#106932) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: https://docs.google.com/document/d/1D5K3_ELsda3tIUeSyNL_2yee-M3jVWbirqSQ5BDNvHQ/edit This is the revamped version of D47908299. For each frame, we will record a list of compilation metrics: e.g, backend_compile time, entire_frame_compile time, cache_size, co_filename, co_firstlineno, co_name, guards, graph input_count, graph node_count, graph op_count. With the help of job info: mast_job_name, global_rank, we can satisfy the requirements from `Things I’ve used/wanted to use our logging to determine` in https://docs.google.com/document/d/1D5K3_ELsda3tIUeSyNL_2yee-M3jVWbirqSQ5BDNvHQ/edit (or add more metrics for this framework) Test Plan: ``` buck2 test //caffe2/test:test_dynamo ``` Differential Revision: D48142400 Pull Request resolved: https://github.com/pytorch/pytorch/pull/106932 Approved by: https://github.com/anijain2305 --- torch/_dynamo/__init__.py | 9 +---- torch/_dynamo/convert_frame.py | 71 +++++++++++++++++++++++++++++++--- torch/_dynamo/utils.py | 43 +++++++++++++------- torch/_utils_internal.py | 4 ++ 4 files changed, 100 insertions(+), 27 deletions(-) diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index b9d67c50da4..cf512c7b3cb 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -24,13 +24,7 @@ from .eval_frame import ( reset_code, ) from .external_utils import is_compiling -from .utils import ( - compilation_metrics, - graph_break_reasons, - guard_failures, - orig_code_map, - reset_frame_count, -) +from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count __all__ = [ "allow_in_graph", @@ -71,6 +65,5 @@ def reset() -> None: if hasattr(eval_frame.most_recent_backend, "reset"): eval_frame.most_recent_backend.reset() eval_frame.most_recent_backend = None - compilation_metrics.clear() reset_frame_count() torch._C._dynamo.compiled_autograd.clear_cache() diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 5eee44f4161..f7b6b41e771 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -5,12 +5,12 @@ import os import random import types import weakref -from typing import Dict, Optional, Set +from typing import Any, Callable, Dict, List, Optional, Set import torch import torch._logging from torch._guards import tracing -from torch._utils_internal import signpost_event +from torch._utils_internal import log_compilation_event, signpost_event from torch.fx.experimental.symbolic_shapes import ( ConstraintViolationError, GuardOnDataDependentSymNode, @@ -24,6 +24,7 @@ from .backends.registry import CompilerFn from .bytecode_analysis import remove_dead_code, remove_pointless_jumps from .bytecode_transformation import ( check_inst_exn_tab_entries_valid, + Instruction, is_generator, propagate_inst_exn_table_entries, transform_code_object, @@ -45,9 +46,11 @@ from .replay_record import ExecutionRecord from .symbolic_convert import InstructionTranslator from .utils import ( CleanupManager, + CompilationMetrics, counters, dynamo_timed, format_bytecode, + frame_phase_timing, gen_record_file_name, guard_failures, increment_frame, @@ -374,6 +377,7 @@ def convert_frame_assert( export, export_constraints, hooks, + cache_size, frame, frame_state=frame_state, ) @@ -382,7 +386,6 @@ def convert_frame_assert( return wrap_convert_context(_convert_frame_assert) -@dynamo_timed(phase_name="entire_frame_compile") def _compile( code: types.CodeType, globals: Dict[str, object], @@ -393,14 +396,14 @@ def _compile( export: bool, export_constraints, hooks: Hooks, + cache_size: int, frame: Optional[types.FrameType] = None, frame_state=None, ) -> Optional[GuardedCode]: output: Optional[OutputGraph] = None # This is shared across restarts mutated_closure_cell_contents: Set[str] = set() - - # from .utils import print_once; print_once(code.co_filename) + fail_reason: Optional[str] = None def transform(instructions, code_options): nonlocal output @@ -431,7 +434,14 @@ def _compile( check_inst_exn_tab_entries_valid(instructions) instructions[:] = remove_pointless_jumps(remove_dead_code(instructions)) - try: + @dynamo_timed(phase_name="entire_frame_compile") + def compile_inner( + code: types.CodeType, + one_graph: bool, + hooks: Hooks, + transform: Callable[[List[Instruction], Dict[str, Any]], Any], + ) -> Optional[GuardedCode]: + nonlocal output for attempt in itertools.count(): try: out_code = transform_code_object(code, transform) @@ -521,6 +531,10 @@ def _compile( output.local_scope.clear() return guarded_code + + try: + guarded_code = compile_inner(code, one_graph, hooks, transform) + return guarded_code except ( Unsupported, TorchRuntimeError, @@ -529,11 +543,54 @@ def _compile( ConstraintViolationError, GuardOnDataDependentSymNode, ) as e: + fail_reason = str(e) exception_handler(e, code, frame, export=export) raise except Exception as e: + fail_reason = str(e) exception_handler(e, code, frame, export=export) raise InternalTorchDynamoError(str(e)).with_traceback(e.__traceback__) from None + finally: + from .utils import curr_frame + + frame_key = str(curr_frame) + if ( + fail_reason is None + and output is not None + and frame_key in frame_phase_timing + ): + guard_count = len(output.guards) + graph_op_count = output.count_calls() + graph_node_count = len(output.graph.nodes) + graph_input_count = len(output.placeholders) + entire_frame_compile_time = frame_phase_timing[frame_key].get( + "entire_frame_compile", None + ) + backend_compile_time = frame_phase_timing[frame_key].get( + "backend_compile", None + ) + else: + guard_count = None + graph_op_count = None + graph_node_count = None + graph_input_count = None + entire_frame_compile_time = None + backend_compile_time = None + metrics = CompilationMetrics( + frame_key, + code.co_name, + code.co_filename, + code.co_firstlineno, + cache_size, + guard_count, + graph_op_count, + graph_node_count, + graph_input_count, + entire_frame_compile_time, + backend_compile_time, + fail_reason, + ) + log_compilation_event(metrics) def convert_frame(compiler_fn: CompilerFn, hooks: Hooks): @@ -602,7 +659,9 @@ def replay(filename): compiler_fn=eager, one_graph=False, export=False, + export_constraints=None, hooks=Hooks(), + cache_size=0, frame=None, ) except Exception: diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 7e9bf62bc76..7b2135b5c38 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -63,8 +63,11 @@ nnmodule_doc_url = "https://pytorch.org/docs/master/compile/nn-module.html" nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations." log = logging.getLogger(__name__) -# profiling compilation time -compilation_metrics = collections.OrderedDict() +# profiling compilation time by function +compilation_time_metrics = collections.OrderedDict() + +# profiling compilation time by frame phase +frame_phase_timing = collections.OrderedDict() timer_counter = itertools.count() @@ -101,8 +104,6 @@ def dynamo_profiled(func): return profile_wrapper -frame_phase_timing = collections.OrderedDict() - curr_frame = 0 @@ -116,6 +117,7 @@ def increment_frame(): def reset_frame_count(): global curr_frame frame_phase_timing.clear() + compilation_time_metrics.clear() curr_frame = 0 @@ -151,7 +153,7 @@ def print_time_report(): # dynamo_timed API works as a function decorator -# By wrapping a function in dynamo_timed, we can store a record in compilation_metrics +# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics # where the key is the functions name. # For example: # @@ -171,14 +173,13 @@ def dynamo_timed(original_function=None, phase_name=None): @wraps(func) def time_wrapper(*args, **kwargs): key = func.__qualname__ - if key not in compilation_metrics: - compilation_metrics[key] = [] + if key not in compilation_time_metrics: + compilation_time_metrics[key] = [] with torch.profiler.record_function(f"{key} (dynamo_timed)"): t0 = time.time() r = func(*args, **kwargs) time_spent = time.time() - t0 - # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec") - compilation_metrics[key].append(time_spent) + compilation_time_metrics[key].append(time_spent) if phase_name: frame_key = str(curr_frame) if frame_key not in frame_phase_timing: @@ -217,8 +218,8 @@ def compile_times(repr="str", aggregate=False): if repr == "str": rows = [ - (k, fmt_fn(compilation_metrics[k], item_fn=lambda x: f"{x:.4f}")) - for k in compilation_metrics + (k, fmt_fn(compilation_time_metrics[k], item_fn=lambda x: f"{x:.4f}")) + for k in compilation_time_metrics ] out = "TorchDynamo compilation metrics:\n" out += tabulate(rows, headers=("Function", "Runtimes (s)")) @@ -226,9 +227,9 @@ def compile_times(repr="str", aggregate=False): elif repr == "csv": values = [ fmt_fn(v, item_fn=lambda x: f"{x:.6f}") - for v in compilation_metrics.values() + for v in compilation_time_metrics.values() ] - headers = list(compilation_metrics.keys()) + headers = list(compilation_time_metrics.keys()) return headers, values @@ -479,6 +480,22 @@ def proxy_args_kwargs(args, kwargs): ) from e +@dataclasses.dataclass +class CompilationMetrics: + frame_key: str + co_name: str + co_filename: str + co_firstlineno: int + cache_size: int + guard_count: Optional[int] + graph_op_count: Optional[int] + graph_node_count: Optional[int] + graph_input_count: Optional[int] + entire_frame_compile_time_s: Optional[float] + backend_compile_time_s: Optional[float] + fail_reason: Optional[str] + + @dataclasses.dataclass class CleanupHook: """Remove a global variable when hook is called""" diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 73be4660265..628e4fed84d 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -64,6 +64,10 @@ def signpost_event(category: str, name: str, parameters: Dict[str, Any]): log.info("%s %s: %r", category, name, parameters) +def log_compilation_event(metrics): + log.info("%s", metrics) + + TEST_MASTER_ADDR = "127.0.0.1" TEST_MASTER_PORT = 29500 # USE_GLOBAL_DEPS controls whether __init__.py tries to load