[Dynamo] Improve PT2 fbcode logging observability (#106932)

Summary:
https://docs.google.com/document/d/1D5K3_ELsda3tIUeSyNL_2yee-M3jVWbirqSQ5BDNvHQ/edit

This is the revamped version of D47908299.

For each frame, we will record a list of compilation metrics: e.g, backend_compile time, entire_frame_compile time, cache_size, co_filename, co_firstlineno, co_name, guards, graph input_count, graph node_count, graph op_count.

With the help of job info: mast_job_name, global_rank, we can satisfy the requirements from `Things I’ve used/wanted to use our logging to determine` in https://docs.google.com/document/d/1D5K3_ELsda3tIUeSyNL_2yee-M3jVWbirqSQ5BDNvHQ/edit (or add more metrics for this framework)

Test Plan:
```
buck2 test //caffe2/test:test_dynamo
```

Differential Revision: D48142400

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106932
Approved by: https://github.com/anijain2305
This commit is contained in:
Yanbo Liang 2023-08-11 20:46:04 +00:00 committed by PyTorch MergeBot
parent 1cfe292061
commit fbfb9a1648
4 changed files with 100 additions and 27 deletions

View file

@ -24,13 +24,7 @@ from .eval_frame import (
reset_code,
)
from .external_utils import is_compiling
from .utils import (
compilation_metrics,
graph_break_reasons,
guard_failures,
orig_code_map,
reset_frame_count,
)
from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count
__all__ = [
"allow_in_graph",
@ -71,6 +65,5 @@ def reset() -> None:
if hasattr(eval_frame.most_recent_backend, "reset"):
eval_frame.most_recent_backend.reset()
eval_frame.most_recent_backend = None
compilation_metrics.clear()
reset_frame_count()
torch._C._dynamo.compiled_autograd.clear_cache()

View file

@ -5,12 +5,12 @@ import os
import random
import types
import weakref
from typing import Dict, Optional, Set
from typing import Any, Callable, Dict, List, Optional, Set
import torch
import torch._logging
from torch._guards import tracing
from torch._utils_internal import signpost_event
from torch._utils_internal import log_compilation_event, signpost_event
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
GuardOnDataDependentSymNode,
@ -24,6 +24,7 @@ from .backends.registry import CompilerFn
from .bytecode_analysis import remove_dead_code, remove_pointless_jumps
from .bytecode_transformation import (
check_inst_exn_tab_entries_valid,
Instruction,
is_generator,
propagate_inst_exn_table_entries,
transform_code_object,
@ -45,9 +46,11 @@ from .replay_record import ExecutionRecord
from .symbolic_convert import InstructionTranslator
from .utils import (
CleanupManager,
CompilationMetrics,
counters,
dynamo_timed,
format_bytecode,
frame_phase_timing,
gen_record_file_name,
guard_failures,
increment_frame,
@ -374,6 +377,7 @@ def convert_frame_assert(
export,
export_constraints,
hooks,
cache_size,
frame,
frame_state=frame_state,
)
@ -382,7 +386,6 @@ def convert_frame_assert(
return wrap_convert_context(_convert_frame_assert)
@dynamo_timed(phase_name="entire_frame_compile")
def _compile(
code: types.CodeType,
globals: Dict[str, object],
@ -393,14 +396,14 @@ def _compile(
export: bool,
export_constraints,
hooks: Hooks,
cache_size: int,
frame: Optional[types.FrameType] = None,
frame_state=None,
) -> Optional[GuardedCode]:
output: Optional[OutputGraph] = None
# This is shared across restarts
mutated_closure_cell_contents: Set[str] = set()
# from .utils import print_once; print_once(code.co_filename)
fail_reason: Optional[str] = None
def transform(instructions, code_options):
nonlocal output
@ -431,7 +434,14 @@ def _compile(
check_inst_exn_tab_entries_valid(instructions)
instructions[:] = remove_pointless_jumps(remove_dead_code(instructions))
try:
@dynamo_timed(phase_name="entire_frame_compile")
def compile_inner(
code: types.CodeType,
one_graph: bool,
hooks: Hooks,
transform: Callable[[List[Instruction], Dict[str, Any]], Any],
) -> Optional[GuardedCode]:
nonlocal output
for attempt in itertools.count():
try:
out_code = transform_code_object(code, transform)
@ -521,6 +531,10 @@ def _compile(
output.local_scope.clear()
return guarded_code
try:
guarded_code = compile_inner(code, one_graph, hooks, transform)
return guarded_code
except (
Unsupported,
TorchRuntimeError,
@ -529,11 +543,54 @@ def _compile(
ConstraintViolationError,
GuardOnDataDependentSymNode,
) as e:
fail_reason = str(e)
exception_handler(e, code, frame, export=export)
raise
except Exception as e:
fail_reason = str(e)
exception_handler(e, code, frame, export=export)
raise InternalTorchDynamoError(str(e)).with_traceback(e.__traceback__) from None
finally:
from .utils import curr_frame
frame_key = str(curr_frame)
if (
fail_reason is None
and output is not None
and frame_key in frame_phase_timing
):
guard_count = len(output.guards)
graph_op_count = output.count_calls()
graph_node_count = len(output.graph.nodes)
graph_input_count = len(output.placeholders)
entire_frame_compile_time = frame_phase_timing[frame_key].get(
"entire_frame_compile", None
)
backend_compile_time = frame_phase_timing[frame_key].get(
"backend_compile", None
)
else:
guard_count = None
graph_op_count = None
graph_node_count = None
graph_input_count = None
entire_frame_compile_time = None
backend_compile_time = None
metrics = CompilationMetrics(
frame_key,
code.co_name,
code.co_filename,
code.co_firstlineno,
cache_size,
guard_count,
graph_op_count,
graph_node_count,
graph_input_count,
entire_frame_compile_time,
backend_compile_time,
fail_reason,
)
log_compilation_event(metrics)
def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
@ -602,7 +659,9 @@ def replay(filename):
compiler_fn=eager,
one_graph=False,
export=False,
export_constraints=None,
hooks=Hooks(),
cache_size=0,
frame=None,
)
except Exception:

View file

@ -63,8 +63,11 @@ nnmodule_doc_url = "https://pytorch.org/docs/master/compile/nn-module.html"
nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations."
log = logging.getLogger(__name__)
# profiling compilation time
compilation_metrics = collections.OrderedDict()
# profiling compilation time by function
compilation_time_metrics = collections.OrderedDict()
# profiling compilation time by frame phase
frame_phase_timing = collections.OrderedDict()
timer_counter = itertools.count()
@ -101,8 +104,6 @@ def dynamo_profiled(func):
return profile_wrapper
frame_phase_timing = collections.OrderedDict()
curr_frame = 0
@ -116,6 +117,7 @@ def increment_frame():
def reset_frame_count():
global curr_frame
frame_phase_timing.clear()
compilation_time_metrics.clear()
curr_frame = 0
@ -151,7 +153,7 @@ def print_time_report():
# dynamo_timed API works as a function decorator
# By wrapping a function in dynamo_timed, we can store a record in compilation_metrics
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
# where the key is the functions name.
# For example:
#
@ -171,14 +173,13 @@ def dynamo_timed(original_function=None, phase_name=None):
@wraps(func)
def time_wrapper(*args, **kwargs):
key = func.__qualname__
if key not in compilation_metrics:
compilation_metrics[key] = []
if key not in compilation_time_metrics:
compilation_time_metrics[key] = []
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
t0 = time.time()
r = func(*args, **kwargs)
time_spent = time.time() - t0
# print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")
compilation_metrics[key].append(time_spent)
compilation_time_metrics[key].append(time_spent)
if phase_name:
frame_key = str(curr_frame)
if frame_key not in frame_phase_timing:
@ -217,8 +218,8 @@ def compile_times(repr="str", aggregate=False):
if repr == "str":
rows = [
(k, fmt_fn(compilation_metrics[k], item_fn=lambda x: f"{x:.4f}"))
for k in compilation_metrics
(k, fmt_fn(compilation_time_metrics[k], item_fn=lambda x: f"{x:.4f}"))
for k in compilation_time_metrics
]
out = "TorchDynamo compilation metrics:\n"
out += tabulate(rows, headers=("Function", "Runtimes (s)"))
@ -226,9 +227,9 @@ def compile_times(repr="str", aggregate=False):
elif repr == "csv":
values = [
fmt_fn(v, item_fn=lambda x: f"{x:.6f}")
for v in compilation_metrics.values()
for v in compilation_time_metrics.values()
]
headers = list(compilation_metrics.keys())
headers = list(compilation_time_metrics.keys())
return headers, values
@ -479,6 +480,22 @@ def proxy_args_kwargs(args, kwargs):
) from e
@dataclasses.dataclass
class CompilationMetrics:
frame_key: str
co_name: str
co_filename: str
co_firstlineno: int
cache_size: int
guard_count: Optional[int]
graph_op_count: Optional[int]
graph_node_count: Optional[int]
graph_input_count: Optional[int]
entire_frame_compile_time_s: Optional[float]
backend_compile_time_s: Optional[float]
fail_reason: Optional[str]
@dataclasses.dataclass
class CleanupHook:
"""Remove a global variable when hook is called"""

View file

@ -64,6 +64,10 @@ def signpost_event(category: str, name: str, parameters: Dict[str, Any]):
log.info("%s %s: %r", category, name, parameters)
def log_compilation_event(metrics):
log.info("%s", metrics)
TEST_MASTER_ADDR = "127.0.0.1"
TEST_MASTER_PORT = 29500
# USE_GLOBAL_DEPS controls whether __init__.py tries to load