mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add trace_shape_events artifact tracing for ShapeEnv events (#130473)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/130473 Approved by: https://github.com/lezcano
This commit is contained in:
parent
3100455b8e
commit
6f54e961ea
3 changed files with 29 additions and 3 deletions
|
|
@ -699,6 +699,7 @@ exclusions = {
|
|||
"verbose_guards",
|
||||
"sym_node",
|
||||
"export",
|
||||
"trace_shape_events",
|
||||
}
|
||||
for name in torch._logging._internal.log_registry.artifact_names:
|
||||
if name not in exclusions:
|
||||
|
|
|
|||
|
|
@ -153,5 +153,10 @@ register_artifact(
|
|||
"Logs extra info for various SymNode operations",
|
||||
off_by_default=True,
|
||||
)
|
||||
register_artifact(
|
||||
"trace_shape_events",
|
||||
"Logs traces for every ShapeEnv operation that we record for replay",
|
||||
off_by_default=True,
|
||||
)
|
||||
|
||||
register_artifact("custom_format_test_artifact", "Testing only", log_format="")
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@ import torch.utils._pytree as pytree
|
|||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
trace_shape_events_log = torch._logging.getArtifactLogger(
|
||||
__name__, "trace_shape_events"
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -175,6 +178,9 @@ class ShapeEnvEvent:
|
|||
return self.name == "defer_runtime_assert"
|
||||
|
||||
|
||||
NEST = 0
|
||||
|
||||
|
||||
# Extracts a ShapeEnv instance inside args and kwargs.
|
||||
# Specifically, it looks for:
|
||||
# 1. ShapeEnv arguments
|
||||
|
|
@ -235,6 +241,17 @@ def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
|
|||
|
||||
assert isinstance(args[0], ShapeEnv)
|
||||
|
||||
global NEST
|
||||
|
||||
trace_shape_events_log.debug(
|
||||
"%scall %s(*%r, **%r)", " " * NEST, name, args[1:], kwargs
|
||||
)
|
||||
NEST += 1
|
||||
|
||||
def retlog(r):
|
||||
trace_shape_events_log.debug("%s-> %s", " " * (NEST - 1), r)
|
||||
return r
|
||||
|
||||
try:
|
||||
if args[0].is_recording: # type: ignore[has-type]
|
||||
# If ShapeEnv is already recording an event, call the wrapped
|
||||
|
|
@ -242,7 +259,7 @@ def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
|
|||
#
|
||||
# NB: here, we skip the check of whether all ShapeEnv instances
|
||||
# are equal, in favor of a faster dispatch.
|
||||
return fn(*args, **kwargs)
|
||||
return retlog(fn(*args, **kwargs))
|
||||
|
||||
# Retrieve an instance of ShapeEnv.
|
||||
# Assumption: the collection of args and kwargs may not reference
|
||||
|
|
@ -252,7 +269,7 @@ def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
|
|||
# If we are calling this function without any ShapeEnv instance
|
||||
# alive in its arguments, we don't record and call the original.
|
||||
if self is None:
|
||||
return fn(*args, **kwargs)
|
||||
return retlog(fn(*args, **kwargs))
|
||||
|
||||
# Otherwise, start recording and call the function.
|
||||
with self._recording():
|
||||
|
|
@ -272,7 +289,7 @@ def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
|
|||
# the record if an error happened
|
||||
self.events.append(event)
|
||||
try:
|
||||
return event.run(self)
|
||||
return retlog(event.run(self))
|
||||
except Exception:
|
||||
self.events.pop()
|
||||
raise
|
||||
|
|
@ -287,6 +304,9 @@ def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
|
|||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
NEST -= 1
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
|
|||
Loading…
Reference in a new issue