Introduce CompileEventLogger, replace usages of metrics_context and chromium_event with it (#143420)

**Problem statement**: I want to be able to centralize and simplify the process by which people add columns/data to existing spans. We have MetricsContext and ChromiumEventLogger, and there's various choices you can make to decide where and when to log different levels of observability for your events. To resolve this, I want a central API for "adding to events under dynamo_timed".

**CompileEventLogger** is intended as a frontend for MetricsContext and ChromiumEventLogger so we can use the same class for handling everything.

CompileEventLogger is intended be used within a `dynamo_timed()` context. Its purpose is to 1. log to existing events that are in progress (i.e. within dynamo_timed), and 2. log instant events to chromium that are independent of any specific span.

CompileEventLogger has three log levels:

- CHROMIUM: Log only to chromium events, visible via tlparse.
- PT2_COMPILE: Log to chromium_events + pt2_compile_events
- COMPILATION_METRIC: Log to compilation metrics in addition to the toplevel chromium and pt2_compile_event.

In addition, we have a function CompileEventLogger.add() that automagically chooses the correct log level. For now, it is conservative, and will never automagically choose to log CompilationMetrics (though I could imagine it figuring out the metadata are all keys in CompilationMetric and therefore loggable there).

The goal here is to make one single interface to log stuff for observability reasons, and make it as easy as possible.

Not included in this diff:
- V1 of this diff will not have implementations of `increment` and `add_to_set` which MetricsContext has, so those usages are not replaced yet. But I'll add those in a followup.

- We don't handle `RuntimeMetricsContext`. It's unclear if I want that to be part of this, because under RuntimeMetricsContext there might not be a toplevel event to log to, so chromium events doesn't make sense in that context. So I might leave that separate for now.

Differential Revision: [D67346203](https://our.internmc.facebook.com/intern/diff/D67346203/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143420
Approved by: https://github.com/aorenste
This commit is contained in:
James Wu 2025-01-03 09:17:37 -08:00 committed by PyTorch MergeBot
parent 68d30c6a25
commit f2d6cfa677
8 changed files with 210 additions and 51 deletions

View file

@ -1148,6 +1148,9 @@ def _compile(
dynamo_time_before_restart
),
}
# TODO: replace with CompileEventLogger.compilation_metrics
# There are some columns here not in PT2 Compile Events
# so we need to slightly change it
metrics_context.update_outer(metrics)
# === END WARNING WARNING WARNING ===

View file

@ -7,7 +7,6 @@ import enum
import logging
import os
import pickle
import time
from collections import defaultdict
from typing import DefaultDict, Optional, Tuple, TYPE_CHECKING, TypeVar, Union
from typing_extensions import Self
@ -17,8 +16,8 @@ import torch._utils_internal
import torch.compiler.config
import torch.distributed as dist
from torch._dynamo.utils import (
CompileEventLogger,
dynamo_timed,
get_chromium_event_logger,
set_feature_use,
warn_once,
)
@ -327,9 +326,8 @@ def update_automatic_dynamic(
entry.scalar,
old_entry.scalar,
)
get_chromium_event_logger().log_instant_event(
CompileEventLogger.instant(
"automatic_dynamic",
time.time_ns(),
{
"name": name,
"dim_changed": "scalar",
@ -366,9 +364,8 @@ def update_automatic_dynamic(
entry_tup,
old_entry_tup,
)
get_chromium_event_logger().log_instant_event(
CompileEventLogger.instant(
"automatic_dynamic",
time.time_ns(),
{
"name": name,
"dim_changed": "all" if i is None else i,
@ -546,8 +543,6 @@ def get_code_state() -> DefaultDict[CodeId, CodeState]:
if _CODE_STATE is not None:
return _CODE_STATE
chromium_log = get_chromium_event_logger()
# Initialize it (even if we don't look up profile)
_CODE_STATE = defaultdict(CodeState)
@ -574,13 +569,13 @@ def get_code_state() -> DefaultDict[CodeId, CodeState]:
with dynamo_timed(
name := "pgo.get_local_code_state", log_pt2_compile_event=True
):
chromium_log.add_event_data(name, cache_key=cache_key)
CompileEventLogger.pt2_compile(name, cache_key=cache_key)
# Read lock not necessary as we always write atomically write to
# the actual location
with open(path, "rb") as f:
try:
_CODE_STATE = pickle.load(f)
chromium_log.add_event_data(name, cache_size_bytes=f.tell())
CompileEventLogger.pt2_compile(name, cache_size_bytes=f.tell())
except Exception:
log.warning(
"get_code_state failed while reading %s", path, exc_info=True
@ -594,7 +589,7 @@ def get_code_state() -> DefaultDict[CodeId, CodeState]:
with dynamo_timed(
name := "pgo.get_remote_code_state", log_pt2_compile_event=True
):
chromium_log.add_event_data(name, cache_key=cache_key)
CompileEventLogger.pt2_compile(name, cache_key=cache_key)
# TODO: I don't really understand why there's a JSON container format
try:
cache_data = remote_cache.get(cache_key)
@ -609,7 +604,9 @@ def get_code_state() -> DefaultDict[CodeId, CodeState]:
data = cache_data["data"]
assert isinstance(data, str)
payload = base64.b64decode(data)
chromium_log.add_event_data(name, cache_size_bytes=len(payload))
CompileEventLogger.pt2_compile(
name, cache_size_bytes=len(payload)
)
_CODE_STATE = pickle.loads(payload)
except Exception:
log.warning(
@ -648,8 +645,7 @@ def put_code_state() -> None:
def put_local_code_state(cache_key: str) -> None:
with dynamo_timed(name := "pgo.put_local_code_state", log_pt2_compile_event=True):
chromium_log = get_chromium_event_logger()
chromium_log.add_event_data(name, cache_key=cache_key)
CompileEventLogger.pt2_compile(name, cache_key=cache_key)
assert _CODE_STATE is not None
path = code_state_path(cache_key)
@ -672,7 +668,7 @@ def put_local_code_state(cache_key: str) -> None:
with FileLock(lock_path, timeout=LOCK_TIMEOUT):
with open(tmp_path, "wb") as f:
pickle.dump(_CODE_STATE, f)
chromium_log.add_event_data(name, cache_size_bytes=f.tell())
CompileEventLogger.pt2_compile(name, cache_size_bytes=f.tell())
os.rename(tmp_path, path)
log.info(
"put_code_state: wrote local %s, %d entries", path, len(_CODE_STATE)
@ -686,8 +682,7 @@ def put_local_code_state(cache_key: str) -> None:
def put_remote_code_state(cache_key: str) -> None:
with dynamo_timed(name := "pgo.put_remote_code_state", log_pt2_compile_event=True):
chromium_log = get_chromium_event_logger()
chromium_log.add_event_data(name, cache_key=cache_key)
CompileEventLogger.pt2_compile(name, cache_key=cache_key)
assert _CODE_STATE is not None
remote_cache = get_remote_cache()
@ -697,7 +692,7 @@ def put_remote_code_state(cache_key: str) -> None:
return
content = pickle.dumps(_CODE_STATE)
chromium_log.add_event_data(name, cache_size_bytes=len(content))
CompileEventLogger.pt2_compile(name, cache_size_bytes=len(content))
cache_data: JsonDataTy = {
"data": base64.b64encode(content).decode("ascii"),
}

View file

@ -301,6 +301,160 @@ def get_runtime_metrics_context() -> RuntimeMetricsContext:
return _RUNTIME_METRICS_CONTEXT
class CompileEventLogLevel(enum.Enum):
"""
Enum that loosely corresponds with a "log level" of a given event.
CHROMIUM_EVENT: Logs only to tlparse.
COMPILE_EVENT: Logs to tlparse + PT2 Compile Events
COMPILATION_METRIC: Logs to tlparse, PT2 Compile Events, and dynamo_compile
"""
CHROMIUM = 1
PT2_COMPILE = 2
COMPILATION_METRIC = 3
class CompileEventLogger:
"""
Helper class for representing adding metadata(i.e. columns) to various compile events.
Use CompileEventLogger to add event data to:
- Chromium events
- PT2 Compile Events
- CompilationMetrics
This should be used in conjunction with dynamo_timed() and metrics contexts, which create
timed spans and events. CompileEventLogger uses three log levels (described in CompileEventLogLevel),
where each log level logs to all sources below it in the hierarchy.
Example usages:
- I want to log to an existing chromium event within dynamo timed:
with dynamo_timed("my_event"):
CompileEventLogger.chromium("my_event", foo=bar)
- I want to log my event to both chromium + pt2_compile_events:
with dynamo_timed("my_event", log_pt2_compile_event=True):
CompileEventLogger.pt2_compile("my_event", foo=bar)
- I want to add information to dynamo events and dynamo_compile
CompileEventLogger.compilation_metric(foo=bar)
"""
@staticmethod
def log_instant_event(
event_name: str,
metadata: Dict[str, Any],
time_ns: Optional[int] = None,
log_level: CompileEventLogLevel = CompileEventLogLevel.CHROMIUM,
):
if time_ns is None:
time_ns = time.time_ns()
chromium_log = get_chromium_event_logger()
if log_level == CompileEventLogLevel.CHROMIUM:
log_pt2_compile_event = False
elif log_level == CompileEventLogLevel.PT2_COMPILE:
log_pt2_compile_event = True
else:
raise RuntimeError(
"Cannot log instant event at COMPILATION_METRIC level. Please choose one of CHROMIUM_EVENT or COMPILE_EVENT"
)
chromium_log.log_instant_event(
event_name, time_ns, metadata, log_pt2_compile_event
)
@staticmethod
def add_data(event_name: str, log_level: CompileEventLogLevel, **metadata: object):
"""
Centralized API for adding data to various events
Log an event to a toplevel "dynamo" event or metrics context
depending on log level.
"""
chromium_log = get_chromium_event_logger()
pt2_compile_substack = chromium_log.get_pt2_compile_substack()
if log_level == CompileEventLogLevel.CHROMIUM:
chromium_log.add_event_data(event_name, **metadata)
elif log_level == CompileEventLogLevel.PT2_COMPILE:
pt2_compile_substack = chromium_log.get_pt2_compile_substack()
if event_name not in pt2_compile_substack:
raise RuntimeError(
"Error: specified log level PT2_COMPILE, but the event %s"
" is not logged to pt2_compile_events. Make sure the event is active and you passed "
"log_pt2_compile_event=True to dynamo_timed",
event_name,
)
chromium_log.add_event_data(event_name, **metadata)
else:
assert log_level == CompileEventLogLevel.COMPILATION_METRIC
top_event = chromium_log.get_top()
if event_name != top_event:
raise RuntimeError(
"Log level is COMPILATION_METRIC, but event_name isn't the toplevel event. "
"CompilationMetrics must be logged to the toplevel event. Consider using `log_toplevel_event_data` directly."
)
metrics_context = get_metrics_context()
if not metrics_context.in_progress():
raise RuntimeError(
"No metrics context is in progress. Please only call this function within a metrics context."
)
# TODO: should we assert that the keys of metadata are in CompilationMetrics?
metrics_context.update(metadata)
chromium_log.add_event_data(event_name, **metadata)
@staticmethod
def add_toplevel(log_level: CompileEventLogLevel, **metadata: object):
"""
Syntactic sugar for logging to the toplevel event
"""
top_event = get_chromium_event_logger().get_top()
if top_event is None:
raise RuntimeError(
"No toplevel event active. Please only call this function within a dynamo_timed context."
)
CompileEventLogger.add_data(top_event, log_level, **metadata)
# MAIN API: These functions are syntactic sugar for the basic operations above without
# needing to use a specific log level. These are easier to use because you don't need
# to import CompileEventLogLevel to use them.
@staticmethod
def chromium(event_name: str, **metadata: object):
CompileEventLogger.add_data(
event_name, CompileEventLogLevel.CHROMIUM, **metadata
)
@staticmethod
def pt2_compile(event_name: str, **metadata: object):
CompileEventLogger.add_data(
event_name, CompileEventLogLevel.PT2_COMPILE, **metadata
)
@staticmethod
def compilation_metric(**metadata: object):
CompileEventLogger.add_toplevel(
CompileEventLogLevel.COMPILATION_METRIC, **metadata
)
@staticmethod
def instant(
event_name: str, metadata: Dict[str, Any], time_ns: Optional[int] = None
):
CompileEventLogger.log_instant_event(
event_name, metadata, time_ns, CompileEventLogLevel.CHROMIUM
)
@staticmethod
def try_add_pt2_compile(event_name: str, **metadata: object):
"""
Adds to an existing pt2_compile event, but silently returns if the event doesn't exist.
This function is syntactic sugar for chromium_event_logger().try_add_event_data.
"""
chromium_log = get_chromium_event_logger()
chromium_log.try_add_event_data(event_name, **metadata)
@contextmanager
def dynamo_timed(
key: str,
@ -1007,6 +1161,16 @@ _compilation_metrics: Deque[CompilationMetrics] = collections.deque(
def add_compilation_metrics_to_chromium(c: CompilationMetrics) -> None:
"""
These are the common fields in CompilationMetrics that existed before
metrics_context, and aren't set by MetricsContext.set(). We add the subset
of them that make sense in `dynamo`/toplevel events in PT2 Compile Events
directly.
If you're tempted to add to this list, consider using CompileEventLogger.compilation_metric()
instead, which will automatically also add it to tlparse and PT2 Compile Events.
TODO: Get rid of this function and replace it with CompileEventLogger directly instead.
"""
event_logger = get_chromium_event_logger()
event_name = event_logger.get_top()
if not event_name:

View file

@ -17,7 +17,7 @@ from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
from torch._dynamo.utils import counters, get_chromium_event_logger
from torch._dynamo.utils import CompileEventLogger, counters
from torch._functorch import config
from torch._inductor.codecache import (
_ident,
@ -500,14 +500,15 @@ class AOTAutogradCacheEntry:
compiled_fw_func = self.compiled_fw.load(args, fx_config)
compiled_bw_func = None
chromium_log = get_chromium_event_logger()
if self.compiled_bw is not None:
compiled_bw_func = self.compiled_bw.load(args, fx_config)
needs_autograd = True
chromium_log.try_add_event_data("backend_compile", dispatch_mode="autograd")
CompileEventLogger.try_add_pt2_compile(
"backend_compile", dispatch_mode="autograd"
)
else:
needs_autograd = False
chromium_log.try_add_event_data(
CompileEventLogger.try_add_pt2_compile(
"backend_compile", dispatch_mode="inference"
)
@ -522,7 +523,7 @@ class AOTAutogradCacheEntry:
)
req_subclass_dispatch = self.maybe_subclass_meta is not None
chromium_log.add_event_data(
CompileEventLogger.pt2_compile(
"backend_compile", requires_subclass_dispatch=req_subclass_dispatch
)
@ -760,11 +761,12 @@ class AOTAutogradCache:
"components": debug_lines,
}
)
chromium_log = get_chromium_event_logger()
chromium_log.log_instant_event(
f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_info
CompileEventLogger.instant(
f"autograd_cache_{cache_state}",
metadata=cache_info,
time_ns=cache_event_time,
)
chromium_log.try_add_event_data(
CompileEventLogger.try_add_pt2_compile(
"backend_compile",
cache_state=cache_state,
cache_event_time=cache_event_time,

View file

@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.utils.dlpack
from torch import Tensor
from torch._dynamo.utils import dynamo_timed, get_metrics_context
from torch._dynamo.utils import CompileEventLogger, dynamo_timed, get_metrics_context
from torch._guards import (
compile_context,
CompileContext,
@ -2047,7 +2047,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
log_pt2_compile_event=True,
dynamo_compile_column_us="backward_cumulative_compile_time_us",
):
metrics_context.update_outer({"is_forward": False})
CompileEventLogger.compilation_metric(is_forward=False)
CompiledFunction.compiled_bw = aot_config.bw_compiler(
bw_module, placeholder_list
)

View file

@ -29,8 +29,8 @@ from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompo
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo import compiled_autograd
from torch._dynamo.utils import (
CompileEventLogger,
dynamo_timed,
get_chromium_event_logger,
preserve_rng_state,
set_feature_use,
)
@ -638,7 +638,7 @@ def _create_aot_dispatcher_function(
python_dispatcher_mode = (
enable_python_dispatcher() if shape_env is not None else nullcontext()
)
chromium_log = get_chromium_event_logger()
# See NOTE: [Deferring tensor pack/unpack hooks until runtime]
# If any saved tensor hooks are active, we **don't** want to trace them.
# Instead, we'll let them run at runtime, around the custom autograd.Function
@ -692,7 +692,7 @@ def _create_aot_dispatcher_function(
req_subclass_dispatch = requires_subclass_dispatch(
fake_flat_args, fw_metadata
)
chromium_log.try_add_event_data(
CompileEventLogger.try_add_pt2_compile(
"backend_compile", requires_subclass_dispatch=req_subclass_dispatch
)
@ -812,17 +812,17 @@ or otherwise set torch._functorch.config.functionalize_rng_ops = False."""
if aot_config.is_export:
# export uses just the "graph bits", whereas the other
# two dispatchers include some extra work around handling a runtime epilogue
chromium_log.try_add_event_data(
CompileEventLogger.try_add_pt2_compile(
"backend_compile", dispatch_mode="export"
)
return partial(aot_dispatch_export, needs_autograd=needs_autograd)
elif needs_autograd and not aot_config.pre_dispatch:
chromium_log.try_add_event_data(
CompileEventLogger.try_add_pt2_compile(
"backend_compile", dispatch_mode="autograd"
)
return aot_dispatch_autograd
else:
chromium_log.try_add_event_data(
CompileEventLogger.try_add_pt2_compile(
"backend_compile", dispatch_mode="inference"
)
return aot_dispatch_base

View file

@ -52,9 +52,9 @@ import torch
import torch.distributed as dist
from torch import SymInt, Tensor
from torch._dynamo.utils import (
CompileEventLogger,
counters,
dynamo_timed,
get_chromium_event_logger,
get_metrics_context,
)
from torch._inductor import config, exc, metrics
@ -1054,12 +1054,10 @@ class FxGraphCache:
triton_bundler_meta = TritonBundler.read_and_emit(bundle)
if (meta := triton_bundler_meta) is not None:
cache_info["triton_bundler_meta"] = str(meta)
logger = get_chromium_event_logger()
if "inductor_compile" in logger.get_stack():
# TODO: Clean up autograd cache integration
logger.add_event_data(
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
)
# TODO: Clean up autograd cache integration
CompileEventLogger.try_add_pt2_compile(
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
)
if len(meta.cached_kernel_names) > 0:
get_metrics_context().increment("num_triton_bundles", 1)

View file

@ -44,11 +44,11 @@ from torch._dynamo import (
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.repro.after_aot import wrap_compiler_debug
from torch._dynamo.utils import (
CompileEventLogger,
counters,
detect_fake_mode,
dynamo_timed,
flatten_graph_inputs,
get_chromium_event_logger,
lazy_format_graph_code,
set_feature_use,
)
@ -569,12 +569,10 @@ def compile_fx_inner(
stack.enter_context(_WaitCounter("pytorch.wait_counter.dynamo_compile").guard())
stack.enter_context(with_fresh_cache_if_config())
stack.enter_context(DebugContext())
get_chromium_event_logger().add_event_data(
CompileEventLogger.pt2_compile(
"inductor_compile",
is_backward=kwargs["is_backward"],
)
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
gm,
example_inputs,
@ -736,7 +734,6 @@ def _compile_fx_inner(
# and a tlparse log for every cache action.
# In the event of a bypass, we also logged to the remote table earlier
# with log_cache_bypass.
chromium_log = get_chromium_event_logger()
cache_state = (
cache_info["cache_state"] if cache_info is not None else "disabled"
)
@ -745,14 +742,14 @@ def _compile_fx_inner(
# fx_graph_cache_miss
# fx_graph_cache_bypass
# fx_graph_cache_disabled
chromium_log.log_instant_event(
CompileEventLogger.instant(
f"fx_graph_cache_{cache_state}",
start_time,
metadata=cache_info,
metadata=cache_info or {},
time_ns=start_time,
)
# Add event data about cache hits/miss
# TODO: add remote cache get/put timings here too
chromium_log.add_event_data(
CompileEventLogger.pt2_compile(
"inductor_compile",
cache_state=cache_state,
cache_event_time=start_time,