mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Integrate sympy expression provenance logging with structured logs (#145848)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145848 Approved by: https://github.com/angelayi
This commit is contained in:
parent
4168982dad
commit
0e49f35e3d
2 changed files with 98 additions and 4 deletions
|
|
@ -43,9 +43,12 @@ LOG_ENV_VAR = "TORCH_LOGS"
|
|||
LOG_OUT_ENV_VAR = "TORCH_LOGS_OUT"
|
||||
LOG_FORMAT_ENV_VAR = "TORCH_LOGS_FORMAT"
|
||||
TRACE_ENV_VAR = "TORCH_TRACE"
|
||||
DTRACE_ENV_VAR = "TORCH_DTRACE"
|
||||
|
||||
LOG_TRACE_HANDLER: Optional["LazyTraceHandler"] = None
|
||||
|
||||
GET_DTRACE_STRUCTURED = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogRegistry:
|
||||
|
|
@ -934,6 +937,8 @@ def _set_log_state(state):
|
|||
|
||||
|
||||
def _init_logs(log_file_name=None):
|
||||
global GET_DTRACE_STRUCTURED
|
||||
|
||||
_reset_logs()
|
||||
_update_log_state_from_env()
|
||||
|
||||
|
|
@ -980,6 +985,10 @@ def _init_logs(log_file_name=None):
|
|||
# Setup handler for the special trace_log, with different default
|
||||
# configuration
|
||||
trace_dir_name = os.environ.get(TRACE_ENV_VAR, None)
|
||||
|
||||
if os.environ.get(DTRACE_ENV_VAR, None):
|
||||
GET_DTRACE_STRUCTURED = True
|
||||
|
||||
# This handler may remove itself if trace_dir_name is None and we are not
|
||||
# actually in an FB environment. This allows us to defer actually
|
||||
# initializing it until we actually need to log anything. This is
|
||||
|
|
@ -1249,9 +1258,6 @@ def trace_structured(
|
|||
add_structured_logging_overhead(structured_logging_overhead_s)
|
||||
|
||||
|
||||
GET_DTRACE_STRUCTURED = False
|
||||
|
||||
|
||||
def dtrace_structured(
|
||||
name: str,
|
||||
# NB: metadata expected to be dict so adding more info is forward compatible
|
||||
|
|
|
|||
|
|
@ -13,14 +13,18 @@ As this file is imported from within torch/__init__.py we do not want it to depe
|
|||
to avoid having to load SymPy at import time, as doing so is *very* slow.
|
||||
"""
|
||||
|
||||
|
||||
import builtins
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
import sys
|
||||
import traceback
|
||||
from functools import lru_cache, update_wrapper
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
from typing import Optional, Set, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -35,6 +39,9 @@ from torch import ( # noqa: F401
|
|||
SymFloat,
|
||||
SymInt,
|
||||
)
|
||||
from torch._guards import TracingContext
|
||||
from torch._logging import dtrace_structured
|
||||
from torch.utils._traceback import format_frame
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -1645,6 +1652,86 @@ def _make_user_magic(method, user_type):
|
|||
return (method_to_operator(method))(get_constant(self))
|
||||
return wrap_node(getattr(self.node, method_attr)())
|
||||
|
||||
def uninteresting_files() -> Set[str]:
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
|
||||
mods = [
|
||||
torch._dynamo.eval_frame,
|
||||
torch._dynamo.utils,
|
||||
torch.fx.experimental.sym_node,
|
||||
torch,
|
||||
]
|
||||
import torch._dynamo.guards
|
||||
|
||||
return (
|
||||
{inspect.getfile(m) for m in mods}
|
||||
| torch._dynamo.guards.uninteresting_files()
|
||||
| {"<string>"}
|
||||
)
|
||||
|
||||
def capture_provenance(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(self, other):
|
||||
result = fn(self, other)
|
||||
if torch._logging._internal.GET_DTRACE_STRUCTURED:
|
||||
floc = None
|
||||
user_stack = None
|
||||
user_top_stack = None
|
||||
user_bottom_stack = None
|
||||
|
||||
if len(TracingContext.extract_stack()) > 0:
|
||||
user_stack = TracingContext.extract_stack()
|
||||
user_top_stack = format_frame(user_stack[0], line=True)
|
||||
user_bottom_stack = format_frame(user_stack[-1], line=True)
|
||||
|
||||
frame = inspect.currentframe()
|
||||
try:
|
||||
while frame is not None:
|
||||
if (
|
||||
floc is None
|
||||
and frame.f_code.co_filename not in uninteresting_files()
|
||||
):
|
||||
floc = format_frame(
|
||||
traceback.FrameSummary(
|
||||
frame.f_code.co_filename,
|
||||
frame.f_lineno,
|
||||
frame.f_code.co_name,
|
||||
),
|
||||
line=True,
|
||||
)
|
||||
if frame.f_back is None and user_top_stack is None:
|
||||
user_top_stack = format_frame(
|
||||
traceback.FrameSummary(
|
||||
frame.f_code.co_filename,
|
||||
frame.f_lineno,
|
||||
frame.f_code.co_name,
|
||||
),
|
||||
line=True,
|
||||
)
|
||||
break
|
||||
frame = frame.f_back
|
||||
finally:
|
||||
del frame
|
||||
|
||||
dtrace_structured(
|
||||
"expression_created",
|
||||
metadata_fn=lambda: {
|
||||
"method": method,
|
||||
"arguments": [str(self), str(other)],
|
||||
"result": str(result),
|
||||
"user_bottom_stack": str(user_bottom_stack),
|
||||
"user_top_stack": str(user_top_stack),
|
||||
"floc": str(floc),
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
@capture_provenance
|
||||
def binary_magic_impl(self, other):
|
||||
if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
|
||||
return NotImplemented
|
||||
|
|
@ -1662,6 +1749,7 @@ def _make_user_magic(method, user_type):
|
|||
ret = wrap_node(getattr(self.node, method_attr)(other_node))
|
||||
return get_constant(ret) if is_constant(ret) else ret
|
||||
|
||||
@capture_provenance
|
||||
def rbinary_magic_impl(self, other):
|
||||
if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
|
||||
return NotImplemented
|
||||
|
|
|
|||
Loading…
Reference in a new issue