diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 3bbf73a7208..ab6311909e5 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -1148,6 +1148,9 @@ def _compile( dynamo_time_before_restart ), } + # TODO: replace with CompileEventLogger.compilation_metrics + # There are some columns here not in PT2 Compile Events + # so we need to slightly change it metrics_context.update_outer(metrics) # === END WARNING WARNING WARNING === diff --git a/torch/_dynamo/pgo.py b/torch/_dynamo/pgo.py index 42eb96cd082..996f008164e 100644 --- a/torch/_dynamo/pgo.py +++ b/torch/_dynamo/pgo.py @@ -7,7 +7,6 @@ import enum import logging import os import pickle -import time from collections import defaultdict from typing import DefaultDict, Optional, Tuple, TYPE_CHECKING, TypeVar, Union from typing_extensions import Self @@ -17,8 +16,8 @@ import torch._utils_internal import torch.compiler.config import torch.distributed as dist from torch._dynamo.utils import ( + CompileEventLogger, dynamo_timed, - get_chromium_event_logger, set_feature_use, warn_once, ) @@ -327,9 +326,8 @@ def update_automatic_dynamic( entry.scalar, old_entry.scalar, ) - get_chromium_event_logger().log_instant_event( + CompileEventLogger.instant( "automatic_dynamic", - time.time_ns(), { "name": name, "dim_changed": "scalar", @@ -366,9 +364,8 @@ def update_automatic_dynamic( entry_tup, old_entry_tup, ) - get_chromium_event_logger().log_instant_event( + CompileEventLogger.instant( "automatic_dynamic", - time.time_ns(), { "name": name, "dim_changed": "all" if i is None else i, @@ -546,8 +543,6 @@ def get_code_state() -> DefaultDict[CodeId, CodeState]: if _CODE_STATE is not None: return _CODE_STATE - chromium_log = get_chromium_event_logger() - # Initialize it (even if we don't look up profile) _CODE_STATE = defaultdict(CodeState) @@ -574,13 +569,13 @@ def get_code_state() -> DefaultDict[CodeId, CodeState]: with dynamo_timed( name := "pgo.get_local_code_state", log_pt2_compile_event=True ): - chromium_log.add_event_data(name, cache_key=cache_key) + CompileEventLogger.pt2_compile(name, cache_key=cache_key) # Read lock not necessary as we always write atomically write to # the actual location with open(path, "rb") as f: try: _CODE_STATE = pickle.load(f) - chromium_log.add_event_data(name, cache_size_bytes=f.tell()) + CompileEventLogger.pt2_compile(name, cache_size_bytes=f.tell()) except Exception: log.warning( "get_code_state failed while reading %s", path, exc_info=True @@ -594,7 +589,7 @@ def get_code_state() -> DefaultDict[CodeId, CodeState]: with dynamo_timed( name := "pgo.get_remote_code_state", log_pt2_compile_event=True ): - chromium_log.add_event_data(name, cache_key=cache_key) + CompileEventLogger.pt2_compile(name, cache_key=cache_key) # TODO: I don't really understand why there's a JSON container format try: cache_data = remote_cache.get(cache_key) @@ -609,7 +604,9 @@ def get_code_state() -> DefaultDict[CodeId, CodeState]: data = cache_data["data"] assert isinstance(data, str) payload = base64.b64decode(data) - chromium_log.add_event_data(name, cache_size_bytes=len(payload)) + CompileEventLogger.pt2_compile( + name, cache_size_bytes=len(payload) + ) _CODE_STATE = pickle.loads(payload) except Exception: log.warning( @@ -648,8 +645,7 @@ def put_code_state() -> None: def put_local_code_state(cache_key: str) -> None: with dynamo_timed(name := "pgo.put_local_code_state", log_pt2_compile_event=True): - chromium_log = get_chromium_event_logger() - chromium_log.add_event_data(name, cache_key=cache_key) + CompileEventLogger.pt2_compile(name, cache_key=cache_key) assert _CODE_STATE is not None path = code_state_path(cache_key) @@ -672,7 +668,7 @@ def put_local_code_state(cache_key: str) -> None: with FileLock(lock_path, timeout=LOCK_TIMEOUT): with open(tmp_path, "wb") as f: pickle.dump(_CODE_STATE, f) - chromium_log.add_event_data(name, cache_size_bytes=f.tell()) + CompileEventLogger.pt2_compile(name, cache_size_bytes=f.tell()) os.rename(tmp_path, path) log.info( "put_code_state: wrote local %s, %d entries", path, len(_CODE_STATE) @@ -686,8 +682,7 @@ def put_local_code_state(cache_key: str) -> None: def put_remote_code_state(cache_key: str) -> None: with dynamo_timed(name := "pgo.put_remote_code_state", log_pt2_compile_event=True): - chromium_log = get_chromium_event_logger() - chromium_log.add_event_data(name, cache_key=cache_key) + CompileEventLogger.pt2_compile(name, cache_key=cache_key) assert _CODE_STATE is not None remote_cache = get_remote_cache() @@ -697,7 +692,7 @@ def put_remote_code_state(cache_key: str) -> None: return content = pickle.dumps(_CODE_STATE) - chromium_log.add_event_data(name, cache_size_bytes=len(content)) + CompileEventLogger.pt2_compile(name, cache_size_bytes=len(content)) cache_data: JsonDataTy = { "data": base64.b64encode(content).decode("ascii"), } diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 1f0da9fe8d4..12a9376d704 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -301,6 +301,160 @@ def get_runtime_metrics_context() -> RuntimeMetricsContext: return _RUNTIME_METRICS_CONTEXT +class CompileEventLogLevel(enum.Enum): + """ + Enum that loosely corresponds with a "log level" of a given event. + + CHROMIUM_EVENT: Logs only to tlparse. + COMPILE_EVENT: Logs to tlparse + PT2 Compile Events + COMPILATION_METRIC: Logs to tlparse, PT2 Compile Events, and dynamo_compile + """ + + CHROMIUM = 1 + PT2_COMPILE = 2 + COMPILATION_METRIC = 3 + + +class CompileEventLogger: + """ + Helper class for representing adding metadata(i.e. columns) to various compile events. + Use CompileEventLogger to add event data to: + - Chromium events + - PT2 Compile Events + - CompilationMetrics + + This should be used in conjunction with dynamo_timed() and metrics contexts, which create + timed spans and events. CompileEventLogger uses three log levels (described in CompileEventLogLevel), + where each log level logs to all sources below it in the hierarchy. + + Example usages: + - I want to log to an existing chromium event within dynamo timed: + with dynamo_timed("my_event"): + CompileEventLogger.chromium("my_event", foo=bar) + + - I want to log my event to both chromium + pt2_compile_events: + with dynamo_timed("my_event", log_pt2_compile_event=True): + CompileEventLogger.pt2_compile("my_event", foo=bar) + + - I want to add information to dynamo events and dynamo_compile + CompileEventLogger.compilation_metric(foo=bar) + """ + + @staticmethod + def log_instant_event( + event_name: str, + metadata: Dict[str, Any], + time_ns: Optional[int] = None, + log_level: CompileEventLogLevel = CompileEventLogLevel.CHROMIUM, + ): + if time_ns is None: + time_ns = time.time_ns() + chromium_log = get_chromium_event_logger() + if log_level == CompileEventLogLevel.CHROMIUM: + log_pt2_compile_event = False + elif log_level == CompileEventLogLevel.PT2_COMPILE: + log_pt2_compile_event = True + else: + raise RuntimeError( + "Cannot log instant event at COMPILATION_METRIC level. Please choose one of CHROMIUM_EVENT or COMPILE_EVENT" + ) + chromium_log.log_instant_event( + event_name, time_ns, metadata, log_pt2_compile_event + ) + + @staticmethod + def add_data(event_name: str, log_level: CompileEventLogLevel, **metadata: object): + """ + Centralized API for adding data to various events + Log an event to a toplevel "dynamo" event or metrics context + depending on log level. + """ + chromium_log = get_chromium_event_logger() + pt2_compile_substack = chromium_log.get_pt2_compile_substack() + + if log_level == CompileEventLogLevel.CHROMIUM: + chromium_log.add_event_data(event_name, **metadata) + elif log_level == CompileEventLogLevel.PT2_COMPILE: + pt2_compile_substack = chromium_log.get_pt2_compile_substack() + if event_name not in pt2_compile_substack: + raise RuntimeError( + "Error: specified log level PT2_COMPILE, but the event %s" + " is not logged to pt2_compile_events. Make sure the event is active and you passed " + "log_pt2_compile_event=True to dynamo_timed", + event_name, + ) + chromium_log.add_event_data(event_name, **metadata) + else: + assert log_level == CompileEventLogLevel.COMPILATION_METRIC + top_event = chromium_log.get_top() + + if event_name != top_event: + raise RuntimeError( + "Log level is COMPILATION_METRIC, but event_name isn't the toplevel event. " + "CompilationMetrics must be logged to the toplevel event. Consider using `log_toplevel_event_data` directly." + ) + metrics_context = get_metrics_context() + if not metrics_context.in_progress(): + raise RuntimeError( + "No metrics context is in progress. Please only call this function within a metrics context." + ) + + # TODO: should we assert that the keys of metadata are in CompilationMetrics? + metrics_context.update(metadata) + chromium_log.add_event_data(event_name, **metadata) + + @staticmethod + def add_toplevel(log_level: CompileEventLogLevel, **metadata: object): + """ + Syntactic sugar for logging to the toplevel event + """ + top_event = get_chromium_event_logger().get_top() + if top_event is None: + raise RuntimeError( + "No toplevel event active. Please only call this function within a dynamo_timed context." + ) + CompileEventLogger.add_data(top_event, log_level, **metadata) + + # MAIN API: These functions are syntactic sugar for the basic operations above without + # needing to use a specific log level. These are easier to use because you don't need + # to import CompileEventLogLevel to use them. + + @staticmethod + def chromium(event_name: str, **metadata: object): + CompileEventLogger.add_data( + event_name, CompileEventLogLevel.CHROMIUM, **metadata + ) + + @staticmethod + def pt2_compile(event_name: str, **metadata: object): + CompileEventLogger.add_data( + event_name, CompileEventLogLevel.PT2_COMPILE, **metadata + ) + + @staticmethod + def compilation_metric(**metadata: object): + CompileEventLogger.add_toplevel( + CompileEventLogLevel.COMPILATION_METRIC, **metadata + ) + + @staticmethod + def instant( + event_name: str, metadata: Dict[str, Any], time_ns: Optional[int] = None + ): + CompileEventLogger.log_instant_event( + event_name, metadata, time_ns, CompileEventLogLevel.CHROMIUM + ) + + @staticmethod + def try_add_pt2_compile(event_name: str, **metadata: object): + """ + Adds to an existing pt2_compile event, but silently returns if the event doesn't exist. + This function is syntactic sugar for chromium_event_logger().try_add_event_data. + """ + chromium_log = get_chromium_event_logger() + chromium_log.try_add_event_data(event_name, **metadata) + + @contextmanager def dynamo_timed( key: str, @@ -1007,6 +1161,16 @@ _compilation_metrics: Deque[CompilationMetrics] = collections.deque( def add_compilation_metrics_to_chromium(c: CompilationMetrics) -> None: + """ + These are the common fields in CompilationMetrics that existed before + metrics_context, and aren't set by MetricsContext.set(). We add the subset + of them that make sense in `dynamo`/toplevel events in PT2 Compile Events + directly. + + If you're tempted to add to this list, consider using CompileEventLogger.compilation_metric() + instead, which will automatically also add it to tlparse and PT2 Compile Events. + TODO: Get rid of this function and replace it with CompileEventLogger directly instead. + """ event_logger = get_chromium_event_logger() event_name = event_logger.get_top() if not event_name: diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index f00f09c544f..7c6bf1e4384 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch -from torch._dynamo.utils import counters, get_chromium_event_logger +from torch._dynamo.utils import CompileEventLogger, counters from torch._functorch import config from torch._inductor.codecache import ( _ident, @@ -500,14 +500,15 @@ class AOTAutogradCacheEntry: compiled_fw_func = self.compiled_fw.load(args, fx_config) compiled_bw_func = None - chromium_log = get_chromium_event_logger() if self.compiled_bw is not None: compiled_bw_func = self.compiled_bw.load(args, fx_config) needs_autograd = True - chromium_log.try_add_event_data("backend_compile", dispatch_mode="autograd") + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="autograd" + ) else: needs_autograd = False - chromium_log.try_add_event_data( + CompileEventLogger.try_add_pt2_compile( "backend_compile", dispatch_mode="inference" ) @@ -522,7 +523,7 @@ class AOTAutogradCacheEntry: ) req_subclass_dispatch = self.maybe_subclass_meta is not None - chromium_log.add_event_data( + CompileEventLogger.pt2_compile( "backend_compile", requires_subclass_dispatch=req_subclass_dispatch ) @@ -760,11 +761,12 @@ class AOTAutogradCache: "components": debug_lines, } ) - chromium_log = get_chromium_event_logger() - chromium_log.log_instant_event( - f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_info + CompileEventLogger.instant( + f"autograd_cache_{cache_state}", + metadata=cache_info, + time_ns=cache_event_time, ) - chromium_log.try_add_event_data( + CompileEventLogger.try_add_pt2_compile( "backend_compile", cache_state=cache_state, cache_event_time=cache_event_time, diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index ea2200ffd72..604d6540849 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.utils.dlpack from torch import Tensor -from torch._dynamo.utils import dynamo_timed, get_metrics_context +from torch._dynamo.utils import CompileEventLogger, dynamo_timed, get_metrics_context from torch._guards import ( compile_context, CompileContext, @@ -2047,7 +2047,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa log_pt2_compile_event=True, dynamo_compile_column_us="backward_cumulative_compile_time_us", ): - metrics_context.update_outer({"is_forward": False}) + CompileEventLogger.compilation_metric(is_forward=False) CompiledFunction.compiled_bw = aot_config.bw_compiler( bw_module, placeholder_list ) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 576cf880c6d..00ab67eabe0 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -29,8 +29,8 @@ from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompo from torch._dispatch.python import enable_python_dispatcher from torch._dynamo import compiled_autograd from torch._dynamo.utils import ( + CompileEventLogger, dynamo_timed, - get_chromium_event_logger, preserve_rng_state, set_feature_use, ) @@ -638,7 +638,7 @@ def _create_aot_dispatcher_function( python_dispatcher_mode = ( enable_python_dispatcher() if shape_env is not None else nullcontext() ) - chromium_log = get_chromium_event_logger() + # See NOTE: [Deferring tensor pack/unpack hooks until runtime] # If any saved tensor hooks are active, we **don't** want to trace them. # Instead, we'll let them run at runtime, around the custom autograd.Function @@ -692,7 +692,7 @@ def _create_aot_dispatcher_function( req_subclass_dispatch = requires_subclass_dispatch( fake_flat_args, fw_metadata ) - chromium_log.try_add_event_data( + CompileEventLogger.try_add_pt2_compile( "backend_compile", requires_subclass_dispatch=req_subclass_dispatch ) @@ -812,17 +812,17 @@ or otherwise set torch._functorch.config.functionalize_rng_ops = False.""" if aot_config.is_export: # export uses just the "graph bits", whereas the other # two dispatchers include some extra work around handling a runtime epilogue - chromium_log.try_add_event_data( + CompileEventLogger.try_add_pt2_compile( "backend_compile", dispatch_mode="export" ) return partial(aot_dispatch_export, needs_autograd=needs_autograd) elif needs_autograd and not aot_config.pre_dispatch: - chromium_log.try_add_event_data( + CompileEventLogger.try_add_pt2_compile( "backend_compile", dispatch_mode="autograd" ) return aot_dispatch_autograd else: - chromium_log.try_add_event_data( + CompileEventLogger.try_add_pt2_compile( "backend_compile", dispatch_mode="inference" ) return aot_dispatch_base diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 42755fb3a50..557bd7d8dc5 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -52,9 +52,9 @@ import torch import torch.distributed as dist from torch import SymInt, Tensor from torch._dynamo.utils import ( + CompileEventLogger, counters, dynamo_timed, - get_chromium_event_logger, get_metrics_context, ) from torch._inductor import config, exc, metrics @@ -1054,12 +1054,10 @@ class FxGraphCache: triton_bundler_meta = TritonBundler.read_and_emit(bundle) if (meta := triton_bundler_meta) is not None: cache_info["triton_bundler_meta"] = str(meta) - logger = get_chromium_event_logger() - if "inductor_compile" in logger.get_stack(): - # TODO: Clean up autograd cache integration - logger.add_event_data( - "inductor_compile", cached_kernel_names=meta.cached_kernel_names - ) + # TODO: Clean up autograd cache integration + CompileEventLogger.try_add_pt2_compile( + "inductor_compile", cached_kernel_names=meta.cached_kernel_names + ) if len(meta.cached_kernel_names) > 0: get_metrics_context().increment("num_triton_bundles", 1) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index b47e9c49010..147f6861a7a 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -44,11 +44,11 @@ from torch._dynamo import ( from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.repro.after_aot import wrap_compiler_debug from torch._dynamo.utils import ( + CompileEventLogger, counters, detect_fake_mode, dynamo_timed, flatten_graph_inputs, - get_chromium_event_logger, lazy_format_graph_code, set_feature_use, ) @@ -569,12 +569,10 @@ def compile_fx_inner( stack.enter_context(_WaitCounter("pytorch.wait_counter.dynamo_compile").guard()) stack.enter_context(with_fresh_cache_if_config()) stack.enter_context(DebugContext()) - - get_chromium_event_logger().add_event_data( + CompileEventLogger.pt2_compile( "inductor_compile", is_backward=kwargs["is_backward"], ) - return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")( gm, example_inputs, @@ -736,7 +734,6 @@ def _compile_fx_inner( # and a tlparse log for every cache action. # In the event of a bypass, we also logged to the remote table earlier # with log_cache_bypass. - chromium_log = get_chromium_event_logger() cache_state = ( cache_info["cache_state"] if cache_info is not None else "disabled" ) @@ -745,14 +742,14 @@ def _compile_fx_inner( # fx_graph_cache_miss # fx_graph_cache_bypass # fx_graph_cache_disabled - chromium_log.log_instant_event( + CompileEventLogger.instant( f"fx_graph_cache_{cache_state}", - start_time, - metadata=cache_info, + metadata=cache_info or {}, + time_ns=start_time, ) # Add event data about cache hits/miss # TODO: add remote cache get/put timings here too - chromium_log.add_event_data( + CompileEventLogger.pt2_compile( "inductor_compile", cache_state=cache_state, cache_event_time=start_time,