mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Dynamo] Improve PT2 fbcode logging observability (#106932)
Summary: https://docs.google.com/document/d/1D5K3_ELsda3tIUeSyNL_2yee-M3jVWbirqSQ5BDNvHQ/edit This is the revamped version of D47908299. For each frame, we will record a list of compilation metrics: e.g, backend_compile time, entire_frame_compile time, cache_size, co_filename, co_firstlineno, co_name, guards, graph input_count, graph node_count, graph op_count. With the help of job info: mast_job_name, global_rank, we can satisfy the requirements from `Things I’ve used/wanted to use our logging to determine` in https://docs.google.com/document/d/1D5K3_ELsda3tIUeSyNL_2yee-M3jVWbirqSQ5BDNvHQ/edit (or add more metrics for this framework) Test Plan: ``` buck2 test //caffe2/test:test_dynamo ``` Differential Revision: D48142400 Pull Request resolved: https://github.com/pytorch/pytorch/pull/106932 Approved by: https://github.com/anijain2305
This commit is contained in:
parent
1cfe292061
commit
fbfb9a1648
4 changed files with 100 additions and 27 deletions
|
|
@ -24,13 +24,7 @@ from .eval_frame import (
|
|||
reset_code,
|
||||
)
|
||||
from .external_utils import is_compiling
|
||||
from .utils import (
|
||||
compilation_metrics,
|
||||
graph_break_reasons,
|
||||
guard_failures,
|
||||
orig_code_map,
|
||||
reset_frame_count,
|
||||
)
|
||||
from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count
|
||||
|
||||
__all__ = [
|
||||
"allow_in_graph",
|
||||
|
|
@ -71,6 +65,5 @@ def reset() -> None:
|
|||
if hasattr(eval_frame.most_recent_backend, "reset"):
|
||||
eval_frame.most_recent_backend.reset()
|
||||
eval_frame.most_recent_backend = None
|
||||
compilation_metrics.clear()
|
||||
reset_frame_count()
|
||||
torch._C._dynamo.compiled_autograd.clear_cache()
|
||||
|
|
|
|||
|
|
@ -5,12 +5,12 @@ import os
|
|||
import random
|
||||
import types
|
||||
import weakref
|
||||
from typing import Dict, Optional, Set
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
import torch
|
||||
import torch._logging
|
||||
from torch._guards import tracing
|
||||
from torch._utils_internal import signpost_event
|
||||
from torch._utils_internal import log_compilation_event, signpost_event
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
ConstraintViolationError,
|
||||
GuardOnDataDependentSymNode,
|
||||
|
|
@ -24,6 +24,7 @@ from .backends.registry import CompilerFn
|
|||
from .bytecode_analysis import remove_dead_code, remove_pointless_jumps
|
||||
from .bytecode_transformation import (
|
||||
check_inst_exn_tab_entries_valid,
|
||||
Instruction,
|
||||
is_generator,
|
||||
propagate_inst_exn_table_entries,
|
||||
transform_code_object,
|
||||
|
|
@ -45,9 +46,11 @@ from .replay_record import ExecutionRecord
|
|||
from .symbolic_convert import InstructionTranslator
|
||||
from .utils import (
|
||||
CleanupManager,
|
||||
CompilationMetrics,
|
||||
counters,
|
||||
dynamo_timed,
|
||||
format_bytecode,
|
||||
frame_phase_timing,
|
||||
gen_record_file_name,
|
||||
guard_failures,
|
||||
increment_frame,
|
||||
|
|
@ -374,6 +377,7 @@ def convert_frame_assert(
|
|||
export,
|
||||
export_constraints,
|
||||
hooks,
|
||||
cache_size,
|
||||
frame,
|
||||
frame_state=frame_state,
|
||||
)
|
||||
|
|
@ -382,7 +386,6 @@ def convert_frame_assert(
|
|||
return wrap_convert_context(_convert_frame_assert)
|
||||
|
||||
|
||||
@dynamo_timed(phase_name="entire_frame_compile")
|
||||
def _compile(
|
||||
code: types.CodeType,
|
||||
globals: Dict[str, object],
|
||||
|
|
@ -393,14 +396,14 @@ def _compile(
|
|||
export: bool,
|
||||
export_constraints,
|
||||
hooks: Hooks,
|
||||
cache_size: int,
|
||||
frame: Optional[types.FrameType] = None,
|
||||
frame_state=None,
|
||||
) -> Optional[GuardedCode]:
|
||||
output: Optional[OutputGraph] = None
|
||||
# This is shared across restarts
|
||||
mutated_closure_cell_contents: Set[str] = set()
|
||||
|
||||
# from .utils import print_once; print_once(code.co_filename)
|
||||
fail_reason: Optional[str] = None
|
||||
|
||||
def transform(instructions, code_options):
|
||||
nonlocal output
|
||||
|
|
@ -431,7 +434,14 @@ def _compile(
|
|||
check_inst_exn_tab_entries_valid(instructions)
|
||||
instructions[:] = remove_pointless_jumps(remove_dead_code(instructions))
|
||||
|
||||
try:
|
||||
@dynamo_timed(phase_name="entire_frame_compile")
|
||||
def compile_inner(
|
||||
code: types.CodeType,
|
||||
one_graph: bool,
|
||||
hooks: Hooks,
|
||||
transform: Callable[[List[Instruction], Dict[str, Any]], Any],
|
||||
) -> Optional[GuardedCode]:
|
||||
nonlocal output
|
||||
for attempt in itertools.count():
|
||||
try:
|
||||
out_code = transform_code_object(code, transform)
|
||||
|
|
@ -521,6 +531,10 @@ def _compile(
|
|||
|
||||
output.local_scope.clear()
|
||||
return guarded_code
|
||||
|
||||
try:
|
||||
guarded_code = compile_inner(code, one_graph, hooks, transform)
|
||||
return guarded_code
|
||||
except (
|
||||
Unsupported,
|
||||
TorchRuntimeError,
|
||||
|
|
@ -529,11 +543,54 @@ def _compile(
|
|||
ConstraintViolationError,
|
||||
GuardOnDataDependentSymNode,
|
||||
) as e:
|
||||
fail_reason = str(e)
|
||||
exception_handler(e, code, frame, export=export)
|
||||
raise
|
||||
except Exception as e:
|
||||
fail_reason = str(e)
|
||||
exception_handler(e, code, frame, export=export)
|
||||
raise InternalTorchDynamoError(str(e)).with_traceback(e.__traceback__) from None
|
||||
finally:
|
||||
from .utils import curr_frame
|
||||
|
||||
frame_key = str(curr_frame)
|
||||
if (
|
||||
fail_reason is None
|
||||
and output is not None
|
||||
and frame_key in frame_phase_timing
|
||||
):
|
||||
guard_count = len(output.guards)
|
||||
graph_op_count = output.count_calls()
|
||||
graph_node_count = len(output.graph.nodes)
|
||||
graph_input_count = len(output.placeholders)
|
||||
entire_frame_compile_time = frame_phase_timing[frame_key].get(
|
||||
"entire_frame_compile", None
|
||||
)
|
||||
backend_compile_time = frame_phase_timing[frame_key].get(
|
||||
"backend_compile", None
|
||||
)
|
||||
else:
|
||||
guard_count = None
|
||||
graph_op_count = None
|
||||
graph_node_count = None
|
||||
graph_input_count = None
|
||||
entire_frame_compile_time = None
|
||||
backend_compile_time = None
|
||||
metrics = CompilationMetrics(
|
||||
frame_key,
|
||||
code.co_name,
|
||||
code.co_filename,
|
||||
code.co_firstlineno,
|
||||
cache_size,
|
||||
guard_count,
|
||||
graph_op_count,
|
||||
graph_node_count,
|
||||
graph_input_count,
|
||||
entire_frame_compile_time,
|
||||
backend_compile_time,
|
||||
fail_reason,
|
||||
)
|
||||
log_compilation_event(metrics)
|
||||
|
||||
|
||||
def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
|
||||
|
|
@ -602,7 +659,9 @@ def replay(filename):
|
|||
compiler_fn=eager,
|
||||
one_graph=False,
|
||||
export=False,
|
||||
export_constraints=None,
|
||||
hooks=Hooks(),
|
||||
cache_size=0,
|
||||
frame=None,
|
||||
)
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -63,8 +63,11 @@ nnmodule_doc_url = "https://pytorch.org/docs/master/compile/nn-module.html"
|
|||
nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations."
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# profiling compilation time
|
||||
compilation_metrics = collections.OrderedDict()
|
||||
# profiling compilation time by function
|
||||
compilation_time_metrics = collections.OrderedDict()
|
||||
|
||||
# profiling compilation time by frame phase
|
||||
frame_phase_timing = collections.OrderedDict()
|
||||
|
||||
timer_counter = itertools.count()
|
||||
|
||||
|
|
@ -101,8 +104,6 @@ def dynamo_profiled(func):
|
|||
return profile_wrapper
|
||||
|
||||
|
||||
frame_phase_timing = collections.OrderedDict()
|
||||
|
||||
curr_frame = 0
|
||||
|
||||
|
||||
|
|
@ -116,6 +117,7 @@ def increment_frame():
|
|||
def reset_frame_count():
|
||||
global curr_frame
|
||||
frame_phase_timing.clear()
|
||||
compilation_time_metrics.clear()
|
||||
curr_frame = 0
|
||||
|
||||
|
||||
|
|
@ -151,7 +153,7 @@ def print_time_report():
|
|||
|
||||
|
||||
# dynamo_timed API works as a function decorator
|
||||
# By wrapping a function in dynamo_timed, we can store a record in compilation_metrics
|
||||
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
|
||||
# where the key is the functions name.
|
||||
# For example:
|
||||
#
|
||||
|
|
@ -171,14 +173,13 @@ def dynamo_timed(original_function=None, phase_name=None):
|
|||
@wraps(func)
|
||||
def time_wrapper(*args, **kwargs):
|
||||
key = func.__qualname__
|
||||
if key not in compilation_metrics:
|
||||
compilation_metrics[key] = []
|
||||
if key not in compilation_time_metrics:
|
||||
compilation_time_metrics[key] = []
|
||||
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
|
||||
t0 = time.time()
|
||||
r = func(*args, **kwargs)
|
||||
time_spent = time.time() - t0
|
||||
# print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")
|
||||
compilation_metrics[key].append(time_spent)
|
||||
compilation_time_metrics[key].append(time_spent)
|
||||
if phase_name:
|
||||
frame_key = str(curr_frame)
|
||||
if frame_key not in frame_phase_timing:
|
||||
|
|
@ -217,8 +218,8 @@ def compile_times(repr="str", aggregate=False):
|
|||
|
||||
if repr == "str":
|
||||
rows = [
|
||||
(k, fmt_fn(compilation_metrics[k], item_fn=lambda x: f"{x:.4f}"))
|
||||
for k in compilation_metrics
|
||||
(k, fmt_fn(compilation_time_metrics[k], item_fn=lambda x: f"{x:.4f}"))
|
||||
for k in compilation_time_metrics
|
||||
]
|
||||
out = "TorchDynamo compilation metrics:\n"
|
||||
out += tabulate(rows, headers=("Function", "Runtimes (s)"))
|
||||
|
|
@ -226,9 +227,9 @@ def compile_times(repr="str", aggregate=False):
|
|||
elif repr == "csv":
|
||||
values = [
|
||||
fmt_fn(v, item_fn=lambda x: f"{x:.6f}")
|
||||
for v in compilation_metrics.values()
|
||||
for v in compilation_time_metrics.values()
|
||||
]
|
||||
headers = list(compilation_metrics.keys())
|
||||
headers = list(compilation_time_metrics.keys())
|
||||
return headers, values
|
||||
|
||||
|
||||
|
|
@ -479,6 +480,22 @@ def proxy_args_kwargs(args, kwargs):
|
|||
) from e
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompilationMetrics:
|
||||
frame_key: str
|
||||
co_name: str
|
||||
co_filename: str
|
||||
co_firstlineno: int
|
||||
cache_size: int
|
||||
guard_count: Optional[int]
|
||||
graph_op_count: Optional[int]
|
||||
graph_node_count: Optional[int]
|
||||
graph_input_count: Optional[int]
|
||||
entire_frame_compile_time_s: Optional[float]
|
||||
backend_compile_time_s: Optional[float]
|
||||
fail_reason: Optional[str]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CleanupHook:
|
||||
"""Remove a global variable when hook is called"""
|
||||
|
|
|
|||
|
|
@ -64,6 +64,10 @@ def signpost_event(category: str, name: str, parameters: Dict[str, Any]):
|
|||
log.info("%s %s: %r", category, name, parameters)
|
||||
|
||||
|
||||
def log_compilation_event(metrics):
|
||||
log.info("%s", metrics)
|
||||
|
||||
|
||||
TEST_MASTER_ADDR = "127.0.0.1"
|
||||
TEST_MASTER_PORT = 29500
|
||||
# USE_GLOBAL_DEPS controls whether __init__.py tries to load
|
||||
|
|
|
|||
Loading…
Reference in a new issue