Integrate sympy expression provenance logging with structured logs (#145848)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145848
Approved by: https://github.com/angelayi
This commit is contained in:
bobrenjc93 2025-01-30 21:05:00 +00:00 committed by PyTorch MergeBot
parent 4168982dad
commit 0e49f35e3d
2 changed files with 98 additions and 4 deletions

View file

@ -43,9 +43,12 @@ LOG_ENV_VAR = "TORCH_LOGS"
LOG_OUT_ENV_VAR = "TORCH_LOGS_OUT"
LOG_FORMAT_ENV_VAR = "TORCH_LOGS_FORMAT"
TRACE_ENV_VAR = "TORCH_TRACE"
DTRACE_ENV_VAR = "TORCH_DTRACE"
LOG_TRACE_HANDLER: Optional["LazyTraceHandler"] = None
GET_DTRACE_STRUCTURED = False
@dataclass
class LogRegistry:
@ -934,6 +937,8 @@ def _set_log_state(state):
def _init_logs(log_file_name=None):
global GET_DTRACE_STRUCTURED
_reset_logs()
_update_log_state_from_env()
@ -980,6 +985,10 @@ def _init_logs(log_file_name=None):
# Setup handler for the special trace_log, with different default
# configuration
trace_dir_name = os.environ.get(TRACE_ENV_VAR, None)
if os.environ.get(DTRACE_ENV_VAR, None):
GET_DTRACE_STRUCTURED = True
# This handler may remove itself if trace_dir_name is None and we are not
# actually in an FB environment. This allows us to defer actually
# initializing it until we actually need to log anything. This is
@ -1249,9 +1258,6 @@ def trace_structured(
add_structured_logging_overhead(structured_logging_overhead_s)
GET_DTRACE_STRUCTURED = False
def dtrace_structured(
name: str,
# NB: metadata expected to be dict so adding more info is forward compatible

View file

@ -13,14 +13,18 @@ As this file is imported from within torch/__init__.py we do not want it to depe
to avoid having to load SymPy at import time, as doing so is *very* slow.
"""
import builtins
import functools
import inspect
import itertools
import logging
import math
import operator
import sys
import traceback
from functools import lru_cache, update_wrapper
from typing import Optional, TYPE_CHECKING, Union
from typing import Optional, Set, TYPE_CHECKING, Union
import torch
@ -35,6 +39,9 @@ from torch import ( # noqa: F401
SymFloat,
SymInt,
)
from torch._guards import TracingContext
from torch._logging import dtrace_structured
from torch.utils._traceback import format_frame
if TYPE_CHECKING:
@ -1645,6 +1652,86 @@ def _make_user_magic(method, user_type):
return (method_to_operator(method))(get_constant(self))
return wrap_node(getattr(self.node, method_attr)())
def uninteresting_files() -> Set[str]:
import inspect
import torch
mods = [
torch._dynamo.eval_frame,
torch._dynamo.utils,
torch.fx.experimental.sym_node,
torch,
]
import torch._dynamo.guards
return (
{inspect.getfile(m) for m in mods}
| torch._dynamo.guards.uninteresting_files()
| {"<string>"}
)
def capture_provenance(fn):
@functools.wraps(fn)
def wrapper(self, other):
result = fn(self, other)
if torch._logging._internal.GET_DTRACE_STRUCTURED:
floc = None
user_stack = None
user_top_stack = None
user_bottom_stack = None
if len(TracingContext.extract_stack()) > 0:
user_stack = TracingContext.extract_stack()
user_top_stack = format_frame(user_stack[0], line=True)
user_bottom_stack = format_frame(user_stack[-1], line=True)
frame = inspect.currentframe()
try:
while frame is not None:
if (
floc is None
and frame.f_code.co_filename not in uninteresting_files()
):
floc = format_frame(
traceback.FrameSummary(
frame.f_code.co_filename,
frame.f_lineno,
frame.f_code.co_name,
),
line=True,
)
if frame.f_back is None and user_top_stack is None:
user_top_stack = format_frame(
traceback.FrameSummary(
frame.f_code.co_filename,
frame.f_lineno,
frame.f_code.co_name,
),
line=True,
)
break
frame = frame.f_back
finally:
del frame
dtrace_structured(
"expression_created",
metadata_fn=lambda: {
"method": method,
"arguments": [str(self), str(other)],
"result": str(result),
"user_bottom_stack": str(user_bottom_stack),
"user_top_stack": str(user_top_stack),
"floc": str(floc),
},
)
return result
return wrapper
@capture_provenance
def binary_magic_impl(self, other):
if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
return NotImplemented
@ -1662,6 +1749,7 @@ def _make_user_magic(method, user_type):
ret = wrap_node(getattr(self.node, method_attr)(other_node))
return get_constant(ret) if is_constant(ret) else ret
@capture_provenance
def rbinary_magic_impl(self, other):
if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
return NotImplemented