Switch times to us in CompilationMetrics and improvements (#138975)

Companion logger diff: https://www.internalfb.com/diff/D65012523

* Using float seconds for timestamps is bad because our internal system defaults to float32 precision and you don't even get second precision for timestamps in float32
* We decide to use microseconds instead of milliseconds because millisecond granularity you can end up with the same timestamp if compilation is happening very quickly; much better to force non-overlapping spans
* Because there are so many new fields and I don't feel like reimplementing each on BwdCompilationMetrics, BwdCompilationMetrics is no more, it's just that everything in CompilationMetrics is now optional.
* The actual frame compile times collection is not modified (still float) to reduce blast radius, so I just convert to microseconds before making the record. At float64 precision (Python's default), you get about microsecond precision on timestamps so shouldn't be a data problem (https://www.leebutterman.com/2021/02/01/store-your-unix-epoch-times-as-float64.html)
* I rename some entries for clarity. In particular, whenever a timing contains all of the its lower phases (e.g., how Inductor also contains Triton compilation) we put "cumulative" in its name.  If something doesn't happen at compile time but is delayed until we have actual real inputs, we put "runtime" in its name.

Test plan:

```
buck2 run @mode/opt @mode/inplace //scripts/oulgen:runner
```

And then inspect https://fburl.com/scuba/dynamo_compile/sandbox/mslu7f5w and verify the us columns are populated and meaningful.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138975
Approved by: https://github.com/masnesral
This commit is contained in:
Edward Z. Yang 2024-10-28 09:47:47 -04:00 committed by PyTorch MergeBot
parent 9b2c99d731
commit bca696ae81
3 changed files with 148 additions and 83 deletions

View file

@ -110,6 +110,8 @@ class StructuredTraceTestingFormatter(logging.Formatter):
metadata["stack"] = "STACK"
if "compilation_metrics" in metadata:
metadata["compilation_metrics"] = "METRICS"
if "bwd_compilation_metrics" in metadata:
metadata["bwd_compilation_metrics"] = "METRICS"
if "describe_storage" in metadata:
metadata["describe_storage"]["describer_id"] = "ID"
if "describe_tensor" in metadata:
@ -368,7 +370,7 @@ class StructuredTraceTest(TestCase):
{"inductor_post_grad_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"bwd_compilation_metrics": {"compile_id": "2/0", "inductor_compile_time_s": <dynamic>, "code_gen_time_s": <dynamic>, "fail_type": null, "fail_reason": null, "remote_cache_time_saved_s": null, "structured_logging_overhead_s": <dynamic>, "is_forward": false, "remote_fx_graph_cache_get_time_ms": null, "remote_fx_graph_cache_put_time_ms": null}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1}
{"bwd_compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1}
{"dynamo_start": {"stack": "STACK"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0}
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0}
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0}

View file

@ -120,6 +120,7 @@ from .utils import (
reset_graph_break_dup_checker,
setup_compile_debug,
to_int_ms,
to_int_us,
troubleshooting_url,
write_record_to_file,
)
@ -964,7 +965,7 @@ def _compile(
]
},
)
start_time = time.time()
start_time_ns = time.time_ns()
fail_type: Optional[str] = None
fail_reason: Optional[str] = None
fail_user_frame_filename: Optional[str] = None
@ -1021,6 +1022,8 @@ def _compile(
if tracer:
tracer.output.local_scope = {}
duration_ns = time.time_ns() - start_time_ns
from .utils import curr_frame
frame_key = str(curr_frame)
@ -1089,7 +1092,7 @@ def _compile(
compliant_custom_ops = set({})
restart_reasons = set()
# If compilation failed, the entire time is wasted
dynamo_time_before_restart = time.time() - start_time
dynamo_time_before_restart = duration_ns / 1e9
possibly_missed_reinplacing_opportunities = None
remote_cache_time_saved = None
remote_fx_graph_cache_get_time = None
@ -1124,7 +1127,7 @@ def _compile(
graph_op_count,
graph_node_count,
graph_input_count,
start_time,
start_time_ns / 1e9,
entire_frame_compile_time,
backend_compile_time,
inductor_compile_time,
@ -1148,6 +1151,29 @@ def _compile(
True, # is_forward
to_int_ms(remote_fx_graph_cache_get_time),
to_int_ms(remote_fx_graph_cache_put_time),
start_time_us=start_time_ns // 1000,
duration_us=duration_ns // 1000,
dynamo_cumulative_compile_time_us=to_int_us(entire_frame_compile_time),
aot_autograd_cumulative_compile_time_us=to_int_us(backend_compile_time),
inductor_cumulative_compile_time_us=to_int_us(inductor_compile_time),
inductor_code_gen_cumulative_compile_time_us=to_int_us(code_gen_time),
triton_compile_time_us=None, # TODO: instrument
runtime_cudagraphify_time_us=None, # TODO: instrument in separate event
runtime_triton_autotune_time_us=None, # TODO: instrument in separate event
dynamo_compile_time_before_restart_us=to_int_us(
dynamo_time_before_restart
),
cuda_synchronize_time_us=None, # TODO: instrument
distributed_ephemeral_timeout_us=to_int_us(
remote_cache_time_saved
), # TODO: instrument more accurately
structured_logging_overhead_us=to_int_us(structured_logging_overhead_s),
remote_fx_graph_cache_get_time_us=to_int_us(
remote_fx_graph_cache_get_time
),
remote_fx_graph_cache_put_time_us=to_int_us(
remote_fx_graph_cache_put_time
),
)
record_compilation_metrics(metrics)
torch._dynamo.callback_handler.run_end_callbacks()

View file

@ -278,14 +278,14 @@ def dynamo_timed(
fail_type: Optional[str] = None
fail_reason: Optional[str] = None
time_spent = float("-inf")
start = time.time_ns()
start_ns = time.time_ns()
try:
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
t0 = time.time()
if phase_name:
chromium_log.log_event_start(phase_name, start, {"fn_name": key})
chromium_log.log_event_start(phase_name, start_ns, {"fn_name": key})
else:
chromium_log.log_event_start(key, start, {})
chromium_log.log_event_start(key, start_ns, {})
yield
time_spent = time.time() - t0
compilation_time_metrics[key].append(time_spent)
@ -294,16 +294,17 @@ def dynamo_timed(
fail_reason = str(e)
raise
finally:
end_ns = time.time_ns()
# Always log the end event even on exception
if phase_name:
chromium_log.log_event_end(
phase_name,
time.time_ns(),
end_ns,
{},
start,
start_ns,
)
else:
chromium_log.log_event_end(key, time.time_ns(), {}, start)
chromium_log.log_event_end(key, end_ns, {}, start_ns)
# Only record backward compilation metrics if phase_name is not None!
if phase_name:
frame_key = str(curr_frame)
@ -358,17 +359,41 @@ def dynamo_timed(
structured_logging_overhead_s = (
torch._logging.get_structured_logging_overhead()
)
metrics = BwdCompilationMetrics(
compile_id,
inductor_compile_time,
code_gen_time,
fail_type,
fail_reason,
remote_cache_time_saved,
structured_logging_overhead_s,
False, # is_forward
to_int_ms(remote_fx_graph_cache_get_time),
to_int_ms(remote_fx_graph_cache_put_time),
metrics = CompilationMetrics(
compile_id=compile_id,
inductor_compile_time_s=inductor_compile_time,
code_gen_time_s=code_gen_time,
fail_type=fail_type,
fail_reason=fail_reason,
remote_cache_time_saved_s=remote_cache_time_saved,
structured_logging_overhead_s=structured_logging_overhead_s,
is_forward=False, # is_forward
remote_fx_graph_cache_get_time_ms=to_int_ms(
remote_fx_graph_cache_get_time
),
remote_fx_graph_cache_put_time_ms=to_int_ms(
remote_fx_graph_cache_put_time
),
start_time_us=start_ns // 1000,
duration_us=(end_ns - start_ns) // 1000,
inductor_cumulative_compile_time_us=to_int_us(
inductor_compile_time
),
inductor_code_gen_cumulative_compile_time_us=to_int_us(
code_gen_time
),
distributed_ephemeral_timeout_us=to_int_us(
remote_cache_time_saved
), # TODO: instrument more accurately
structured_logging_overhead_us=to_int_us(
structured_logging_overhead_s
),
remote_fx_graph_cache_get_time_us=to_int_us(
remote_fx_graph_cache_get_time
),
remote_fx_graph_cache_put_time_us=to_int_us(
remote_fx_graph_cache_put_time
),
)
record_compilation_metrics(metrics)
@ -779,69 +804,76 @@ def to_int_ms(v: Optional[float]) -> Optional[int]:
return None if v is None else int(v * 1000)
# float64 timestamp has a quarter microsecond precision in 2024, so while
# this is suboptimal we shouldn't meaningfully lose precision
def to_int_us(v: Optional[float]) -> Optional[int]:
return None if v is None else int(v * 1_000_000)
@dataclasses.dataclass
class CompilationMetrics:
compile_id: str
frame_key: str
co_name: str
co_filename: str
co_firstlineno: int
cache_size: int
accumulated_cache_size: int
guard_count: Optional[int]
shape_env_guard_count: Optional[int]
graph_op_count: Optional[int]
graph_node_count: Optional[int]
graph_input_count: Optional[int]
start_time: float
entire_frame_compile_time_s: Optional[float]
backend_compile_time_s: Optional[float]
inductor_compile_time_s: Optional[float]
code_gen_time_s: Optional[float]
fail_type: Optional[str]
fail_reason: Optional[str]
fail_user_frame_filename: Optional[str]
fail_user_frame_lineno: Optional[int]
non_compliant_ops: Set[str]
compliant_custom_ops: Set[str]
restart_reasons: Set[str]
dynamo_time_before_restart_s: float
compile_id: Optional[str] = None
frame_key: Optional[str] = None
co_name: Optional[str] = None
co_filename: Optional[str] = None
co_firstlineno: Optional[int] = None
cache_size: Optional[int] = None
accumulated_cache_size: Optional[int] = None
guard_count: Optional[int] = None
shape_env_guard_count: Optional[int] = None
graph_op_count: Optional[int] = None
graph_node_count: Optional[int] = None
graph_input_count: Optional[int] = None
start_time: Optional[float] = None
entire_frame_compile_time_s: Optional[float] = None
backend_compile_time_s: Optional[float] = None
inductor_compile_time_s: Optional[float] = None
code_gen_time_s: Optional[float] = None
fail_type: Optional[str] = None
fail_reason: Optional[str] = None
fail_user_frame_filename: Optional[str] = None
fail_user_frame_lineno: Optional[int] = None
non_compliant_ops: Optional[Set[str]] = None
compliant_custom_ops: Optional[Set[str]] = None
restart_reasons: Optional[Set[str]] = None
dynamo_time_before_restart_s: Optional[float] = None
# Sometimes, we will finish analyzing a frame but conclude we don't want
# to install any guarded code. True means we actually decided to install
# a compiled frame
has_guarded_code: bool
possibly_missed_reinplacing_opportunities: Optional[int]
remote_cache_time_saved_s: Optional[float]
structured_logging_overhead_s: Optional[float]
config_suppress_errors: Optional[bool]
config_inline_inbuilt_nn_modules: Optional[bool]
specialize_float: Optional[bool]
dynamo_config: Optional[str]
is_forward: Optional[bool]
remote_fx_graph_cache_get_time_ms: Optional[int]
remote_fx_graph_cache_put_time_ms: Optional[int]
@dataclasses.dataclass
class BwdCompilationMetrics:
compile_id: str
inductor_compile_time_s: Optional[float]
code_gen_time_s: Optional[float]
fail_type: Optional[str]
fail_reason: Optional[str]
remote_cache_time_saved_s: Optional[float]
structured_logging_overhead_s: Optional[float]
is_forward: Optional[bool]
remote_fx_graph_cache_get_time_ms: Optional[int]
remote_fx_graph_cache_put_time_ms: Optional[int]
has_guarded_code: Optional[bool] = None
possibly_missed_reinplacing_opportunities: Optional[int] = None
remote_cache_time_saved_s: Optional[float] = None
structured_logging_overhead_s: Optional[float] = None
config_suppress_errors: Optional[bool] = None
config_inline_inbuilt_nn_modules: Optional[bool] = None
specialize_float: Optional[bool] = None
dynamo_config: Optional[str] = None
is_forward: Optional[bool] = None
remote_fx_graph_cache_get_time_ms: Optional[int] = None
remote_fx_graph_cache_put_time_ms: Optional[int] = None
start_time_us: Optional[int] = None
duration_us: Optional[int] = None
dynamo_cumulative_compile_time_us: Optional[int] = None
aot_autograd_cumulative_compile_time_us: Optional[int] = None
inductor_cumulative_compile_time_us: Optional[int] = None
inductor_code_gen_cumulative_compile_time_us: Optional[int] = None
triton_compile_time_us: Optional[int] = None
runtime_cudagraphify_time_us: Optional[int] = None
runtime_triton_autotune_time_us: Optional[int] = None
dynamo_compile_time_before_restart_us: Optional[int] = None
cuda_synchronize_time_us: Optional[int] = None
distributed_ephemeral_timeout_us: Optional[int] = None
structured_logging_overhead_us: Optional[int] = None
remote_fx_graph_cache_get_time_us: Optional[int] = None
remote_fx_graph_cache_put_time_us: Optional[int] = None
DEFAULT_COMPILATION_METRICS_LIMIT = 64
_compilation_metrics: Deque[
Union[CompilationMetrics, BwdCompilationMetrics]
] = collections.deque(maxlen=DEFAULT_COMPILATION_METRICS_LIMIT)
_compilation_metrics: Deque[CompilationMetrics] = collections.deque(
maxlen=DEFAULT_COMPILATION_METRICS_LIMIT
)
def add_compilation_metrics_to_chromium(c: CompilationMetrics):
@ -866,21 +898,25 @@ def add_compilation_metrics_to_chromium(c: CompilationMetrics):
fail_user_frame_filename=c.fail_user_frame_filename,
fail_user_frame_lineno=c.fail_user_frame_lineno,
# Sets aren't JSON serializable
non_compliant_ops=list(c.non_compliant_ops),
compliant_custom_ops=list(c.compliant_custom_ops),
restart_reasons=list(c.restart_reasons),
non_compliant_ops=list(c.non_compliant_ops)
if c.non_compliant_ops is not None
else None,
compliant_custom_ops=list(c.compliant_custom_ops)
if c.compliant_custom_ops is not None
else None,
restart_reasons=list(c.restart_reasons)
if c.restart_reasons is not None
else None,
dynamo_time_before_restart_s=c.dynamo_time_before_restart_s,
has_guarded_code=c.has_guarded_code,
dynamo_config=c.dynamo_config,
)
def record_compilation_metrics(
compilation_metrics: Union[CompilationMetrics, BwdCompilationMetrics]
):
def record_compilation_metrics(compilation_metrics: CompilationMetrics):
global _compilation_metrics
_compilation_metrics.append(compilation_metrics)
if isinstance(compilation_metrics, CompilationMetrics):
if compilation_metrics.is_forward:
name = "compilation_metrics"
add_compilation_metrics_to_chromium(compilation_metrics)
else:
@ -914,7 +950,7 @@ def clear_compilation_metrics() -> None:
_compilation_metrics.clear()
def get_compilation_metrics() -> List[Union[CompilationMetrics, BwdCompilationMetrics]]:
def get_compilation_metrics() -> List[CompilationMetrics]:
return list(_compilation_metrics)
@ -957,7 +993,8 @@ class ChromiumEventLogger:
"""
if event_name not in self.get_stack():
raise RuntimeError(
"Cannot add metadata to events that aren't in progress."
f"Event {repr(event_name)} not in {self.get_stack()}. "
"Cannot add metadata to events that aren't in progress. "
"Please make sure the event has started and hasn't ended."
)
event_data = self.get_event_data()