Beef up error message for pending assert failure (#126212)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126212
Approved by: https://github.com/Skylion007
This commit is contained in:
Edward Z. Yang 2024-05-15 09:22:01 -04:00 committed by PyTorch MergeBot
parent 26f6f98364
commit 44efeac24e
2 changed files with 16 additions and 8 deletions

View file

@ -2796,6 +2796,7 @@ coverage_ignore_classes = [
"ConstraintViolationError",
"DynamicDimConstraintPrinter",
"GuardOnDataDependentSymNode",
"PendingUnbackedSymbolNotFound",
"LoggingShapeGuardPrinter",
"RelaxedUnspecConstraint",
"RuntimeAssert",

View file

@ -83,6 +83,9 @@ log = logging.getLogger(__name__)
class GuardOnDataDependentSymNode(RuntimeError):
pass
class PendingUnbackedSymbolNotFound(RuntimeError):
pass
import sympy
from sympy.printing.str import StrPrinter
from sympy.printing.precedence import precedence, PRECEDENCE
@ -602,15 +605,19 @@ def compute_unbacked_bindings(shape_env, example_value, old_example_value=None,
return r
symbol_to_path = free_unbacked_symbols_with_path(example_value, ())
if not peek:
assert not pending, (
f"pending {pending} not in {example_value} " +
(
repr((example_value.stride(), example_value.storage_offset()))
if isinstance(example_value, torch.Tensor)
else ""
)
if not peek and pending:
extra = (
repr((example_value.stride(), example_value.storage_offset()))
if isinstance(example_value, torch.Tensor)
else ""
)
raise PendingUnbackedSymbolNotFound(
f"Pending unbacked symbols {pending} not in returned outputs {example_value} {extra}.\n"
"Did you accidentally call new_dynamic_size() or item() more times "
"than you needed to in your fake implementation?\n"
"For more help, see https://docs.google.com/document/d/1RWrH-3wLEpzR9kCS6gGBNen_-Fs-8PVbWWFE5AcgeWE/edit"
)
# Why do we have to do some rebinding here? If the original FX node
# wasn't a binding site because you had a memo hit, but post
# translation you aren't a memo hit anymore, there's now a new binding