mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[PT2 Compile Events] Revamp PT2 Compile/chromium event logging [1/?] (#138093)
This diff is the starting steps of https://docs.google.com/document/u/2/d/1kAEBt4AyW7HTAhXHbjoz8FBFHNyyEA2Qo2mPn7v3WUQ/edit?usp=drive_web&ouid=113555078003219714709 It implements the following changes: - Only log spans to scuba, so no start events are ever logged - Log events as the full event name, without "START" or "END" - Only log to scuba major phases from chromium events. These are: - entire_frame_compile (dynamo) - backend_compile (aotdispatch) - inductor_compile (inductor) - codegen (inductor codegen) Tlparse chromium events stay basically the same. But I implemented a few changes to clean that up as well: - When there's a phase name available, log the phase name instead of the function name as the event name. This simplifies the trace to not have two identical rows. The fn_name is avaliable as metadata on the chromium event, if interested - Log new events for pre and post grad passes. These do *not* log to scuba. By making the phases much simpler in Scuba, with only categories for major phases of PT2 Compilation, we pave the way to add **much** more metadata and information to each individual event type. Diffs for that will come later. **IMPLEMENTATION NOTES:** - The logic for `log_chromium_event_internal` (which is the function that logs to Scuba) lives in chromium_events for now, but in the future as we add more metadata, it may belong independently in dynamo_timed or even outside of dynamo_timed. I haven't explored in detail what the refactor will look like. Once we start logging metadata for dynamo, aotdispatch, inductor, I suspect we will call log_pt2_compile_event directly, instead of making chromium event logger handle the pt2_compile_event logic. But that refactor is left for another PR on top of this one. - There's an interesting space after pre grad passes within AOT autograd logic, that's between create_aot_dispatcher_function and pre grad passes. I'm not sure what we're spending time doing in that time, but I'll find out with a profile later. Differential Revision: [D64479033](https://our.internmc.facebook.com/intern/diff/D64479033/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/138093 Approved by: https://github.com/ezyang
This commit is contained in:
parent
3c7d9d6c7f
commit
295de00908
3 changed files with 39 additions and 49 deletions
|
|
@ -236,16 +236,6 @@ def add_remote_cache_time_saved(time_saved_ns: int, is_backward: bool = False) -
|
|||
_add_time_spent(key, "remote_cache_time_saved", time_saved)
|
||||
|
||||
|
||||
def get_cache_stats() -> Dict[str, Any]:
|
||||
"""Get a bunch of metadata about cache hits and misses to use in chromium events"""
|
||||
cache_stats = {
|
||||
"fxgraph_cache_hit": counters["inductor"]["fxgraph_cache_hit"],
|
||||
"fxgraph_cache_miss": counters["inductor"]["fxgraph_cache_miss"],
|
||||
"fxgraph_cache_bypass": counters["inductor"]["fxgraph_cache_bypass"],
|
||||
}
|
||||
return cache_stats
|
||||
|
||||
|
||||
# dynamo_timed is a context manager
|
||||
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
|
||||
# where the key is the functions name.
|
||||
|
|
@ -290,9 +280,10 @@ def dynamo_timed(
|
|||
try:
|
||||
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
|
||||
t0 = time.time()
|
||||
chromium_log.log_event_start(key, start, None)
|
||||
if phase_name:
|
||||
chromium_log.log_event_start(phase_name, start)
|
||||
chromium_log.log_event_start(phase_name, start, {"fn_name": key})
|
||||
else:
|
||||
chromium_log.log_event_start(key, start, {})
|
||||
yield
|
||||
time_spent = time.time() - t0
|
||||
compilation_time_metrics[key].append(time_spent)
|
||||
|
|
@ -306,16 +297,15 @@ def dynamo_timed(
|
|||
chromium_log.log_event_end(
|
||||
phase_name,
|
||||
time.time_ns(),
|
||||
{"cache_stats": get_cache_stats()},
|
||||
{},
|
||||
start,
|
||||
)
|
||||
chromium_log.log_event_end(
|
||||
key, time.time_ns(), {"cache_stats": get_cache_stats()}, start
|
||||
)
|
||||
else:
|
||||
chromium_log.log_event_end(key, time.time_ns(), {}, start)
|
||||
# Only record backward compilation metrics if phase_name is not None!
|
||||
if phase_name:
|
||||
frame_key = str(curr_frame)
|
||||
# fwd only compilation stages: entire_frame_compile, backend_compile.
|
||||
# fwd only compilation stages: entire_frame_compile, backend_compile, aotdispatch.
|
||||
# use frame_key as time aggregation key.
|
||||
if fwd_only and fail_type is None:
|
||||
_add_time_spent(frame_key, phase_name, time_spent)
|
||||
|
|
@ -902,7 +892,7 @@ class ChromiumEventLogger:
|
|||
self,
|
||||
event_name: str,
|
||||
time_ns: int,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Logs the start of a single event.
|
||||
|
|
@ -911,19 +901,14 @@ class ChromiumEventLogger:
|
|||
:param metadata: Any extra metadata associated with this event
|
||||
"""
|
||||
|
||||
# Add compile id to metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
compile_id = str(torch._guards.CompileContext.current_compile_id())
|
||||
metadata["compile_id"] = compile_id
|
||||
|
||||
event = self._log_timed_event(
|
||||
self._log_timed_event(
|
||||
event_name,
|
||||
time_ns,
|
||||
"B",
|
||||
metadata,
|
||||
)
|
||||
log_chromium_event_internal(event, self.get_stack(), compile_id, self.id_)
|
||||
self.get_stack().append(event_name)
|
||||
|
||||
def reset(self) -> None:
|
||||
|
|
@ -937,8 +922,8 @@ class ChromiumEventLogger:
|
|||
self,
|
||||
event_name: str,
|
||||
time_ns: int,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
start_time_ns: Optional[int] = None,
|
||||
metadata: Dict[str, Any],
|
||||
start_time_ns: int,
|
||||
) -> None:
|
||||
"""
|
||||
Logs the end of a single event. This function should only be
|
||||
|
|
@ -947,11 +932,14 @@ class ChromiumEventLogger:
|
|||
:param time_ns: Timestamp in nanoseconds
|
||||
:param metadata: Any extra metadata associated with this event
|
||||
"""
|
||||
# Add compile id to metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
compile_id = str(torch._guards.CompileContext.current_compile_id())
|
||||
metadata["compile_id"] = compile_id
|
||||
event = self._log_timed_event(
|
||||
event_name,
|
||||
time_ns,
|
||||
"E",
|
||||
metadata,
|
||||
)
|
||||
|
||||
# These stack health checks currently never happen,
|
||||
# but they're written this way to future proof any weird event
|
||||
|
|
@ -963,13 +951,6 @@ class ChromiumEventLogger:
|
|||
log.warning("ChromiumEventLogger: Start event not in stack, ignoring")
|
||||
return
|
||||
|
||||
event = self._log_timed_event(
|
||||
event_name,
|
||||
time_ns,
|
||||
"E",
|
||||
metadata,
|
||||
)
|
||||
|
||||
while event_name != stack[-1]:
|
||||
# If the event isn't the most recent one to end, pop
|
||||
# off the stack until it is.
|
||||
|
|
@ -1046,7 +1027,9 @@ class ChromiumEventLogger:
|
|||
expect_trace_id=True,
|
||||
)
|
||||
# Log an instant event with the same start and end time
|
||||
log_chromium_event_internal(event, self.get_stack(), compile_id, self.id_)
|
||||
log_chromium_event_internal(
|
||||
event, self.get_stack(), compile_id, self.id_, time_ns
|
||||
)
|
||||
|
||||
|
||||
CHROMIUM_EVENT_LOG: Optional[ChromiumEventLogger] = None
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ from torch._dynamo.repro.after_aot import wrap_compiler_debug
|
|||
from torch._dynamo.utils import (
|
||||
counters,
|
||||
detect_fake_mode,
|
||||
dynamo_timed,
|
||||
flatten_graph_inputs,
|
||||
lazy_format_graph_code,
|
||||
)
|
||||
|
|
@ -281,12 +282,13 @@ def _get_subgraph_names(gm):
|
|||
|
||||
|
||||
def _recursive_pre_grad_passes(gm, example_inputs):
|
||||
for subgraph_name in _get_subgraph_names(gm):
|
||||
subgraph = getattr(gm, subgraph_name)
|
||||
# as we don't have recursive example inputs, passing None here
|
||||
new_subgraph = _recursive_pre_grad_passes(subgraph, example_inputs=None)
|
||||
setattr(gm, subgraph_name, new_subgraph)
|
||||
return pre_grad_passes(gm, example_inputs)
|
||||
with dynamo_timed("_recursive_pre_grad_passes"):
|
||||
for subgraph_name in _get_subgraph_names(gm):
|
||||
subgraph = getattr(gm, subgraph_name)
|
||||
# as we don't have recursive example inputs, passing None here
|
||||
new_subgraph = _recursive_pre_grad_passes(subgraph, example_inputs=None)
|
||||
setattr(gm, subgraph_name, new_subgraph)
|
||||
return pre_grad_passes(gm, example_inputs)
|
||||
|
||||
|
||||
def _recursive_joint_graph_passes(gm):
|
||||
|
|
@ -297,10 +299,11 @@ def _recursive_joint_graph_passes(gm):
|
|||
|
||||
|
||||
def _recursive_post_grad_passes(gm, is_inference: bool = False):
|
||||
for subgraph_name in _get_subgraph_names(gm):
|
||||
subgraph = getattr(gm, subgraph_name)
|
||||
_recursive_post_grad_passes(subgraph, is_inference)
|
||||
post_grad_passes(gm, is_inference)
|
||||
with dynamo_timed("_recursive_post_grad_passes"):
|
||||
for subgraph_name in _get_subgraph_names(gm):
|
||||
subgraph = getattr(gm, subgraph_name)
|
||||
_recursive_post_grad_passes(subgraph, is_inference)
|
||||
post_grad_passes(gm, is_inference)
|
||||
|
||||
|
||||
def split_const_gm(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler
|
||||
|
|
@ -357,6 +357,10 @@ def maybe_upload_prof_stats_to_manifold(profile_path: str) -> Optional[str]:
|
|||
|
||||
|
||||
def log_chromium_event_internal(
|
||||
event, stack, compile_id, logger_uuid, start_timestamp=None
|
||||
event: Dict[str, Any],
|
||||
stack: List[str],
|
||||
compile_id: Optional[str],
|
||||
logger_uuid: str,
|
||||
start_time_ns: int,
|
||||
):
|
||||
return None
|
||||
|
|
|
|||
Loading…
Reference in a new issue