From f2d6cfa6775601df5a038f7a4d0b37da75a53ed9 Mon Sep 17 00:00:00 2001 From: James Wu Date: Fri, 3 Jan 2025 09:17:37 -0800 Subject: [PATCH] Introduce CompileEventLogger, replace usages of metrics_context and chromium_event with it (#143420) **Problem statement**: I want to be able to centralize and simplify the process by which people add columns/data to existing spans. We have MetricsContext and ChromiumEventLogger, and there's various choices you can make to decide where and when to log different levels of observability for your events. To resolve this, I want a central API for "adding to events under dynamo_timed". **CompileEventLogger** is intended as a frontend for MetricsContext and ChromiumEventLogger so we can use the same class for handling everything. CompileEventLogger is intended be used within a `dynamo_timed()` context. Its purpose is to 1. log to existing events that are in progress (i.e. within dynamo_timed), and 2. log instant events to chromium that are independent of any specific span. CompileEventLogger has three log levels: - CHROMIUM: Log only to chromium events, visible via tlparse. - PT2_COMPILE: Log to chromium_events + pt2_compile_events - COMPILATION_METRIC: Log to compilation metrics in addition to the toplevel chromium and pt2_compile_event. In addition, we have a function CompileEventLogger.add() that automagically chooses the correct log level. For now, it is conservative, and will never automagically choose to log CompilationMetrics (though I could imagine it figuring out the metadata are all keys in CompilationMetric and therefore loggable there). The goal here is to make one single interface to log stuff for observability reasons, and make it as easy as possible. Not included in this diff: - V1 of this diff will not have implementations of `increment` and `add_to_set` which MetricsContext has, so those usages are not replaced yet. But I'll add those in a followup. - We don't handle `RuntimeMetricsContext`. It's unclear if I want that to be part of this, because under RuntimeMetricsContext there might not be a toplevel event to log to, so chromium events doesn't make sense in that context. So I might leave that separate for now. Differential Revision: [D67346203](https://our.internmc.facebook.com/intern/diff/D67346203/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143420 Approved by: https://github.com/aorenste --- torch/_dynamo/convert_frame.py | 3 + torch/_dynamo/pgo.py | 31 ++-- torch/_dynamo/utils.py | 164 ++++++++++++++++++ .../_aot_autograd/autograd_cache.py | 20 ++- .../_aot_autograd/runtime_wrappers.py | 4 +- torch/_functorch/aot_autograd.py | 12 +- torch/_inductor/codecache.py | 12 +- torch/_inductor/compile_fx.py | 15 +- 8 files changed, 210 insertions(+), 51 deletions(-) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 3bbf73a7208..ab6311909e5 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -1148,6 +1148,9 @@ def _compile( dynamo_time_before_restart ), } + # TODO: replace with CompileEventLogger.compilation_metrics + # There are some columns here not in PT2 Compile Events + # so we need to slightly change it metrics_context.update_outer(metrics) # === END WARNING WARNING WARNING === diff --git a/torch/_dynamo/pgo.py b/torch/_dynamo/pgo.py index 42eb96cd082..996f008164e 100644 --- a/torch/_dynamo/pgo.py +++ b/torch/_dynamo/pgo.py @@ -7,7 +7,6 @@ import enum import logging import os import pickle -import time from collections import defaultdict from typing import DefaultDict, Optional, Tuple, TYPE_CHECKING, TypeVar, Union from typing_extensions import Self @@ -17,8 +16,8 @@ import torch._utils_internal import torch.compiler.config import torch.distributed as dist from torch._dynamo.utils import ( + CompileEventLogger, dynamo_timed, - get_chromium_event_logger, set_feature_use, warn_once, ) @@ -327,9 +326,8 @@ def update_automatic_dynamic( entry.scalar, old_entry.scalar, ) - get_chromium_event_logger().log_instant_event( + CompileEventLogger.instant( "automatic_dynamic", - time.time_ns(), { "name": name, "dim_changed": "scalar", @@ -366,9 +364,8 @@ def update_automatic_dynamic( entry_tup, old_entry_tup, ) - get_chromium_event_logger().log_instant_event( + CompileEventLogger.instant( "automatic_dynamic", - time.time_ns(), { "name": name, "dim_changed": "all" if i is None else i, @@ -546,8 +543,6 @@ def get_code_state() -> DefaultDict[CodeId, CodeState]: if _CODE_STATE is not None: return _CODE_STATE - chromium_log = get_chromium_event_logger() - # Initialize it (even if we don't look up profile) _CODE_STATE = defaultdict(CodeState) @@ -574,13 +569,13 @@ def get_code_state() -> DefaultDict[CodeId, CodeState]: with dynamo_timed( name := "pgo.get_local_code_state", log_pt2_compile_event=True ): - chromium_log.add_event_data(name, cache_key=cache_key) + CompileEventLogger.pt2_compile(name, cache_key=cache_key) # Read lock not necessary as we always write atomically write to # the actual location with open(path, "rb") as f: try: _CODE_STATE = pickle.load(f) - chromium_log.add_event_data(name, cache_size_bytes=f.tell()) + CompileEventLogger.pt2_compile(name, cache_size_bytes=f.tell()) except Exception: log.warning( "get_code_state failed while reading %s", path, exc_info=True @@ -594,7 +589,7 @@ def get_code_state() -> DefaultDict[CodeId, CodeState]: with dynamo_timed( name := "pgo.get_remote_code_state", log_pt2_compile_event=True ): - chromium_log.add_event_data(name, cache_key=cache_key) + CompileEventLogger.pt2_compile(name, cache_key=cache_key) # TODO: I don't really understand why there's a JSON container format try: cache_data = remote_cache.get(cache_key) @@ -609,7 +604,9 @@ def get_code_state() -> DefaultDict[CodeId, CodeState]: data = cache_data["data"] assert isinstance(data, str) payload = base64.b64decode(data) - chromium_log.add_event_data(name, cache_size_bytes=len(payload)) + CompileEventLogger.pt2_compile( + name, cache_size_bytes=len(payload) + ) _CODE_STATE = pickle.loads(payload) except Exception: log.warning( @@ -648,8 +645,7 @@ def put_code_state() -> None: def put_local_code_state(cache_key: str) -> None: with dynamo_timed(name := "pgo.put_local_code_state", log_pt2_compile_event=True): - chromium_log = get_chromium_event_logger() - chromium_log.add_event_data(name, cache_key=cache_key) + CompileEventLogger.pt2_compile(name, cache_key=cache_key) assert _CODE_STATE is not None path = code_state_path(cache_key) @@ -672,7 +668,7 @@ def put_local_code_state(cache_key: str) -> None: with FileLock(lock_path, timeout=LOCK_TIMEOUT): with open(tmp_path, "wb") as f: pickle.dump(_CODE_STATE, f) - chromium_log.add_event_data(name, cache_size_bytes=f.tell()) + CompileEventLogger.pt2_compile(name, cache_size_bytes=f.tell()) os.rename(tmp_path, path) log.info( "put_code_state: wrote local %s, %d entries", path, len(_CODE_STATE) @@ -686,8 +682,7 @@ def put_local_code_state(cache_key: str) -> None: def put_remote_code_state(cache_key: str) -> None: with dynamo_timed(name := "pgo.put_remote_code_state", log_pt2_compile_event=True): - chromium_log = get_chromium_event_logger() - chromium_log.add_event_data(name, cache_key=cache_key) + CompileEventLogger.pt2_compile(name, cache_key=cache_key) assert _CODE_STATE is not None remote_cache = get_remote_cache() @@ -697,7 +692,7 @@ def put_remote_code_state(cache_key: str) -> None: return content = pickle.dumps(_CODE_STATE) - chromium_log.add_event_data(name, cache_size_bytes=len(content)) + CompileEventLogger.pt2_compile(name, cache_size_bytes=len(content)) cache_data: JsonDataTy = { "data": base64.b64encode(content).decode("ascii"), } diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 1f0da9fe8d4..12a9376d704 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -301,6 +301,160 @@ def get_runtime_metrics_context() -> RuntimeMetricsContext: return _RUNTIME_METRICS_CONTEXT +class CompileEventLogLevel(enum.Enum): + """ + Enum that loosely corresponds with a "log level" of a given event. + + CHROMIUM_EVENT: Logs only to tlparse. + COMPILE_EVENT: Logs to tlparse + PT2 Compile Events + COMPILATION_METRIC: Logs to tlparse, PT2 Compile Events, and dynamo_compile + """ + + CHROMIUM = 1 + PT2_COMPILE = 2 + COMPILATION_METRIC = 3 + + +class CompileEventLogger: + """ + Helper class for representing adding metadata(i.e. columns) to various compile events. + Use CompileEventLogger to add event data to: + - Chromium events + - PT2 Compile Events + - CompilationMetrics + + This should be used in conjunction with dynamo_timed() and metrics contexts, which create + timed spans and events. CompileEventLogger uses three log levels (described in CompileEventLogLevel), + where each log level logs to all sources below it in the hierarchy. + + Example usages: + - I want to log to an existing chromium event within dynamo timed: + with dynamo_timed("my_event"): + CompileEventLogger.chromium("my_event", foo=bar) + + - I want to log my event to both chromium + pt2_compile_events: + with dynamo_timed("my_event", log_pt2_compile_event=True): + CompileEventLogger.pt2_compile("my_event", foo=bar) + + - I want to add information to dynamo events and dynamo_compile + CompileEventLogger.compilation_metric(foo=bar) + """ + + @staticmethod + def log_instant_event( + event_name: str, + metadata: Dict[str, Any], + time_ns: Optional[int] = None, + log_level: CompileEventLogLevel = CompileEventLogLevel.CHROMIUM, + ): + if time_ns is None: + time_ns = time.time_ns() + chromium_log = get_chromium_event_logger() + if log_level == CompileEventLogLevel.CHROMIUM: + log_pt2_compile_event = False + elif log_level == CompileEventLogLevel.PT2_COMPILE: + log_pt2_compile_event = True + else: + raise RuntimeError( + "Cannot log instant event at COMPILATION_METRIC level. Please choose one of CHROMIUM_EVENT or COMPILE_EVENT" + ) + chromium_log.log_instant_event( + event_name, time_ns, metadata, log_pt2_compile_event + ) + + @staticmethod + def add_data(event_name: str, log_level: CompileEventLogLevel, **metadata: object): + """ + Centralized API for adding data to various events + Log an event to a toplevel "dynamo" event or metrics context + depending on log level. + """ + chromium_log = get_chromium_event_logger() + pt2_compile_substack = chromium_log.get_pt2_compile_substack() + + if log_level == CompileEventLogLevel.CHROMIUM: + chromium_log.add_event_data(event_name, **metadata) + elif log_level == CompileEventLogLevel.PT2_COMPILE: + pt2_compile_substack = chromium_log.get_pt2_compile_substack() + if event_name not in pt2_compile_substack: + raise RuntimeError( + "Error: specified log level PT2_COMPILE, but the event %s" + " is not logged to pt2_compile_events. Make sure the event is active and you passed " + "log_pt2_compile_event=True to dynamo_timed", + event_name, + ) + chromium_log.add_event_data(event_name, **metadata) + else: + assert log_level == CompileEventLogLevel.COMPILATION_METRIC + top_event = chromium_log.get_top() + + if event_name != top_event: + raise RuntimeError( + "Log level is COMPILATION_METRIC, but event_name isn't the toplevel event. " + "CompilationMetrics must be logged to the toplevel event. Consider using `log_toplevel_event_data` directly." + ) + metrics_context = get_metrics_context() + if not metrics_context.in_progress(): + raise RuntimeError( + "No metrics context is in progress. Please only call this function within a metrics context." + ) + + # TODO: should we assert that the keys of metadata are in CompilationMetrics? + metrics_context.update(metadata) + chromium_log.add_event_data(event_name, **metadata) + + @staticmethod + def add_toplevel(log_level: CompileEventLogLevel, **metadata: object): + """ + Syntactic sugar for logging to the toplevel event + """ + top_event = get_chromium_event_logger().get_top() + if top_event is None: + raise RuntimeError( + "No toplevel event active. Please only call this function within a dynamo_timed context." + ) + CompileEventLogger.add_data(top_event, log_level, **metadata) + + # MAIN API: These functions are syntactic sugar for the basic operations above without + # needing to use a specific log level. These are easier to use because you don't need + # to import CompileEventLogLevel to use them. + + @staticmethod + def chromium(event_name: str, **metadata: object): + CompileEventLogger.add_data( + event_name, CompileEventLogLevel.CHROMIUM, **metadata + ) + + @staticmethod + def pt2_compile(event_name: str, **metadata: object): + CompileEventLogger.add_data( + event_name, CompileEventLogLevel.PT2_COMPILE, **metadata + ) + + @staticmethod + def compilation_metric(**metadata: object): + CompileEventLogger.add_toplevel( + CompileEventLogLevel.COMPILATION_METRIC, **metadata + ) + + @staticmethod + def instant( + event_name: str, metadata: Dict[str, Any], time_ns: Optional[int] = None + ): + CompileEventLogger.log_instant_event( + event_name, metadata, time_ns, CompileEventLogLevel.CHROMIUM + ) + + @staticmethod + def try_add_pt2_compile(event_name: str, **metadata: object): + """ + Adds to an existing pt2_compile event, but silently returns if the event doesn't exist. + This function is syntactic sugar for chromium_event_logger().try_add_event_data. + """ + chromium_log = get_chromium_event_logger() + chromium_log.try_add_event_data(event_name, **metadata) + + @contextmanager def dynamo_timed( key: str, @@ -1007,6 +1161,16 @@ _compilation_metrics: Deque[CompilationMetrics] = collections.deque( def add_compilation_metrics_to_chromium(c: CompilationMetrics) -> None: + """ + These are the common fields in CompilationMetrics that existed before + metrics_context, and aren't set by MetricsContext.set(). We add the subset + of them that make sense in `dynamo`/toplevel events in PT2 Compile Events + directly. + + If you're tempted to add to this list, consider using CompileEventLogger.compilation_metric() + instead, which will automatically also add it to tlparse and PT2 Compile Events. + TODO: Get rid of this function and replace it with CompileEventLogger directly instead. + """ event_logger = get_chromium_event_logger() event_name = event_logger.get_top() if not event_name: diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index f00f09c544f..7c6bf1e4384 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch -from torch._dynamo.utils import counters, get_chromium_event_logger +from torch._dynamo.utils import CompileEventLogger, counters from torch._functorch import config from torch._inductor.codecache import ( _ident, @@ -500,14 +500,15 @@ class AOTAutogradCacheEntry: compiled_fw_func = self.compiled_fw.load(args, fx_config) compiled_bw_func = None - chromium_log = get_chromium_event_logger() if self.compiled_bw is not None: compiled_bw_func = self.compiled_bw.load(args, fx_config) needs_autograd = True - chromium_log.try_add_event_data("backend_compile", dispatch_mode="autograd") + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="autograd" + ) else: needs_autograd = False - chromium_log.try_add_event_data( + CompileEventLogger.try_add_pt2_compile( "backend_compile", dispatch_mode="inference" ) @@ -522,7 +523,7 @@ class AOTAutogradCacheEntry: ) req_subclass_dispatch = self.maybe_subclass_meta is not None - chromium_log.add_event_data( + CompileEventLogger.pt2_compile( "backend_compile", requires_subclass_dispatch=req_subclass_dispatch ) @@ -760,11 +761,12 @@ class AOTAutogradCache: "components": debug_lines, } ) - chromium_log = get_chromium_event_logger() - chromium_log.log_instant_event( - f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_info + CompileEventLogger.instant( + f"autograd_cache_{cache_state}", + metadata=cache_info, + time_ns=cache_event_time, ) - chromium_log.try_add_event_data( + CompileEventLogger.try_add_pt2_compile( "backend_compile", cache_state=cache_state, cache_event_time=cache_event_time, diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index ea2200ffd72..604d6540849 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.utils.dlpack from torch import Tensor -from torch._dynamo.utils import dynamo_timed, get_metrics_context +from torch._dynamo.utils import CompileEventLogger, dynamo_timed, get_metrics_context from torch._guards import ( compile_context, CompileContext, @@ -2047,7 +2047,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa log_pt2_compile_event=True, dynamo_compile_column_us="backward_cumulative_compile_time_us", ): - metrics_context.update_outer({"is_forward": False}) + CompileEventLogger.compilation_metric(is_forward=False) CompiledFunction.compiled_bw = aot_config.bw_compiler( bw_module, placeholder_list ) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 576cf880c6d..00ab67eabe0 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -29,8 +29,8 @@ from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompo from torch._dispatch.python import enable_python_dispatcher from torch._dynamo import compiled_autograd from torch._dynamo.utils import ( + CompileEventLogger, dynamo_timed, - get_chromium_event_logger, preserve_rng_state, set_feature_use, ) @@ -638,7 +638,7 @@ def _create_aot_dispatcher_function( python_dispatcher_mode = ( enable_python_dispatcher() if shape_env is not None else nullcontext() ) - chromium_log = get_chromium_event_logger() + # See NOTE: [Deferring tensor pack/unpack hooks until runtime] # If any saved tensor hooks are active, we **don't** want to trace them. # Instead, we'll let them run at runtime, around the custom autograd.Function @@ -692,7 +692,7 @@ def _create_aot_dispatcher_function( req_subclass_dispatch = requires_subclass_dispatch( fake_flat_args, fw_metadata ) - chromium_log.try_add_event_data( + CompileEventLogger.try_add_pt2_compile( "backend_compile", requires_subclass_dispatch=req_subclass_dispatch ) @@ -812,17 +812,17 @@ or otherwise set torch._functorch.config.functionalize_rng_ops = False.""" if aot_config.is_export: # export uses just the "graph bits", whereas the other # two dispatchers include some extra work around handling a runtime epilogue - chromium_log.try_add_event_data( + CompileEventLogger.try_add_pt2_compile( "backend_compile", dispatch_mode="export" ) return partial(aot_dispatch_export, needs_autograd=needs_autograd) elif needs_autograd and not aot_config.pre_dispatch: - chromium_log.try_add_event_data( + CompileEventLogger.try_add_pt2_compile( "backend_compile", dispatch_mode="autograd" ) return aot_dispatch_autograd else: - chromium_log.try_add_event_data( + CompileEventLogger.try_add_pt2_compile( "backend_compile", dispatch_mode="inference" ) return aot_dispatch_base diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 42755fb3a50..557bd7d8dc5 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -52,9 +52,9 @@ import torch import torch.distributed as dist from torch import SymInt, Tensor from torch._dynamo.utils import ( + CompileEventLogger, counters, dynamo_timed, - get_chromium_event_logger, get_metrics_context, ) from torch._inductor import config, exc, metrics @@ -1054,12 +1054,10 @@ class FxGraphCache: triton_bundler_meta = TritonBundler.read_and_emit(bundle) if (meta := triton_bundler_meta) is not None: cache_info["triton_bundler_meta"] = str(meta) - logger = get_chromium_event_logger() - if "inductor_compile" in logger.get_stack(): - # TODO: Clean up autograd cache integration - logger.add_event_data( - "inductor_compile", cached_kernel_names=meta.cached_kernel_names - ) + # TODO: Clean up autograd cache integration + CompileEventLogger.try_add_pt2_compile( + "inductor_compile", cached_kernel_names=meta.cached_kernel_names + ) if len(meta.cached_kernel_names) > 0: get_metrics_context().increment("num_triton_bundles", 1) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index b47e9c49010..147f6861a7a 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -44,11 +44,11 @@ from torch._dynamo import ( from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.repro.after_aot import wrap_compiler_debug from torch._dynamo.utils import ( + CompileEventLogger, counters, detect_fake_mode, dynamo_timed, flatten_graph_inputs, - get_chromium_event_logger, lazy_format_graph_code, set_feature_use, ) @@ -569,12 +569,10 @@ def compile_fx_inner( stack.enter_context(_WaitCounter("pytorch.wait_counter.dynamo_compile").guard()) stack.enter_context(with_fresh_cache_if_config()) stack.enter_context(DebugContext()) - - get_chromium_event_logger().add_event_data( + CompileEventLogger.pt2_compile( "inductor_compile", is_backward=kwargs["is_backward"], ) - return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")( gm, example_inputs, @@ -736,7 +734,6 @@ def _compile_fx_inner( # and a tlparse log for every cache action. # In the event of a bypass, we also logged to the remote table earlier # with log_cache_bypass. - chromium_log = get_chromium_event_logger() cache_state = ( cache_info["cache_state"] if cache_info is not None else "disabled" ) @@ -745,14 +742,14 @@ def _compile_fx_inner( # fx_graph_cache_miss # fx_graph_cache_bypass # fx_graph_cache_disabled - chromium_log.log_instant_event( + CompileEventLogger.instant( f"fx_graph_cache_{cache_state}", - start_time, - metadata=cache_info, + metadata=cache_info or {}, + time_ns=start_time, ) # Add event data about cache hits/miss # TODO: add remote cache get/put timings here too - chromium_log.add_event_data( + CompileEventLogger.pt2_compile( "inductor_compile", cache_state=cache_state, cache_event_time=start_time,