Add trace_shape_events artifact tracing for ShapeEnv events (#130473)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130473
Approved by: https://github.com/lezcano
This commit is contained in:
Edward Z. Yang 2024-07-11 08:26:25 -07:00 committed by PyTorch MergeBot
parent 3100455b8e
commit 6f54e961ea
3 changed files with 29 additions and 3 deletions

View file

@ -699,6 +699,7 @@ exclusions = {
"verbose_guards",
"sym_node",
"export",
"trace_shape_events",
}
for name in torch._logging._internal.log_registry.artifact_names:
if name not in exclusions:

View file

@ -153,5 +153,10 @@ register_artifact(
"Logs extra info for various SymNode operations",
off_by_default=True,
)
register_artifact(
"trace_shape_events",
"Logs traces for every ShapeEnv operation that we record for replay",
off_by_default=True,
)
register_artifact("custom_format_test_artifact", "Testing only", log_format="")

View file

@ -11,6 +11,9 @@ import torch.utils._pytree as pytree
log = logging.getLogger(__name__)
trace_shape_events_log = torch._logging.getArtifactLogger(
__name__, "trace_shape_events"
)
__all__ = [
@ -175,6 +178,9 @@ class ShapeEnvEvent:
return self.name == "defer_runtime_assert"
NEST = 0
# Extracts a ShapeEnv instance inside args and kwargs.
# Specifically, it looks for:
# 1. ShapeEnv arguments
@ -235,6 +241,17 @@ def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
assert isinstance(args[0], ShapeEnv)
global NEST
trace_shape_events_log.debug(
"%scall %s(*%r, **%r)", " " * NEST, name, args[1:], kwargs
)
NEST += 1
def retlog(r):
trace_shape_events_log.debug("%s-> %s", " " * (NEST - 1), r)
return r
try:
if args[0].is_recording: # type: ignore[has-type]
# If ShapeEnv is already recording an event, call the wrapped
@ -242,7 +259,7 @@ def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
#
# NB: here, we skip the check of whether all ShapeEnv instances
# are equal, in favor of a faster dispatch.
return fn(*args, **kwargs)
return retlog(fn(*args, **kwargs))
# Retrieve an instance of ShapeEnv.
# Assumption: the collection of args and kwargs may not reference
@ -252,7 +269,7 @@ def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
# If we are calling this function without any ShapeEnv instance
# alive in its arguments, we don't record and call the original.
if self is None:
return fn(*args, **kwargs)
return retlog(fn(*args, **kwargs))
# Otherwise, start recording and call the function.
with self._recording():
@ -272,7 +289,7 @@ def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
# the record if an error happened
self.events.append(event)
try:
return event.run(self)
return retlog(event.run(self))
except Exception:
self.events.pop()
raise
@ -287,6 +304,9 @@ def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
)
raise
finally:
NEST -= 1
return wrapper
return decorator