speculation_log: Raise a unique error for divergence issues (#144785)

This is primarily sent for discussion and to see what tests fail due to
this. The idea is that rather than capturing this as a regex on the
fail_reason, just give it a unique failure type

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144785
Approved by: https://github.com/ezyang
This commit is contained in:
Colin L. Rice 2025-01-15 14:36:27 -07:00 committed by PyTorch MergeBot
parent b90231a189
commit 926f9056a9
2 changed files with 18 additions and 2 deletions

View file

@ -6,6 +6,8 @@ import torch._dynamo.test_case
import torch._functorch.config
import torch.nn
import torch.utils.checkpoint
from torch._dynamo.bytecode_transformation import Instruction
from torch._dynamo.symbolic_convert import SpeculationLog, SpeculationLogDivergence
class ExceptionTests(torch._dynamo.test_case.TestCase):
@ -402,6 +404,13 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref[0], res[0])
self.assertEqual(ref[1], res[1])
def test_speculation_exception(self):
log = SpeculationLog()
log.next("fake", 555, "fake", Instruction(1, "fake", 1, 1))
log.restart()
with self.assertRaises(SpeculationLogDivergence):
log.next("bad", 58, "bad", Instruction(2, "different", 2, 2))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View file

@ -213,11 +213,13 @@ class SpeculationLog:
f"Previous instruction: {prev_entry.filename}:{prev_entry.lineno}"
f"({prev_entry.inst.opname} @ {prev_entry.instruction_pointer})\n"
)
assert (
if not (
entry.instruction_pointer == instruction_pointer
and entry.filename == filename
and entry.lineno == lineno
), f"""
):
raise SpeculationLogDivergence(
f"""
SpeculationLog diverged at index {self.index} (log had {len(self.entries)} entries):
- Expected: {entry.filename}:{entry.lineno} ({entry.inst.opname} at ip={entry.instruction_pointer})
- Actual: {filename}:{lineno} ({inst.opname} at ip={instruction_pointer})
@ -235,6 +237,7 @@ do this for graph breaks, you will infinite loop).
Otherwise, please submit a bug report, ideally including the contents of TORCH_LOGS=+dynamo
"""
)
self.index += 1
return entry
@ -319,6 +322,10 @@ class BlockStackEntry:
return self.with_context.exit(tx)
class SpeculationLogDivergence(AssertionError):
pass
class ReturnValueOp(Exception):
pass