diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index 40f9a3b8395..52f2b30e2da 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -6,6 +6,8 @@ import torch._dynamo.test_case import torch._functorch.config import torch.nn import torch.utils.checkpoint +from torch._dynamo.bytecode_transformation import Instruction +from torch._dynamo.symbolic_convert import SpeculationLog, SpeculationLogDivergence class ExceptionTests(torch._dynamo.test_case.TestCase): @@ -402,6 +404,13 @@ class ExceptionTests(torch._dynamo.test_case.TestCase): self.assertEqual(ref[0], res[0]) self.assertEqual(ref[1], res[1]) + def test_speculation_exception(self): + log = SpeculationLog() + log.next("fake", 555, "fake", Instruction(1, "fake", 1, 1)) + log.restart() + with self.assertRaises(SpeculationLogDivergence): + log.next("bad", 58, "bad", Instruction(2, "different", 2, 2)) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 4b870977434..64675d3c732 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -213,11 +213,13 @@ class SpeculationLog: f"Previous instruction: {prev_entry.filename}:{prev_entry.lineno}" f"({prev_entry.inst.opname} @ {prev_entry.instruction_pointer})\n" ) - assert ( + if not ( entry.instruction_pointer == instruction_pointer and entry.filename == filename and entry.lineno == lineno - ), f""" + ): + raise SpeculationLogDivergence( + f""" SpeculationLog diverged at index {self.index} (log had {len(self.entries)} entries): - Expected: {entry.filename}:{entry.lineno} ({entry.inst.opname} at ip={entry.instruction_pointer}) - Actual: {filename}:{lineno} ({inst.opname} at ip={instruction_pointer}) @@ -235,6 +237,7 @@ do this for graph breaks, you will infinite loop). Otherwise, please submit a bug report, ideally including the contents of TORCH_LOGS=+dynamo """ + ) self.index += 1 return entry @@ -319,6 +322,10 @@ class BlockStackEntry: return self.with_context.exit(tx) +class SpeculationLogDivergence(AssertionError): + pass + + class ReturnValueOp(Exception): pass