diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index 4e5c04d399f..e3da411034a 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -110,6 +110,8 @@ class StructuredTraceTestingFormatter(logging.Formatter): metadata["stack"] = "STACK" if "compilation_metrics" in metadata: metadata["compilation_metrics"] = "METRICS" + if "bwd_compilation_metrics" in metadata: + metadata["bwd_compilation_metrics"] = "METRICS" if "describe_storage" in metadata: metadata["describe_storage"]["describer_id"] = "ID" if "describe_tensor" in metadata: @@ -368,7 +370,7 @@ class StructuredTraceTest(TestCase): {"inductor_post_grad_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} -{"bwd_compilation_metrics": {"compile_id": "2/0", "inductor_compile_time_s": , "code_gen_time_s": , "fail_type": null, "fail_reason": null, "remote_cache_time_saved_s": null, "structured_logging_overhead_s": , "is_forward": false, "remote_fx_graph_cache_get_time_ms": null, "remote_fx_graph_cache_put_time_ms": null}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1} +{"bwd_compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 48582f42d85..8e8ea43e1f3 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -120,6 +120,7 @@ from .utils import ( reset_graph_break_dup_checker, setup_compile_debug, to_int_ms, + to_int_us, troubleshooting_url, write_record_to_file, ) @@ -964,7 +965,7 @@ def _compile( ] }, ) - start_time = time.time() + start_time_ns = time.time_ns() fail_type: Optional[str] = None fail_reason: Optional[str] = None fail_user_frame_filename: Optional[str] = None @@ -1021,6 +1022,8 @@ def _compile( if tracer: tracer.output.local_scope = {} + duration_ns = time.time_ns() - start_time_ns + from .utils import curr_frame frame_key = str(curr_frame) @@ -1089,7 +1092,7 @@ def _compile( compliant_custom_ops = set({}) restart_reasons = set() # If compilation failed, the entire time is wasted - dynamo_time_before_restart = time.time() - start_time + dynamo_time_before_restart = duration_ns / 1e9 possibly_missed_reinplacing_opportunities = None remote_cache_time_saved = None remote_fx_graph_cache_get_time = None @@ -1124,7 +1127,7 @@ def _compile( graph_op_count, graph_node_count, graph_input_count, - start_time, + start_time_ns / 1e9, entire_frame_compile_time, backend_compile_time, inductor_compile_time, @@ -1148,6 +1151,29 @@ def _compile( True, # is_forward to_int_ms(remote_fx_graph_cache_get_time), to_int_ms(remote_fx_graph_cache_put_time), + start_time_us=start_time_ns // 1000, + duration_us=duration_ns // 1000, + dynamo_cumulative_compile_time_us=to_int_us(entire_frame_compile_time), + aot_autograd_cumulative_compile_time_us=to_int_us(backend_compile_time), + inductor_cumulative_compile_time_us=to_int_us(inductor_compile_time), + inductor_code_gen_cumulative_compile_time_us=to_int_us(code_gen_time), + triton_compile_time_us=None, # TODO: instrument + runtime_cudagraphify_time_us=None, # TODO: instrument in separate event + runtime_triton_autotune_time_us=None, # TODO: instrument in separate event + dynamo_compile_time_before_restart_us=to_int_us( + dynamo_time_before_restart + ), + cuda_synchronize_time_us=None, # TODO: instrument + distributed_ephemeral_timeout_us=to_int_us( + remote_cache_time_saved + ), # TODO: instrument more accurately + structured_logging_overhead_us=to_int_us(structured_logging_overhead_s), + remote_fx_graph_cache_get_time_us=to_int_us( + remote_fx_graph_cache_get_time + ), + remote_fx_graph_cache_put_time_us=to_int_us( + remote_fx_graph_cache_put_time + ), ) record_compilation_metrics(metrics) torch._dynamo.callback_handler.run_end_callbacks() diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 775ec8b488d..d54d9cf3f82 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -278,14 +278,14 @@ def dynamo_timed( fail_type: Optional[str] = None fail_reason: Optional[str] = None time_spent = float("-inf") - start = time.time_ns() + start_ns = time.time_ns() try: with torch.profiler.record_function(f"{key} (dynamo_timed)"): t0 = time.time() if phase_name: - chromium_log.log_event_start(phase_name, start, {"fn_name": key}) + chromium_log.log_event_start(phase_name, start_ns, {"fn_name": key}) else: - chromium_log.log_event_start(key, start, {}) + chromium_log.log_event_start(key, start_ns, {}) yield time_spent = time.time() - t0 compilation_time_metrics[key].append(time_spent) @@ -294,16 +294,17 @@ def dynamo_timed( fail_reason = str(e) raise finally: + end_ns = time.time_ns() # Always log the end event even on exception if phase_name: chromium_log.log_event_end( phase_name, - time.time_ns(), + end_ns, {}, - start, + start_ns, ) else: - chromium_log.log_event_end(key, time.time_ns(), {}, start) + chromium_log.log_event_end(key, end_ns, {}, start_ns) # Only record backward compilation metrics if phase_name is not None! if phase_name: frame_key = str(curr_frame) @@ -358,17 +359,41 @@ def dynamo_timed( structured_logging_overhead_s = ( torch._logging.get_structured_logging_overhead() ) - metrics = BwdCompilationMetrics( - compile_id, - inductor_compile_time, - code_gen_time, - fail_type, - fail_reason, - remote_cache_time_saved, - structured_logging_overhead_s, - False, # is_forward - to_int_ms(remote_fx_graph_cache_get_time), - to_int_ms(remote_fx_graph_cache_put_time), + metrics = CompilationMetrics( + compile_id=compile_id, + inductor_compile_time_s=inductor_compile_time, + code_gen_time_s=code_gen_time, + fail_type=fail_type, + fail_reason=fail_reason, + remote_cache_time_saved_s=remote_cache_time_saved, + structured_logging_overhead_s=structured_logging_overhead_s, + is_forward=False, # is_forward + remote_fx_graph_cache_get_time_ms=to_int_ms( + remote_fx_graph_cache_get_time + ), + remote_fx_graph_cache_put_time_ms=to_int_ms( + remote_fx_graph_cache_put_time + ), + start_time_us=start_ns // 1000, + duration_us=(end_ns - start_ns) // 1000, + inductor_cumulative_compile_time_us=to_int_us( + inductor_compile_time + ), + inductor_code_gen_cumulative_compile_time_us=to_int_us( + code_gen_time + ), + distributed_ephemeral_timeout_us=to_int_us( + remote_cache_time_saved + ), # TODO: instrument more accurately + structured_logging_overhead_us=to_int_us( + structured_logging_overhead_s + ), + remote_fx_graph_cache_get_time_us=to_int_us( + remote_fx_graph_cache_get_time + ), + remote_fx_graph_cache_put_time_us=to_int_us( + remote_fx_graph_cache_put_time + ), ) record_compilation_metrics(metrics) @@ -779,69 +804,76 @@ def to_int_ms(v: Optional[float]) -> Optional[int]: return None if v is None else int(v * 1000) +# float64 timestamp has a quarter microsecond precision in 2024, so while +# this is suboptimal we shouldn't meaningfully lose precision +def to_int_us(v: Optional[float]) -> Optional[int]: + return None if v is None else int(v * 1_000_000) + + @dataclasses.dataclass class CompilationMetrics: - compile_id: str - frame_key: str - co_name: str - co_filename: str - co_firstlineno: int - cache_size: int - accumulated_cache_size: int - guard_count: Optional[int] - shape_env_guard_count: Optional[int] - graph_op_count: Optional[int] - graph_node_count: Optional[int] - graph_input_count: Optional[int] - start_time: float - entire_frame_compile_time_s: Optional[float] - backend_compile_time_s: Optional[float] - inductor_compile_time_s: Optional[float] - code_gen_time_s: Optional[float] - fail_type: Optional[str] - fail_reason: Optional[str] - fail_user_frame_filename: Optional[str] - fail_user_frame_lineno: Optional[int] - non_compliant_ops: Set[str] - compliant_custom_ops: Set[str] - restart_reasons: Set[str] - dynamo_time_before_restart_s: float + compile_id: Optional[str] = None + frame_key: Optional[str] = None + co_name: Optional[str] = None + co_filename: Optional[str] = None + co_firstlineno: Optional[int] = None + cache_size: Optional[int] = None + accumulated_cache_size: Optional[int] = None + guard_count: Optional[int] = None + shape_env_guard_count: Optional[int] = None + graph_op_count: Optional[int] = None + graph_node_count: Optional[int] = None + graph_input_count: Optional[int] = None + start_time: Optional[float] = None + entire_frame_compile_time_s: Optional[float] = None + backend_compile_time_s: Optional[float] = None + inductor_compile_time_s: Optional[float] = None + code_gen_time_s: Optional[float] = None + fail_type: Optional[str] = None + fail_reason: Optional[str] = None + fail_user_frame_filename: Optional[str] = None + fail_user_frame_lineno: Optional[int] = None + non_compliant_ops: Optional[Set[str]] = None + compliant_custom_ops: Optional[Set[str]] = None + restart_reasons: Optional[Set[str]] = None + dynamo_time_before_restart_s: Optional[float] = None # Sometimes, we will finish analyzing a frame but conclude we don't want # to install any guarded code. True means we actually decided to install # a compiled frame - has_guarded_code: bool - possibly_missed_reinplacing_opportunities: Optional[int] - remote_cache_time_saved_s: Optional[float] - structured_logging_overhead_s: Optional[float] - config_suppress_errors: Optional[bool] - config_inline_inbuilt_nn_modules: Optional[bool] - specialize_float: Optional[bool] - dynamo_config: Optional[str] - is_forward: Optional[bool] - remote_fx_graph_cache_get_time_ms: Optional[int] - remote_fx_graph_cache_put_time_ms: Optional[int] - - -@dataclasses.dataclass -class BwdCompilationMetrics: - compile_id: str - inductor_compile_time_s: Optional[float] - code_gen_time_s: Optional[float] - fail_type: Optional[str] - fail_reason: Optional[str] - remote_cache_time_saved_s: Optional[float] - structured_logging_overhead_s: Optional[float] - is_forward: Optional[bool] - remote_fx_graph_cache_get_time_ms: Optional[int] - remote_fx_graph_cache_put_time_ms: Optional[int] + has_guarded_code: Optional[bool] = None + possibly_missed_reinplacing_opportunities: Optional[int] = None + remote_cache_time_saved_s: Optional[float] = None + structured_logging_overhead_s: Optional[float] = None + config_suppress_errors: Optional[bool] = None + config_inline_inbuilt_nn_modules: Optional[bool] = None + specialize_float: Optional[bool] = None + dynamo_config: Optional[str] = None + is_forward: Optional[bool] = None + remote_fx_graph_cache_get_time_ms: Optional[int] = None + remote_fx_graph_cache_put_time_ms: Optional[int] = None + start_time_us: Optional[int] = None + duration_us: Optional[int] = None + dynamo_cumulative_compile_time_us: Optional[int] = None + aot_autograd_cumulative_compile_time_us: Optional[int] = None + inductor_cumulative_compile_time_us: Optional[int] = None + inductor_code_gen_cumulative_compile_time_us: Optional[int] = None + triton_compile_time_us: Optional[int] = None + runtime_cudagraphify_time_us: Optional[int] = None + runtime_triton_autotune_time_us: Optional[int] = None + dynamo_compile_time_before_restart_us: Optional[int] = None + cuda_synchronize_time_us: Optional[int] = None + distributed_ephemeral_timeout_us: Optional[int] = None + structured_logging_overhead_us: Optional[int] = None + remote_fx_graph_cache_get_time_us: Optional[int] = None + remote_fx_graph_cache_put_time_us: Optional[int] = None DEFAULT_COMPILATION_METRICS_LIMIT = 64 -_compilation_metrics: Deque[ - Union[CompilationMetrics, BwdCompilationMetrics] -] = collections.deque(maxlen=DEFAULT_COMPILATION_METRICS_LIMIT) +_compilation_metrics: Deque[CompilationMetrics] = collections.deque( + maxlen=DEFAULT_COMPILATION_METRICS_LIMIT +) def add_compilation_metrics_to_chromium(c: CompilationMetrics): @@ -866,21 +898,25 @@ def add_compilation_metrics_to_chromium(c: CompilationMetrics): fail_user_frame_filename=c.fail_user_frame_filename, fail_user_frame_lineno=c.fail_user_frame_lineno, # Sets aren't JSON serializable - non_compliant_ops=list(c.non_compliant_ops), - compliant_custom_ops=list(c.compliant_custom_ops), - restart_reasons=list(c.restart_reasons), + non_compliant_ops=list(c.non_compliant_ops) + if c.non_compliant_ops is not None + else None, + compliant_custom_ops=list(c.compliant_custom_ops) + if c.compliant_custom_ops is not None + else None, + restart_reasons=list(c.restart_reasons) + if c.restart_reasons is not None + else None, dynamo_time_before_restart_s=c.dynamo_time_before_restart_s, has_guarded_code=c.has_guarded_code, dynamo_config=c.dynamo_config, ) -def record_compilation_metrics( - compilation_metrics: Union[CompilationMetrics, BwdCompilationMetrics] -): +def record_compilation_metrics(compilation_metrics: CompilationMetrics): global _compilation_metrics _compilation_metrics.append(compilation_metrics) - if isinstance(compilation_metrics, CompilationMetrics): + if compilation_metrics.is_forward: name = "compilation_metrics" add_compilation_metrics_to_chromium(compilation_metrics) else: @@ -914,7 +950,7 @@ def clear_compilation_metrics() -> None: _compilation_metrics.clear() -def get_compilation_metrics() -> List[Union[CompilationMetrics, BwdCompilationMetrics]]: +def get_compilation_metrics() -> List[CompilationMetrics]: return list(_compilation_metrics) @@ -957,7 +993,8 @@ class ChromiumEventLogger: """ if event_name not in self.get_stack(): raise RuntimeError( - "Cannot add metadata to events that aren't in progress." + f"Event {repr(event_name)} not in {self.get_stack()}. " + "Cannot add metadata to events that aren't in progress. " "Please make sure the event has started and hasn't ended." ) event_data = self.get_event_data()